mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-05-31 02:25:29 +08:00
Compare commits
14 Commits
tools/goct
...
v1.8.3
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
69aa7fe346 | ||
|
|
c3820a95c1 | ||
|
|
493f3bad0f | ||
|
|
eb0d5ad3a4 | ||
|
|
14192050ae | ||
|
|
9193e771e3 | ||
|
|
808b4e496a | ||
|
|
e416d01f8d | ||
|
|
789c5de873 | ||
|
|
52078a0c14 | ||
|
|
7ef13116a0 | ||
|
|
6b8053410a | ||
|
|
81c6928445 | ||
|
|
761c2dd716 |
@@ -142,7 +142,7 @@ func validateOptional(field reflect.StructField, value reflect.Value) error {
|
|||||||
if value.IsNil() {
|
if value.IsNil() {
|
||||||
return fmt.Errorf("field %q is nil", field.Name)
|
return fmt.Errorf("field %q is nil", field.Name)
|
||||||
}
|
}
|
||||||
case reflect.Array, reflect.Slice, reflect.Map:
|
case reflect.Slice, reflect.Map:
|
||||||
if value.IsNil() || value.Len() == 0 {
|
if value.IsNil() || value.Len() == 0 {
|
||||||
return fmt.Errorf("field %q is empty", field.Name)
|
return fmt.Errorf("field %q is empty", field.Name)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -462,3 +462,15 @@ func TestMarshal_FromString(t *testing.T) {
|
|||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, "10", m["json"]["age"].(string))
|
assert.Equal(t, "10", m["json"]["age"].(string))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMarshal_Array(t *testing.T) {
|
||||||
|
v := struct {
|
||||||
|
H [1]int `json:"h,string"`
|
||||||
|
}{
|
||||||
|
H: [1]int{1},
|
||||||
|
}
|
||||||
|
|
||||||
|
m, err := Marshal(v)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "[1]", m["json"]["h"].(string))
|
||||||
|
}
|
||||||
|
|||||||
@@ -201,6 +201,13 @@ func TestHttpToHttp(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
|
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("method not allowed", func(t *testing.T) {
|
||||||
|
resp, err := httpc.Do(context.Background(), http.MethodPost,
|
||||||
|
"http://localhost:18882/api/ping", nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHttpToHttpBadUpstream(t *testing.T) {
|
func TestHttpToHttpBadUpstream(t *testing.T) {
|
||||||
|
|||||||
2
go.mod
2
go.mod
@@ -17,7 +17,7 @@ require (
|
|||||||
github.com/olekukonko/tablewriter v0.0.5
|
github.com/olekukonko/tablewriter v0.0.5
|
||||||
github.com/pelletier/go-toml/v2 v2.2.2
|
github.com/pelletier/go-toml/v2 v2.2.2
|
||||||
github.com/prometheus/client_golang v1.21.1
|
github.com/prometheus/client_golang v1.21.1
|
||||||
github.com/redis/go-redis/v9 v9.7.3
|
github.com/redis/go-redis/v9 v9.8.0
|
||||||
github.com/spaolacci/murmur3 v1.1.0
|
github.com/spaolacci/murmur3 v1.1.0
|
||||||
github.com/stretchr/testify v1.10.0
|
github.com/stretchr/testify v1.10.0
|
||||||
go.etcd.io/etcd/api/v3 v3.5.15
|
go.etcd.io/etcd/api/v3 v3.5.15
|
||||||
|
|||||||
4
go.sum
4
go.sum
@@ -159,8 +159,8 @@ github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ
|
|||||||
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
|
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
|
||||||
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
|
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
|
||||||
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
|
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
|
||||||
github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM=
|
github.com/redis/go-redis/v9 v9.8.0 h1:q3nRvjrlge/6UD7eTu/DSg2uYiU2mCL0G/uzBWqhicI=
|
||||||
github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA=
|
github.com/redis/go-redis/v9 v9.8.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
|
||||||
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
||||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||||
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
||||||
|
|||||||
@@ -34,7 +34,10 @@ type McpConf struct {
|
|||||||
// Cors contains allowed CORS origins
|
// Cors contains allowed CORS origins
|
||||||
Cors []string `json:",optional"`
|
Cors []string `json:",optional"`
|
||||||
|
|
||||||
// ToolTimeout is the maximum time allowed for tool execution
|
// SseTimeout is the maximum time allowed for SSE connections
|
||||||
ToolTimeout time.Duration `json:",default=30s"`
|
SseTimeout time.Duration `json:",default=24h"`
|
||||||
|
|
||||||
|
// MessageTimeout is the maximum time allowed for request execution
|
||||||
|
MessageTimeout time.Duration `json:",default=30s"`
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ mcp:
|
|||||||
assert.Equal(t, "2024-11-05", c.Mcp.ProtocolVersion, "Default protocol version should be 2024-11-05")
|
assert.Equal(t, "2024-11-05", c.Mcp.ProtocolVersion, "Default protocol version should be 2024-11-05")
|
||||||
assert.Equal(t, "/sse", c.Mcp.SseEndpoint, "Default SSE endpoint should be /sse")
|
assert.Equal(t, "/sse", c.Mcp.SseEndpoint, "Default SSE endpoint should be /sse")
|
||||||
assert.Equal(t, "/message", c.Mcp.MessageEndpoint, "Default message endpoint should be /message")
|
assert.Equal(t, "/message", c.Mcp.MessageEndpoint, "Default message endpoint should be /message")
|
||||||
assert.Equal(t, 30*time.Second, c.Mcp.ToolTimeout, "Default tool timeout should be 30s")
|
assert.Equal(t, 30*time.Second, c.Mcp.MessageTimeout, "Default message timeout should be 30s")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMcpConfCustomValues(t *testing.T) {
|
func TestMcpConfCustomValues(t *testing.T) {
|
||||||
@@ -43,7 +43,7 @@ func TestMcpConfCustomValues(t *testing.T) {
|
|||||||
"SseEndpoint": "/custom-sse",
|
"SseEndpoint": "/custom-sse",
|
||||||
"MessageEndpoint": "/custom-message",
|
"MessageEndpoint": "/custom-message",
|
||||||
"Cors": ["http://localhost:3000", "http://example.com"],
|
"Cors": ["http://localhost:3000", "http://example.com"],
|
||||||
"ToolTimeout": "60s"
|
"MessageTimeout": "60s"
|
||||||
}
|
}
|
||||||
}`
|
}`
|
||||||
|
|
||||||
@@ -59,5 +59,5 @@ func TestMcpConfCustomValues(t *testing.T) {
|
|||||||
assert.Equal(t, "/custom-sse", c.Mcp.SseEndpoint, "SSE endpoint should be customizable")
|
assert.Equal(t, "/custom-sse", c.Mcp.SseEndpoint, "SSE endpoint should be customizable")
|
||||||
assert.Equal(t, "/custom-message", c.Mcp.MessageEndpoint, "Message endpoint should be customizable")
|
assert.Equal(t, "/custom-message", c.Mcp.MessageEndpoint, "Message endpoint should be customizable")
|
||||||
assert.Equal(t, []string{"http://localhost:3000", "http://example.com"}, c.Mcp.Cors, "CORS settings should be customizable")
|
assert.Equal(t, []string{"http://localhost:3000", "http://example.com"}, c.Mcp.Cors, "CORS settings should be customizable")
|
||||||
assert.Equal(t, 60*time.Second, c.Mcp.ToolTimeout, "Tool timeout should be customizable")
|
assert.Equal(t, 60*time.Second, c.Mcp.MessageTimeout, "Tool timeout should be customizable")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package mcp
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -59,7 +60,7 @@ func TestHTTPHandlerIntegration(t *testing.T) {
|
|||||||
conf := McpConf{}
|
conf := McpConf{}
|
||||||
conf.Mcp.Name = "test-integration"
|
conf.Mcp.Name = "test-integration"
|
||||||
conf.Mcp.Version = "1.0.0-test"
|
conf.Mcp.Version = "1.0.0-test"
|
||||||
conf.Mcp.ToolTimeout = 1 * time.Second
|
conf.Mcp.MessageTimeout = 1 * time.Second
|
||||||
|
|
||||||
// Create a mock server directly
|
// Create a mock server directly
|
||||||
server := &sseMcpServer{
|
server := &sseMcpServer{
|
||||||
@@ -75,7 +76,6 @@ func TestHTTPHandlerIntegration(t *testing.T) {
|
|||||||
Name: "echo",
|
Name: "echo",
|
||||||
Description: "Echo tool for testing",
|
Description: "Echo tool for testing",
|
||||||
InputSchema: InputSchema{
|
InputSchema: InputSchema{
|
||||||
Type: "object",
|
|
||||||
Properties: map[string]any{
|
Properties: map[string]any{
|
||||||
"message": map[string]any{
|
"message": map[string]any{
|
||||||
"type": "string",
|
"type": "string",
|
||||||
@@ -83,7 +83,7 @@ func TestHTTPHandlerIntegration(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Handler: func(params map[string]any) (any, error) {
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
if msg, ok := params["message"].(string); ok {
|
if msg, ok := params["message"].(string); ok {
|
||||||
return fmt.Sprintf("Echo: %s", msg), nil
|
return fmt.Sprintf("Echo: %s", msg), nil
|
||||||
}
|
}
|
||||||
@@ -181,7 +181,7 @@ func TestHandlerResponseFlow(t *testing.T) {
|
|||||||
Name: "test.tool",
|
Name: "test.tool",
|
||||||
Description: "Test tool",
|
Description: "Test tool",
|
||||||
InputSchema: InputSchema{Type: "object"},
|
InputSchema: InputSchema{Type: "object"},
|
||||||
Handler: func(params map[string]any) (any, error) {
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
return "tool result", nil
|
return "tool result", nil
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@@ -329,7 +329,7 @@ func TestProcessListMethods(t *testing.T) {
|
|||||||
Params: json.RawMessage(`{"cursor": "", "_meta": {"progressToken": "token1"}}`),
|
Params: json.RawMessage(`{"cursor": "", "_meta": {"progressToken": "token1"}}`),
|
||||||
}
|
}
|
||||||
|
|
||||||
server.processListTools(client, req)
|
server.processListTools(context.Background(), client, req)
|
||||||
|
|
||||||
// Read response
|
// Read response
|
||||||
select {
|
select {
|
||||||
@@ -344,7 +344,7 @@ func TestProcessListMethods(t *testing.T) {
|
|||||||
req.ID = 2
|
req.ID = 2
|
||||||
req.Method = methodPromptsList
|
req.Method = methodPromptsList
|
||||||
req.Params = json.RawMessage(`{"cursor": "next"}`)
|
req.Params = json.RawMessage(`{"cursor": "next"}`)
|
||||||
server.processListPrompts(client, req)
|
server.processListPrompts(context.Background(), client, req)
|
||||||
|
|
||||||
// Read response
|
// Read response
|
||||||
select {
|
select {
|
||||||
@@ -358,7 +358,7 @@ func TestProcessListMethods(t *testing.T) {
|
|||||||
req.ID = 3
|
req.ID = 3
|
||||||
req.Method = methodResourcesList
|
req.Method = methodResourcesList
|
||||||
req.Params = json.RawMessage(`{"cursor": "next"}`)
|
req.Params = json.RawMessage(`{"cursor": "next"}`)
|
||||||
server.processListResources(client, req)
|
server.processListResources(context.Background(), client, req)
|
||||||
|
|
||||||
// Read response
|
// Read response
|
||||||
select {
|
select {
|
||||||
@@ -393,7 +393,7 @@ func TestErrorResponseHandling(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Mock handleRequest by directly calling error handler
|
// Mock handleRequest by directly calling error handler
|
||||||
server.sendErrorResponse(client, req.ID, "Method not found", errCodeMethodNotFound)
|
server.sendErrorResponse(context.Background(), client, req.ID, "Method not found", errCodeMethodNotFound)
|
||||||
|
|
||||||
// Check response
|
// Check response
|
||||||
select {
|
select {
|
||||||
@@ -412,7 +412,7 @@ func TestErrorResponseHandling(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Call process method directly
|
// Call process method directly
|
||||||
server.processToolCall(client, toolReq)
|
server.processToolCall(context.Background(), client, toolReq)
|
||||||
|
|
||||||
// Check response
|
// Check response
|
||||||
select {
|
select {
|
||||||
@@ -431,7 +431,7 @@ func TestErrorResponseHandling(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Call process method directly
|
// Call process method directly
|
||||||
server.processGetPrompt(client, promptReq)
|
server.processGetPrompt(context.Background(), client, promptReq)
|
||||||
|
|
||||||
// Check response
|
// Check response
|
||||||
select {
|
select {
|
||||||
|
|||||||
23
mcp/parser.go
Normal file
23
mcp/parser.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/zeromicro/go-zero/core/mapping"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ParseArguments parses the arguments and populates the request object
|
||||||
|
func ParseArguments(args any, req any) error {
|
||||||
|
switch arguments := args.(type) {
|
||||||
|
case map[string]string:
|
||||||
|
m := make(map[string]any, len(arguments))
|
||||||
|
for k, v := range arguments {
|
||||||
|
m[k] = v
|
||||||
|
}
|
||||||
|
return mapping.UnmarshalJsonMap(m, req, mapping.WithStringValues())
|
||||||
|
case map[string]any:
|
||||||
|
return mapping.UnmarshalJsonMap(arguments, req)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unsupported argument type: %T", arguments)
|
||||||
|
}
|
||||||
|
}
|
||||||
139
mcp/parser_test.go
Normal file
139
mcp/parser_test.go
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestParseArguments_MapStringString tests parsing map[string]string arguments
|
||||||
|
func TestParseArguments_MapStringString(t *testing.T) {
|
||||||
|
// Sample request struct to populate
|
||||||
|
type requestStruct struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Count int `json:"count"`
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create test arguments
|
||||||
|
args := map[string]string{
|
||||||
|
"name": "test-name",
|
||||||
|
"message": "hello world",
|
||||||
|
"count": "42",
|
||||||
|
"enabled": "true",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a target object to populate
|
||||||
|
var req requestStruct
|
||||||
|
|
||||||
|
// Parse the arguments
|
||||||
|
err := ParseArguments(args, &req)
|
||||||
|
|
||||||
|
// Verify results
|
||||||
|
assert.NoError(t, err, "Should parse map[string]string without error")
|
||||||
|
assert.Equal(t, "test-name", req.Name, "Name should be correctly parsed")
|
||||||
|
assert.Equal(t, "hello world", req.Message, "Message should be correctly parsed")
|
||||||
|
assert.Equal(t, 42, req.Count, "Count should be correctly parsed to int")
|
||||||
|
assert.True(t, req.Enabled, "Enabled should be correctly parsed to bool")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestParseArguments_MapStringAny tests parsing map[string]any arguments
|
||||||
|
func TestParseArguments_MapStringAny(t *testing.T) {
|
||||||
|
// Sample request struct to populate
|
||||||
|
type requestStruct struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Count int `json:"count"`
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
Tags []string `json:"tags"`
|
||||||
|
Metadata map[string]string `json:"metadata"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create test arguments with mixed types
|
||||||
|
args := map[string]any{
|
||||||
|
"name": "test-name",
|
||||||
|
"message": "hello world",
|
||||||
|
"count": 42, // note: this is already an int
|
||||||
|
"enabled": true, // note: this is already a bool
|
||||||
|
"tags": []string{"tag1", "tag2"},
|
||||||
|
"metadata": map[string]string{
|
||||||
|
"key1": "value1",
|
||||||
|
"key2": "value2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a target object to populate
|
||||||
|
var req requestStruct
|
||||||
|
|
||||||
|
// Parse the arguments
|
||||||
|
err := ParseArguments(args, &req)
|
||||||
|
|
||||||
|
// Verify results
|
||||||
|
assert.NoError(t, err, "Should parse map[string]any without error")
|
||||||
|
assert.Equal(t, "test-name", req.Name, "Name should be correctly parsed")
|
||||||
|
assert.Equal(t, "hello world", req.Message, "Message should be correctly parsed")
|
||||||
|
assert.Equal(t, 42, req.Count, "Count should be correctly parsed")
|
||||||
|
assert.True(t, req.Enabled, "Enabled should be correctly parsed")
|
||||||
|
assert.Equal(t, []string{"tag1", "tag2"}, req.Tags, "Tags should be correctly parsed")
|
||||||
|
assert.Equal(t, map[string]string{
|
||||||
|
"key1": "value1",
|
||||||
|
"key2": "value2",
|
||||||
|
}, req.Metadata, "Metadata should be correctly parsed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestParseArguments_UnsupportedType tests parsing with an unsupported type
|
||||||
|
func TestParseArguments_UnsupportedType(t *testing.T) {
|
||||||
|
// Sample request struct to populate
|
||||||
|
type requestStruct struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use an unsupported argument type (slice)
|
||||||
|
args := []string{"not", "a", "map"}
|
||||||
|
|
||||||
|
// Create a target object to populate
|
||||||
|
var req requestStruct
|
||||||
|
|
||||||
|
// Parse the arguments
|
||||||
|
err := ParseArguments(args, &req)
|
||||||
|
|
||||||
|
// Verify error is returned with correct message
|
||||||
|
assert.Error(t, err, "Should return error for unsupported type")
|
||||||
|
assert.Contains(t, err.Error(), "unsupported argument type", "Error should mention unsupported type")
|
||||||
|
assert.Contains(t, err.Error(), "[]string", "Error should include the actual type")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestParseArguments_EmptyMap tests parsing with empty maps
|
||||||
|
func TestParseArguments_EmptyMap(t *testing.T) {
|
||||||
|
// Sample request struct to populate
|
||||||
|
type requestStruct struct {
|
||||||
|
Name string `json:"name,optional"`
|
||||||
|
Message string `json:"message,optional"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test empty map[string]string
|
||||||
|
t.Run("EmptyMapStringString", func(t *testing.T) {
|
||||||
|
args := map[string]string{}
|
||||||
|
var req requestStruct
|
||||||
|
|
||||||
|
err := ParseArguments(args, &req)
|
||||||
|
|
||||||
|
assert.NoError(t, err, "Should parse empty map[string]string without error")
|
||||||
|
assert.Empty(t, req.Name, "Name should be empty string")
|
||||||
|
assert.Empty(t, req.Message, "Message should be empty string")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test empty map[string]any
|
||||||
|
t.Run("EmptyMapStringAny", func(t *testing.T) {
|
||||||
|
args := map[string]any{}
|
||||||
|
var req requestStruct
|
||||||
|
|
||||||
|
err := ParseArguments(args, &req)
|
||||||
|
|
||||||
|
assert.NoError(t, err, "Should parse empty map[string]any without error")
|
||||||
|
assert.Empty(t, req.Name, "Name should be empty string")
|
||||||
|
assert.Empty(t, req.Message, "Message should be empty string")
|
||||||
|
})
|
||||||
|
}
|
||||||
824
mcp/readme.md
824
mcp/readme.md
@@ -1,7 +1,7 @@
|
|||||||
# Model Context Protocol (MCP) SDK Implementation
|
# Model Context Protocol (MCP) Implementation
|
||||||
|
|
||||||
## Overview
|
## Overview
|
||||||
This package implements a Model Context Protocol (MCP) server in Go that facilitates real-time communication between AI models and clients using Server-Sent Events (SSE). The implementation provides a framework for building AI-assisted applications with bidirectional communication capabilities.
|
This package implements the Model Context Protocol (MCP) server specification in Go, providing a framework for real-time communication between AI models and clients using Server-Sent Events (SSE). The implementation follows the standardized protocol for building AI-assisted applications with bidirectional communication capabilities.
|
||||||
|
|
||||||
## Core Components
|
## Core Components
|
||||||
|
|
||||||
@@ -54,9 +54,817 @@ This package implements a Model Context Protocol (MCP) server in Go that facilit
|
|||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
To create and use an MCP server, see the examples directory for practical implementation examples including:
|
### Setting Up an MCP Server
|
||||||
- Tool registration and execution
|
|
||||||
- Static and dynamic prompt creation
|
To create and start an MCP server:
|
||||||
- Resource handling with proper URI identification
|
|
||||||
- Embedded resources in prompt responses
|
```go
|
||||||
- Client connection management
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/zeromicro/go-zero/core/conf"
|
||||||
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
|
"github.com/zeromicro/go-zero/mcp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// Load configuration from YAML file
|
||||||
|
var c mcp.McpConf
|
||||||
|
conf.MustLoad("config.yaml", &c)
|
||||||
|
|
||||||
|
// Optional: Disable stats logging
|
||||||
|
logx.DisableStat()
|
||||||
|
|
||||||
|
// Create MCP server
|
||||||
|
server := mcp.NewMcpServer(c)
|
||||||
|
|
||||||
|
// Register tools, prompts, and resources (examples below)
|
||||||
|
|
||||||
|
// Start the server and ensure it's stopped on exit
|
||||||
|
defer server.Stop()
|
||||||
|
server.Start()
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Sample configuration file (config.yaml):
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
name: mcp-server
|
||||||
|
host: localhost
|
||||||
|
port: 8080
|
||||||
|
mcp:
|
||||||
|
name: my-mcp-server
|
||||||
|
messageTimeout: 30s # Timeout for tool calls
|
||||||
|
cors:
|
||||||
|
- http://localhost:3000 # Optional CORS configuration
|
||||||
|
```
|
||||||
|
|
||||||
|
### Registering Tools
|
||||||
|
|
||||||
|
Tools allow AI models to execute custom code through the MCP protocol.
|
||||||
|
|
||||||
|
#### Basic Tool Example:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Register a simple echo tool
|
||||||
|
echoTool := mcp.Tool{
|
||||||
|
Name: "echo",
|
||||||
|
Description: "Echoes back the message provided by the user",
|
||||||
|
InputSchema: mcp.InputSchema{
|
||||||
|
Properties: map[string]any{
|
||||||
|
"message": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"description": "The message to echo back",
|
||||||
|
},
|
||||||
|
"prefix": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"description": "Optional prefix to add to the echoed message",
|
||||||
|
"default": "Echo: ",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Required: []string{"message"},
|
||||||
|
},
|
||||||
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
|
var req struct {
|
||||||
|
Message string `json:"message"`
|
||||||
|
Prefix string `json:"prefix,optional"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mcp.ParseArguments(params, &req); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse params: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
prefix := "Echo: "
|
||||||
|
if len(req.Prefix) > 0 {
|
||||||
|
prefix = req.Prefix
|
||||||
|
}
|
||||||
|
|
||||||
|
return prefix + req.Message, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
server.RegisterTool(echoTool)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Tool with Different Response Types:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Tool returning JSON data
|
||||||
|
dataTool := mcp.Tool{
|
||||||
|
Name: "data.generate",
|
||||||
|
Description: "Generates sample data in various formats",
|
||||||
|
InputSchema: mcp.InputSchema{
|
||||||
|
Properties: map[string]any{
|
||||||
|
"format": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"description": "Format of data (json, text)",
|
||||||
|
"enum": []string{"json", "text"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
|
var req struct {
|
||||||
|
Format string `json:"format"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mcp.ParseArguments(params, &req); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse params: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Format == "json" {
|
||||||
|
// Return structured data
|
||||||
|
return map[string]any{
|
||||||
|
"items": []map[string]any{
|
||||||
|
{"id": 1, "name": "Item 1"},
|
||||||
|
{"id": 2, "name": "Item 2"},
|
||||||
|
},
|
||||||
|
"count": 2,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default to text
|
||||||
|
return "Sample text data", nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
server.RegisterTool(dataTool)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Image Generation Tool Example:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Tool returning image content
|
||||||
|
imageTool := mcp.Tool{
|
||||||
|
Name: "image.generate",
|
||||||
|
Description: "Generates a simple image",
|
||||||
|
InputSchema: mcp.InputSchema{
|
||||||
|
Properties: map[string]any{
|
||||||
|
"type": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"description": "Type of image to generate",
|
||||||
|
"default": "placeholder",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
|
// Return image content directly
|
||||||
|
return mcp.ImageContent{
|
||||||
|
Data: "base64EncodedImageData...", // Base64 encoded image data
|
||||||
|
MimeType: "image/png",
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
server.RegisterTool(imageTool)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Using ToolResult for Custom Outputs:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Tool that returns a custom ToolResult type
|
||||||
|
customResultTool := mcp.Tool{
|
||||||
|
Name: "custom.result",
|
||||||
|
Description: "Returns a custom formatted result",
|
||||||
|
InputSchema: mcp.InputSchema{
|
||||||
|
Properties: map[string]any{
|
||||||
|
"resultType": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"enum": []string{"text", "image"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
|
var req struct {
|
||||||
|
ResultType string `json:"resultType"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mcp.ParseArguments(params, &req); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse params: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.ResultType == "image" {
|
||||||
|
return mcp.ToolResult{
|
||||||
|
Type: mcp.ContentTypeImage,
|
||||||
|
Content: map[string]any{
|
||||||
|
"data": "base64EncodedImageData...",
|
||||||
|
"mimeType": "image/jpeg",
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default to text
|
||||||
|
return mcp.ToolResult{
|
||||||
|
Type: mcp.ContentTypeText,
|
||||||
|
Content: "This is a text result from ToolResult",
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
server.RegisterTool(customResultTool)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Registering Prompts
|
||||||
|
|
||||||
|
Prompts are reusable conversation templates for AI models.
|
||||||
|
|
||||||
|
#### Static Prompt Example:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Register a simple static prompt with placeholders
|
||||||
|
server.RegisterPrompt(mcp.Prompt{
|
||||||
|
Name: "hello",
|
||||||
|
Description: "A simple hello prompt",
|
||||||
|
Arguments: []mcp.PromptArgument{
|
||||||
|
{
|
||||||
|
Name: "name",
|
||||||
|
Description: "The name to greet",
|
||||||
|
Required: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Content: "Say hello to {{name}} and introduce yourself as an AI assistant.",
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Dynamic Prompt with Handler Function:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Register a prompt with a dynamic handler function
|
||||||
|
server.RegisterPrompt(mcp.Prompt{
|
||||||
|
Name: "dynamic-prompt",
|
||||||
|
Description: "A prompt that uses a handler to generate dynamic content",
|
||||||
|
Arguments: []mcp.PromptArgument{
|
||||||
|
{
|
||||||
|
Name: "username",
|
||||||
|
Description: "User's name for personalized greeting",
|
||||||
|
Required: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "topic",
|
||||||
|
Description: "Topic of expertise",
|
||||||
|
Required: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) {
|
||||||
|
var req struct {
|
||||||
|
Username string `json:"username"`
|
||||||
|
Topic string `json:"topic"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mcp.ParseArguments(args, &req); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse args: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a user message
|
||||||
|
userMessage := mcp.PromptMessage{
|
||||||
|
Role: mcp.RoleUser,
|
||||||
|
Content: mcp.TextContent{
|
||||||
|
Text: fmt.Sprintf("Hello, I'm %s and I'd like to learn about %s.", req.Username, req.Topic),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create an assistant response with current time
|
||||||
|
currentTime := time.Now().Format(time.RFC1123)
|
||||||
|
assistantMessage := mcp.PromptMessage{
|
||||||
|
Role: mcp.RoleAssistant,
|
||||||
|
Content: mcp.TextContent{
|
||||||
|
Text: fmt.Sprintf("Hello %s! I'm an AI assistant and I'll help you learn about %s. The current time is %s.",
|
||||||
|
req.Username, req.Topic, currentTime),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return both messages as a conversation
|
||||||
|
return []mcp.PromptMessage{userMessage, assistantMessage}, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Multi-Message Prompt with Code Examples:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Register a prompt that provides code examples in different programming languages
|
||||||
|
server.RegisterPrompt(mcp.Prompt{
|
||||||
|
Name: "code-example",
|
||||||
|
Description: "Provides code examples in different programming languages",
|
||||||
|
Arguments: []mcp.PromptArgument{
|
||||||
|
{
|
||||||
|
Name: "language",
|
||||||
|
Description: "Programming language for the example",
|
||||||
|
Required: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "complexity",
|
||||||
|
Description: "Complexity level (simple, medium, advanced)",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) {
|
||||||
|
var req struct {
|
||||||
|
Language string `json:"language"`
|
||||||
|
Complexity string `json:"complexity,optional"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mcp.ParseArguments(args, &req); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse args: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate language
|
||||||
|
supportedLanguages := map[string]bool{"go": true, "python": true, "javascript": true, "rust": true}
|
||||||
|
if !supportedLanguages[req.Language] {
|
||||||
|
return nil, fmt.Errorf("unsupported language: %s", req.Language)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate code example based on language and complexity
|
||||||
|
var codeExample string
|
||||||
|
|
||||||
|
switch req.Language {
|
||||||
|
case "go":
|
||||||
|
if req.Complexity == "simple" {
|
||||||
|
codeExample = `
|
||||||
|
package main
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
fmt.Println("Hello, World!")
|
||||||
|
}`
|
||||||
|
} else {
|
||||||
|
codeExample = `
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
now := time.Now()
|
||||||
|
fmt.Printf("Hello, World! Current time is %s\n", now.Format(time.RFC3339))
|
||||||
|
}`
|
||||||
|
}
|
||||||
|
case "python":
|
||||||
|
// Python example code
|
||||||
|
if req.Complexity == "simple" {
|
||||||
|
codeExample = `
|
||||||
|
def greet(name):
|
||||||
|
return f"Hello, {name}!"
|
||||||
|
|
||||||
|
print(greet("World"))`
|
||||||
|
} else {
|
||||||
|
codeExample = `
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
def greet(name, include_time=False):
|
||||||
|
message = f"Hello, {name}!"
|
||||||
|
if include_time:
|
||||||
|
message += f" Current time is {datetime.datetime.now().isoformat()}"
|
||||||
|
return message
|
||||||
|
|
||||||
|
print(greet("World", include_time=True))`
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create messages array according to MCP spec
|
||||||
|
messages := []mcp.PromptMessage{
|
||||||
|
{
|
||||||
|
Role: mcp.RoleAssistant,
|
||||||
|
Content: mcp.TextContent{
|
||||||
|
Text: fmt.Sprintf("You are a helpful coding assistant specialized in %s programming.", req.Language),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: mcp.RoleUser,
|
||||||
|
Content: mcp.TextContent{
|
||||||
|
Text: fmt.Sprintf("Show me a %s example of a Hello World program in %s.", req.Complexity, req.Language),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: mcp.RoleAssistant,
|
||||||
|
Content: mcp.TextContent{
|
||||||
|
Text: fmt.Sprintf("Here's a %s example in %s:\n\n```%s%s\n```\n\nHow can I help you implement this?",
|
||||||
|
req.Complexity, req.Language, req.Language, codeExample),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return messages, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Registering Resources
|
||||||
|
|
||||||
|
Resources provide access to external content such as files or generated data.
|
||||||
|
|
||||||
|
#### Basic Resource Example:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Register a static resource
|
||||||
|
server.RegisterResource(mcp.Resource{
|
||||||
|
Name: "example-document",
|
||||||
|
URI: "file:///example/document.txt",
|
||||||
|
Description: "An example document",
|
||||||
|
MimeType: "text/plain",
|
||||||
|
Handler: func(ctx context.Context) (mcp.ResourceContent, error) {
|
||||||
|
return mcp.ResourceContent{
|
||||||
|
URI: "file:///example/document.txt",
|
||||||
|
MimeType: "text/plain",
|
||||||
|
Text: "This is an example document content.",
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Dynamic Resource with Code Example:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Register a Go code resource with dynamic handler
|
||||||
|
server.RegisterResource(mcp.Resource{
|
||||||
|
Name: "go-example",
|
||||||
|
URI: "file:///project/src/main.go",
|
||||||
|
Description: "A simple Go example with multiple files",
|
||||||
|
MimeType: "text/x-go",
|
||||||
|
Handler: func(ctx context.Context) (mcp.ResourceContent, error) {
|
||||||
|
// Return ResourceContent with all required fields
|
||||||
|
return mcp.ResourceContent{
|
||||||
|
URI: "file:///project/src/main.go",
|
||||||
|
MimeType: "text/x-go",
|
||||||
|
Text: "package main\n\nimport (\n\t\"fmt\"\n\t\"./greeting\"\n)\n\nfunc main() {\n\tfmt.Println(greeting.Hello(\"world\"))\n}",
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
// Register a companion file for the above example
|
||||||
|
server.RegisterResource(mcp.Resource{
|
||||||
|
Name: "go-greeting",
|
||||||
|
URI: "file:///project/src/greeting/greeting.go",
|
||||||
|
Description: "A greeting package for the Go example",
|
||||||
|
MimeType: "text/x-go",
|
||||||
|
Handler: func(ctx context.Context) (mcp.ResourceContent, error) {
|
||||||
|
return mcp.ResourceContent{
|
||||||
|
URI: "file:///project/src/greeting/greeting.go",
|
||||||
|
MimeType: "text/x-go",
|
||||||
|
Text: "package greeting\n\nfunc Hello(name string) string {\n\treturn \"Hello, \" + name + \"!\"\n}",
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Binary Resource Example:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Register a binary resource (like an image)
|
||||||
|
server.RegisterResource(mcp.Resource{
|
||||||
|
Name: "example-image",
|
||||||
|
URI: "file:///example/image.png",
|
||||||
|
Description: "An example image",
|
||||||
|
MimeType: "image/png",
|
||||||
|
Handler: func(ctx context.Context) (mcp.ResourceContent, error) {
|
||||||
|
// Read image from file or generate it
|
||||||
|
imageData := "base64EncodedImageData..." // Base64 encoded image data
|
||||||
|
|
||||||
|
return mcp.ResourceContent{
|
||||||
|
URI: "file:///example/image.png",
|
||||||
|
MimeType: "image/png",
|
||||||
|
Blob: imageData, // For binary data
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Using Resources in Prompts
|
||||||
|
|
||||||
|
You can embed resources in prompt responses to create rich interactions with proper MCP-compliant structure:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Register a prompt that embeds a resource
|
||||||
|
server.RegisterPrompt(mcp.Prompt{
|
||||||
|
Name: "resource-example",
|
||||||
|
Description: "A prompt that embeds a resource",
|
||||||
|
Arguments: []mcp.PromptArgument{
|
||||||
|
{
|
||||||
|
Name: "file_type",
|
||||||
|
Description: "Type of file to show (rust or go)",
|
||||||
|
Required: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) {
|
||||||
|
var req struct {
|
||||||
|
FileType string `json:"file_type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mcp.ParseArguments(args, &req); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse args: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resourceURI, mimeType, fileContent string
|
||||||
|
if req.FileType == "rust" {
|
||||||
|
resourceURI = "file:///project/src/main.rs"
|
||||||
|
mimeType = "text/x-rust"
|
||||||
|
fileContent = "fn main() {\n println!(\"Hello world!\");\n}"
|
||||||
|
} else {
|
||||||
|
resourceURI = "file:///project/src/main.go"
|
||||||
|
mimeType = "text/x-go"
|
||||||
|
fileContent = "package main\n\nimport \"fmt\"\n\nfunc main() {\n\tfmt.Println(\"Hello, world!\")\n}"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create message with embedded resource using proper MCP format
|
||||||
|
return []mcp.PromptMessage{
|
||||||
|
{
|
||||||
|
Role: mcp.RoleUser,
|
||||||
|
Content: mcp.TextContent{
|
||||||
|
Text: fmt.Sprintf("Can you explain this %s code?", req.FileType),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: mcp.RoleAssistant,
|
||||||
|
Content: mcp.EmbeddedResource{
|
||||||
|
Type: mcp.ContentTypeResource,
|
||||||
|
Resource: struct {
|
||||||
|
URI string `json:"uri"`
|
||||||
|
MimeType string `json:"mimeType"`
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
Blob string `json:"blob,omitempty"`
|
||||||
|
}{
|
||||||
|
URI: resourceURI,
|
||||||
|
MimeType: mimeType,
|
||||||
|
Text: fileContent,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: mcp.RoleAssistant,
|
||||||
|
Content: mcp.TextContent{
|
||||||
|
Text: fmt.Sprintf("Above is a simple Hello World example in %s. Let me explain how it works.", req.FileType),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Multiple File Resources Example
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Register a prompt that demonstrates embedding multiple resource files
|
||||||
|
server.RegisterPrompt(mcp.Prompt{
|
||||||
|
Name: "go-code-example",
|
||||||
|
Description: "A prompt that correctly embeds multiple resource files",
|
||||||
|
Arguments: []mcp.PromptArgument{
|
||||||
|
{
|
||||||
|
Name: "format",
|
||||||
|
Description: "How to format the code display",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) {
|
||||||
|
var req struct {
|
||||||
|
Format string `json:"format,optional"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mcp.ParseArguments(args, &req); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse args: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the Go code for multiple files
|
||||||
|
var mainGoText string = "package main\n\nimport (\n\t\"fmt\"\n\t\"./greeting\"\n)\n\nfunc main() {\n\tfmt.Println(greeting.Hello(\"world\"))\n}"
|
||||||
|
var greetingGoText string = "package greeting\n\nfunc Hello(name string) string {\n\treturn \"Hello, \" + name + \"!\"\n}"
|
||||||
|
|
||||||
|
// Create message with properly formatted embedded resource per MCP spec
|
||||||
|
messages := []mcp.PromptMessage{
|
||||||
|
{
|
||||||
|
Role: mcp.RoleUser,
|
||||||
|
Content: mcp.TextContent{
|
||||||
|
Text: "Show me a simple Go example with proper imports.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: mcp.RoleAssistant,
|
||||||
|
Content: mcp.TextContent{
|
||||||
|
Text: "Here's a simple Go example project:",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: mcp.RoleAssistant,
|
||||||
|
Content: mcp.EmbeddedResource{
|
||||||
|
Type: mcp.ContentTypeResource,
|
||||||
|
Resource: struct {
|
||||||
|
URI string `json:"uri"`
|
||||||
|
MimeType string `json:"mimeType"`
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
Blob string `json:"blob,omitempty"`
|
||||||
|
}{
|
||||||
|
URI: "file:///project/src/main.go",
|
||||||
|
MimeType: "text/x-go",
|
||||||
|
Text: mainGoText,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add explanation and additional file if requested
|
||||||
|
if req.Format == "with_explanation" {
|
||||||
|
messages = append(messages, mcp.PromptMessage{
|
||||||
|
Role: mcp.RoleAssistant,
|
||||||
|
Content: mcp.TextContent{
|
||||||
|
Text: "This example demonstrates a simple Go application with modular structure. The main.go file imports from a local 'greeting' package that provides the Hello function.",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
// Also show the greeting.go file with correct resource format
|
||||||
|
messages = append(messages, mcp.PromptMessage{
|
||||||
|
Role: mcp.RoleAssistant,
|
||||||
|
Content: mcp.EmbeddedResource{
|
||||||
|
Type: mcp.ContentTypeResource,
|
||||||
|
Resource: struct {
|
||||||
|
URI string `json:"uri"`
|
||||||
|
MimeType string `json:"mimeType"`
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
Blob string `json:"blob,omitempty"`
|
||||||
|
}{
|
||||||
|
URI: "file:///project/src/greeting/greeting.go",
|
||||||
|
MimeType: "text/x-go",
|
||||||
|
Text: greetingGoText,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return messages, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Complete Application Example
|
||||||
|
|
||||||
|
Here's a complete example demonstrating all the components:
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/zeromicro/go-zero/core/conf"
|
||||||
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
|
"github.com/zeromicro/go-zero/mcp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// Load configuration
|
||||||
|
var c mcp.McpConf
|
||||||
|
if err := conf.Load("config.yaml", &c); err != nil {
|
||||||
|
log.Fatalf("Failed to load config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up logging
|
||||||
|
logx.DisableStat()
|
||||||
|
|
||||||
|
// Create MCP server
|
||||||
|
server := mcp.NewMcpServer(c)
|
||||||
|
defer server.Stop()
|
||||||
|
|
||||||
|
// Register a simple echo tool
|
||||||
|
echoTool := mcp.Tool{
|
||||||
|
Name: "echo",
|
||||||
|
Description: "Echoes back the message provided by the user",
|
||||||
|
InputSchema: mcp.InputSchema{
|
||||||
|
Properties: map[string]any{
|
||||||
|
"message": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"description": "The message to echo back",
|
||||||
|
},
|
||||||
|
"prefix": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"description": "Optional prefix to add to the echoed message",
|
||||||
|
"default": "Echo: ",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Required: []string{"message"},
|
||||||
|
},
|
||||||
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
|
var req struct {
|
||||||
|
Message string `json:"message"`
|
||||||
|
Prefix string `json:"prefix,optional"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mcp.ParseArguments(params, &req); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse args: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
prefix := "Echo: "
|
||||||
|
if len(req.Prefix) > 0 {
|
||||||
|
prefix = req.Prefix
|
||||||
|
}
|
||||||
|
|
||||||
|
return prefix + req.Message, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
server.RegisterTool(echoTool)
|
||||||
|
|
||||||
|
// Register a static prompt
|
||||||
|
server.RegisterPrompt(mcp.Prompt{
|
||||||
|
Name: "greeting",
|
||||||
|
Description: "A simple greeting prompt",
|
||||||
|
Arguments: []mcp.PromptArgument{
|
||||||
|
{
|
||||||
|
Name: "name",
|
||||||
|
Description: "The name to greet",
|
||||||
|
Required: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Content: "Hello {{name}}! How can I assist you today?",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Register a dynamic prompt
|
||||||
|
server.RegisterPrompt(mcp.Prompt{
|
||||||
|
Name: "dynamic-prompt",
|
||||||
|
Description: "A prompt that uses a handler to generate dynamic content",
|
||||||
|
Arguments: []mcp.PromptArgument{
|
||||||
|
{
|
||||||
|
Name: "username",
|
||||||
|
Description: "User's name for personalized greeting",
|
||||||
|
Required: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "topic",
|
||||||
|
Description: "Topic of expertise",
|
||||||
|
Required: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) {
|
||||||
|
var req struct {
|
||||||
|
Username string `json:"username"`
|
||||||
|
Topic string `json:"topic"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mcp.ParseArguments(args, &req); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse args: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create messages with current time
|
||||||
|
currentTime := time.Now().Format(time.RFC1123)
|
||||||
|
return []mcp.PromptMessage{
|
||||||
|
{
|
||||||
|
Role: mcp.RoleUser,
|
||||||
|
Content: mcp.TextContent{
|
||||||
|
Text: fmt.Sprintf("Hello, I'm %s and I'd like to learn about %s.", req.Username, req.Topic),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: mcp.RoleAssistant,
|
||||||
|
Content: mcp.TextContent{
|
||||||
|
Text: fmt.Sprintf("Hello %s! I'm an AI assistant and I'll help you learn about %s. The current time is %s.",
|
||||||
|
req.Username, req.Topic, currentTime),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
// Register a resource
|
||||||
|
server.RegisterResource(mcp.Resource{
|
||||||
|
Name: "example-doc",
|
||||||
|
URI: "file:///example/doc.txt",
|
||||||
|
Description: "An example document",
|
||||||
|
MimeType: "text/plain",
|
||||||
|
Handler: func(ctx context.Context) (mcp.ResourceContent, error) {
|
||||||
|
return mcp.ResourceContent{
|
||||||
|
URI: "file:///example/doc.txt",
|
||||||
|
MimeType: "text/plain",
|
||||||
|
Text: "This is the content of the example document.",
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
// Start the server
|
||||||
|
fmt.Printf("Starting MCP server on %s:%d\n", c.Host, c.Port)
|
||||||
|
server.Start()
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Error Handling
|
||||||
|
|
||||||
|
The MCP implementation provides comprehensive error handling:
|
||||||
|
|
||||||
|
- Tool execution errors are properly reported back to clients
|
||||||
|
- Missing or invalid parameters are detected and reported with appropriate error codes
|
||||||
|
- Resource and prompt lookup failures are handled gracefully
|
||||||
|
- Timeout handling for long-running tool executions using context
|
||||||
|
- Panic recovery to prevent server crashes
|
||||||
|
|
||||||
|
## Advanced Features
|
||||||
|
|
||||||
|
- **Annotations**: Add audience and priority metadata to content
|
||||||
|
- **Content Types**: Support for text, images, audio, and other content formats
|
||||||
|
- **Embedded Resources**: Include file resources directly in prompt responses
|
||||||
|
- **Context Awareness**: All handlers receive context.Context for timeout and cancellation support
|
||||||
|
- **Progress Tokens**: Support for tracking progress of long-running operations
|
||||||
|
- **Customizable Timeouts**: Configure execution timeouts for tools and operations
|
||||||
|
|
||||||
|
## Performance Considerations
|
||||||
|
|
||||||
|
- Tool execution runs with configurable timeouts to prevent blocking
|
||||||
|
- Efficient client tracking and cleanup to prevent resource leaks
|
||||||
|
- Proper concurrency handling with mutex protection for shared resources
|
||||||
|
- Buffered message channels to prevent blocking on client message delivery
|
||||||
|
|||||||
239
mcp/server.go
239
mcp/server.go
@@ -42,14 +42,14 @@ func NewMcpServer(c McpConf) McpServer {
|
|||||||
Method: http.MethodGet,
|
Method: http.MethodGet,
|
||||||
Path: s.conf.Mcp.SseEndpoint,
|
Path: s.conf.Mcp.SseEndpoint,
|
||||||
Handler: s.handleSSE,
|
Handler: s.handleSSE,
|
||||||
}, rest.WithSSE())
|
}, rest.WithSSE(), rest.WithTimeout(c.Mcp.SseTimeout))
|
||||||
|
|
||||||
// JSON-RPC message endpoint for regular requests
|
// JSON-RPC message endpoint for regular requests
|
||||||
s.server.AddRoute(rest.Route{
|
s.server.AddRoute(rest.Route{
|
||||||
Method: http.MethodPost,
|
Method: http.MethodPost,
|
||||||
Path: s.conf.Mcp.MessageEndpoint,
|
Path: s.conf.Mcp.MessageEndpoint,
|
||||||
Handler: s.handleRequest,
|
Handler: s.handleRequest,
|
||||||
})
|
}, rest.WithTimeout(c.Mcp.MessageTimeout))
|
||||||
|
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
@@ -182,21 +182,23 @@ func (s *sseMcpServer) handleRequest(w http.ResponseWriter, r *http.Request) {
|
|||||||
// Always allow initialize and notifications/initialized regardless of client state
|
// Always allow initialize and notifications/initialized regardless of client state
|
||||||
if req.Method == methodInitialize {
|
if req.Method == methodInitialize {
|
||||||
logx.Infof("Processing initialize request with ID: %d", req.ID)
|
logx.Infof("Processing initialize request with ID: %d", req.ID)
|
||||||
s.processInitialize(client, req)
|
s.processInitialize(r.Context(), client, req)
|
||||||
logx.Infof("Sent initialize response for ID: %d, waiting for notifications/initialized", req.ID)
|
logx.Infof("Sent initialize response for ID: %d, waiting for notifications/initialized", req.ID)
|
||||||
return
|
return
|
||||||
} else if req.Method == methodNotificationsInitialized {
|
} else if req.Method == methodNotificationsInitialized {
|
||||||
// Handle initialized notification
|
// Handle initialized notification
|
||||||
logx.Info("Received notifications/initialized notification")
|
logx.Info("Received notifications/initialized notification")
|
||||||
if !isNotification {
|
if !isNotification {
|
||||||
s.sendErrorResponse(client, req.ID, "Method should be used as a notification", errCodeInvalidRequest)
|
s.sendErrorResponse(r.Context(), client, req.ID,
|
||||||
|
"Method should be used as a notification", errCodeInvalidRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
s.processNotificationInitialized(client)
|
s.processNotificationInitialized(client)
|
||||||
return
|
return
|
||||||
} else if !client.initialized && req.Method != methodNotificationsCancelled {
|
} else if !client.initialized && req.Method != methodNotificationsCancelled {
|
||||||
// Block most requests until client is initialized (except for cancellations)
|
// Block most requests until client is initialized (except for cancellations)
|
||||||
s.sendErrorResponse(client, req.ID, "Client not fully initialized, waiting for notifications/initialized",
|
s.sendErrorResponse(r.Context(), client, req.ID,
|
||||||
|
"Client not fully initialized, waiting for notifications/initialized",
|
||||||
errCodeClientNotInitialized)
|
errCodeClientNotInitialized)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -205,41 +207,41 @@ func (s *sseMcpServer) handleRequest(w http.ResponseWriter, r *http.Request) {
|
|||||||
switch req.Method {
|
switch req.Method {
|
||||||
case methodToolsCall:
|
case methodToolsCall:
|
||||||
logx.Infof("Received tools call request with ID: %d", req.ID)
|
logx.Infof("Received tools call request with ID: %d", req.ID)
|
||||||
s.processToolCall(client, req)
|
s.processToolCall(r.Context(), client, req)
|
||||||
logx.Infof("Sent tools call response for ID: %d", req.ID)
|
logx.Infof("Sent tools call response for ID: %d", req.ID)
|
||||||
case methodToolsList:
|
case methodToolsList:
|
||||||
logx.Infof("Processing tools/list request with ID: %d", req.ID)
|
logx.Infof("Processing tools/list request with ID: %d", req.ID)
|
||||||
s.processListTools(client, req)
|
s.processListTools(r.Context(), client, req)
|
||||||
logx.Infof("Sent tools/list response for ID: %d", req.ID)
|
logx.Infof("Sent tools/list response for ID: %d", req.ID)
|
||||||
case methodPromptsList:
|
case methodPromptsList:
|
||||||
logx.Infof("Processing prompts/list request with ID: %d", req.ID)
|
logx.Infof("Processing prompts/list request with ID: %d", req.ID)
|
||||||
s.processListPrompts(client, req)
|
s.processListPrompts(r.Context(), client, req)
|
||||||
logx.Infof("Sent prompts/list response for ID: %d", req.ID)
|
logx.Infof("Sent prompts/list response for ID: %d", req.ID)
|
||||||
case methodPromptsGet:
|
case methodPromptsGet:
|
||||||
logx.Infof("Processing prompts/get request with ID: %d", req.ID)
|
logx.Infof("Processing prompts/get request with ID: %d", req.ID)
|
||||||
s.processGetPrompt(client, req)
|
s.processGetPrompt(r.Context(), client, req)
|
||||||
logx.Infof("Sent prompts/get response for ID: %d", req.ID)
|
logx.Infof("Sent prompts/get response for ID: %d", req.ID)
|
||||||
case methodResourcesList:
|
case methodResourcesList:
|
||||||
logx.Infof("Processing resources/list request with ID: %d", req.ID)
|
logx.Infof("Processing resources/list request with ID: %d", req.ID)
|
||||||
s.processListResources(client, req)
|
s.processListResources(r.Context(), client, req)
|
||||||
logx.Infof("Sent resources/list response for ID: %d", req.ID)
|
logx.Infof("Sent resources/list response for ID: %d", req.ID)
|
||||||
case methodResourcesRead:
|
case methodResourcesRead:
|
||||||
logx.Infof("Processing resources/read request with ID: %d", req.ID)
|
logx.Infof("Processing resources/read request with ID: %d", req.ID)
|
||||||
s.processResourcesRead(client, req)
|
s.processResourcesRead(r.Context(), client, req)
|
||||||
logx.Infof("Sent resources/read response for ID: %d", req.ID)
|
logx.Infof("Sent resources/read response for ID: %d", req.ID)
|
||||||
case methodResourcesSubscribe:
|
case methodResourcesSubscribe:
|
||||||
logx.Infof("Processing resources/subscribe request with ID: %d", req.ID)
|
logx.Infof("Processing resources/subscribe request with ID: %d", req.ID)
|
||||||
s.processResourceSubscribe(client, req)
|
s.processResourceSubscribe(r.Context(), client, req)
|
||||||
logx.Infof("Sent resources/subscribe response for ID: %d", req.ID)
|
logx.Infof("Sent resources/subscribe response for ID: %d", req.ID)
|
||||||
case methodPing:
|
case methodPing:
|
||||||
logx.Infof("Processing ping request with ID: %d", req.ID)
|
logx.Infof("Processing ping request with ID: %d", req.ID)
|
||||||
s.processPing(client, req)
|
s.processPing(r.Context(), client, req)
|
||||||
case methodNotificationsCancelled:
|
case methodNotificationsCancelled:
|
||||||
logx.Infof("Received notifications/cancelled notification: %v", req.Params)
|
logx.Infof("Received notifications/cancelled notification: %d", req.ID)
|
||||||
s.processNotificationCancelled(client, req)
|
s.processNotificationCancelled(r.Context(), client, req)
|
||||||
default:
|
default:
|
||||||
logx.Infof("Unknown method: %s", req.Method)
|
logx.Infof("Unknown method: %s from client: %d", req.Method, req.ID)
|
||||||
s.sendErrorResponse(client, req.ID, "Method not found", errCodeMethodNotFound)
|
s.sendErrorResponse(r.Context(), client, req.ID, "Method not found", errCodeMethodNotFound)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -321,7 +323,7 @@ func (s *sseMcpServer) handleSSE(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
case <-r.Context().Done():
|
case <-r.Context().Done():
|
||||||
// Client disconnected or request was canceled
|
// Client disconnected or request was canceled or timed out
|
||||||
logx.Infof("Client %s disconnected: context done", sessionId)
|
logx.Infof("Client %s disconnected: context done", sessionId)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -329,7 +331,7 @@ func (s *sseMcpServer) handleSSE(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// processInitialize processes the initialize request
|
// processInitialize processes the initialize request
|
||||||
func (s *sseMcpServer) processInitialize(client *mcpClient, req Request) {
|
func (s *sseMcpServer) processInitialize(ctx context.Context, client *mcpClient, req Request) {
|
||||||
// Create a proper JSON-RPC response that preserves the client's request ID
|
// Create a proper JSON-RPC response that preserves the client's request ID
|
||||||
result := initializationResponse{
|
result := initializationResponse{
|
||||||
ProtocolVersion: s.conf.Mcp.ProtocolVersion,
|
ProtocolVersion: s.conf.Mcp.ProtocolVersion,
|
||||||
@@ -362,11 +364,11 @@ func (s *sseMcpServer) processInitialize(client *mcpClient, req Request) {
|
|||||||
client.initialized = true
|
client.initialized = true
|
||||||
|
|
||||||
// Send response with client's original request ID
|
// Send response with client's original request ID
|
||||||
s.sendResponse(client, req.ID, result)
|
s.sendResponse(ctx, client, req.ID, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
// processListTools processes the tools/list request
|
// processListTools processes the tools/list request
|
||||||
func (s *sseMcpServer) processListTools(client *mcpClient, req Request) {
|
func (s *sseMcpServer) processListTools(ctx context.Context, client *mcpClient, req Request) {
|
||||||
// Extract pagination params if any
|
// Extract pagination params if any
|
||||||
var nextCursor string
|
var nextCursor string
|
||||||
var progressToken any
|
var progressToken any
|
||||||
@@ -390,6 +392,9 @@ func (s *sseMcpServer) processListTools(client *mcpClient, req Request) {
|
|||||||
var toolsList []Tool
|
var toolsList []Tool
|
||||||
s.toolsLock.Lock()
|
s.toolsLock.Lock()
|
||||||
for _, tool := range s.tools {
|
for _, tool := range s.tools {
|
||||||
|
if len(tool.InputSchema.Type) == 0 {
|
||||||
|
tool.InputSchema.Type = ContentTypeObject
|
||||||
|
}
|
||||||
toolsList = append(toolsList, tool)
|
toolsList = append(toolsList, tool)
|
||||||
}
|
}
|
||||||
s.toolsLock.Unlock()
|
s.toolsLock.Unlock()
|
||||||
@@ -405,15 +410,15 @@ func (s *sseMcpServer) processListTools(client *mcpClient, req Request) {
|
|||||||
// Add meta information if progress token was provided
|
// Add meta information if progress token was provided
|
||||||
if progressToken != nil {
|
if progressToken != nil {
|
||||||
result.Result.Meta = map[string]any{
|
result.Result.Meta = map[string]any{
|
||||||
"progressToken": progressToken,
|
progressTokenKey: progressToken,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
s.sendResponse(client, req.ID, result)
|
s.sendResponse(ctx, client, req.ID, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
// processListPrompts processes the prompts/list request
|
// processListPrompts processes the prompts/list request
|
||||||
func (s *sseMcpServer) processListPrompts(client *mcpClient, req Request) {
|
func (s *sseMcpServer) processListPrompts(ctx context.Context, client *mcpClient, req Request) {
|
||||||
// Extract pagination params if any
|
// Extract pagination params if any
|
||||||
var nextCursor string
|
var nextCursor string
|
||||||
if req.Params != nil {
|
if req.Params != nil {
|
||||||
@@ -447,11 +452,11 @@ func (s *sseMcpServer) processListPrompts(client *mcpClient, req Request) {
|
|||||||
NextCursor: nextCursor,
|
NextCursor: nextCursor,
|
||||||
}
|
}
|
||||||
|
|
||||||
s.sendResponse(client, req.ID, result)
|
s.sendResponse(ctx, client, req.ID, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
// processListResources processes the resources/list request
|
// processListResources processes the resources/list request
|
||||||
func (s *sseMcpServer) processListResources(client *mcpClient, req Request) {
|
func (s *sseMcpServer) processListResources(ctx context.Context, client *mcpClient, req Request) {
|
||||||
// Extract pagination params if any
|
// Extract pagination params if any
|
||||||
var nextCursor string
|
var nextCursor string
|
||||||
var progressToken any
|
var progressToken any
|
||||||
@@ -493,15 +498,15 @@ func (s *sseMcpServer) processListResources(client *mcpClient, req Request) {
|
|||||||
// Add meta information if progress token was provided
|
// Add meta information if progress token was provided
|
||||||
if progressToken != nil {
|
if progressToken != nil {
|
||||||
result.Result.Meta = map[string]any{
|
result.Result.Meta = map[string]any{
|
||||||
"progressToken": progressToken,
|
progressTokenKey: progressToken,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
s.sendResponse(client, req.ID, result)
|
s.sendResponse(ctx, client, req.ID, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
// processGetPrompt processes the prompts/get request
|
// processGetPrompt processes the prompts/get request
|
||||||
func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) {
|
func (s *sseMcpServer) processGetPrompt(ctx context.Context, client *mcpClient, req Request) {
|
||||||
type GetPromptParams struct {
|
type GetPromptParams struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Arguments map[string]string `json:"arguments,omitempty"`
|
Arguments map[string]string `json:"arguments,omitempty"`
|
||||||
@@ -509,7 +514,7 @@ func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) {
|
|||||||
|
|
||||||
var params GetPromptParams
|
var params GetPromptParams
|
||||||
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
||||||
s.sendErrorResponse(client, req.ID, "Invalid parameters", errCodeInvalidParams)
|
s.sendErrorResponse(ctx, client, req.ID, "Invalid parameters", errCodeInvalidParams)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -519,7 +524,7 @@ func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) {
|
|||||||
s.promptsLock.Unlock()
|
s.promptsLock.Unlock()
|
||||||
if !exists {
|
if !exists {
|
||||||
message := fmt.Sprintf("Prompt '%s' not found", params.Name)
|
message := fmt.Sprintf("Prompt '%s' not found", params.Name)
|
||||||
s.sendErrorResponse(client, req.ID, message, errCodeInvalidParams)
|
s.sendErrorResponse(ctx, client, req.ID, message, errCodeInvalidParams)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -529,12 +534,15 @@ func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) {
|
|||||||
missingArgs := validatePromptArguments(prompt, params.Arguments)
|
missingArgs := validatePromptArguments(prompt, params.Arguments)
|
||||||
if len(missingArgs) > 0 {
|
if len(missingArgs) > 0 {
|
||||||
message := fmt.Sprintf("Missing required arguments: %s", strings.Join(missingArgs, ", "))
|
message := fmt.Sprintf("Missing required arguments: %s", strings.Join(missingArgs, ", "))
|
||||||
s.sendErrorResponse(client, req.ID, message, errCodeInvalidParams)
|
s.sendErrorResponse(ctx, client, req.ID, message, errCodeInvalidParams)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply default values for missing optional arguments
|
// Ensure arguments are initialized to an empty map if nil
|
||||||
args := applyDefaultArguments(prompt, params.Arguments)
|
if params.Arguments == nil {
|
||||||
|
params.Arguments = make(map[string]string)
|
||||||
|
}
|
||||||
|
args := params.Arguments
|
||||||
|
|
||||||
// Generate messages using handler or static content
|
// Generate messages using handler or static content
|
||||||
var messages []PromptMessage
|
var messages []PromptMessage
|
||||||
@@ -542,17 +550,17 @@ func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) {
|
|||||||
|
|
||||||
if prompt.Handler != nil {
|
if prompt.Handler != nil {
|
||||||
// Use dynamic handler to generate messages
|
// Use dynamic handler to generate messages
|
||||||
logx.Info("Using prompt handler to generate content")
|
messages, err = prompt.Handler(ctx, args)
|
||||||
messages, err = prompt.Handler(args)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logx.Errorf("Error from prompt handler: %v", err)
|
logx.Errorf("Error from prompt handler: %v", err)
|
||||||
s.sendErrorResponse(client, req.ID, fmt.Sprintf("Error generating prompt content: %v", err), errCodeInternalError)
|
s.sendErrorResponse(ctx, client, req.ID,
|
||||||
|
fmt.Sprintf("Error generating prompt content: %v", err), errCodeInternalError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// No handler, generate messages from static content
|
// No handler, generate messages from static content
|
||||||
var messageText string
|
var messageText string
|
||||||
if prompt.Content != "" {
|
if len(prompt.Content) > 0 {
|
||||||
messageText = prompt.Content
|
messageText = prompt.Content
|
||||||
|
|
||||||
// Apply argument substitutions to static content
|
// Apply argument substitutions to static content
|
||||||
@@ -560,21 +568,13 @@ func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) {
|
|||||||
placeholder := fmt.Sprintf("{{%s}}", key)
|
placeholder := fmt.Sprintf("{{%s}}", key)
|
||||||
messageText = strings.Replace(messageText, placeholder, value, -1)
|
messageText = strings.Replace(messageText, placeholder, value, -1)
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
// No content, use a default fallback
|
|
||||||
topic := "this topic"
|
|
||||||
if t, ok := args["topic"]; ok && t != "" {
|
|
||||||
topic = t
|
|
||||||
}
|
|
||||||
messageText = fmt.Sprintf("Tell me about %s", topic)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a single user message with the content
|
// Create a single user message with the content
|
||||||
messages = []PromptMessage{
|
messages = []PromptMessage{
|
||||||
{
|
{
|
||||||
Role: roleUser,
|
Role: RoleUser,
|
||||||
Content: TextContent{
|
Content: TextContent{
|
||||||
Type: contentTypeText,
|
|
||||||
Text: messageText,
|
Text: messageText,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -587,49 +587,14 @@ func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) {
|
|||||||
Messages []PromptMessage `json:"messages"`
|
Messages []PromptMessage `json:"messages"`
|
||||||
}{
|
}{
|
||||||
Description: prompt.Description,
|
Description: prompt.Description,
|
||||||
Messages: messages,
|
Messages: toTypedPromptMessages(messages),
|
||||||
}
|
}
|
||||||
|
|
||||||
s.sendResponse(client, req.ID, result)
|
s.sendResponse(ctx, client, req.ID, result)
|
||||||
}
|
|
||||||
|
|
||||||
// validatePromptArguments checks if all required arguments are provided
|
|
||||||
// Returns a list of missing required arguments
|
|
||||||
func validatePromptArguments(prompt Prompt, providedArgs map[string]string) []string {
|
|
||||||
var missingArgs []string
|
|
||||||
|
|
||||||
for _, arg := range prompt.Arguments {
|
|
||||||
if arg.Required {
|
|
||||||
if value, exists := providedArgs[arg.Name]; !exists || len(value) == 0 {
|
|
||||||
missingArgs = append(missingArgs, arg.Name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return missingArgs
|
|
||||||
}
|
|
||||||
|
|
||||||
// applyDefaultArguments adds default values for missing optional arguments
|
|
||||||
func applyDefaultArguments(prompt Prompt, providedArgs map[string]string) map[string]string {
|
|
||||||
result := make(map[string]string)
|
|
||||||
|
|
||||||
// Copy all provided arguments
|
|
||||||
for k, v := range providedArgs {
|
|
||||||
result[k] = v
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add defaults for missing arguments
|
|
||||||
for _, arg := range prompt.Arguments {
|
|
||||||
if _, exists := result[arg.Name]; !exists && arg.Default != "" {
|
|
||||||
result[arg.Name] = arg.Default
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return result
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// processToolCall processes the tools/call request
|
// processToolCall processes the tools/call request
|
||||||
func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
func (s *sseMcpServer) processToolCall(ctx context.Context, client *mcpClient, req Request) {
|
||||||
var toolCallParams struct {
|
var toolCallParams struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Arguments map[string]any `json:"arguments,omitempty"`
|
Arguments map[string]any `json:"arguments,omitempty"`
|
||||||
@@ -642,7 +607,7 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
|||||||
// If it's a RawMessage (JSON), unmarshal it
|
// If it's a RawMessage (JSON), unmarshal it
|
||||||
if err := json.Unmarshal(req.Params, &toolCallParams); err != nil {
|
if err := json.Unmarshal(req.Params, &toolCallParams); err != nil {
|
||||||
logx.Errorf("Failed to unmarshal tool call params: %v", err)
|
logx.Errorf("Failed to unmarshal tool call params: %v", err)
|
||||||
s.sendErrorResponse(client, req.ID, "Invalid tool call parameters", errCodeInvalidParams)
|
s.sendErrorResponse(ctx, client, req.ID, "Invalid tool call parameters", errCodeInvalidParams)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -654,15 +619,11 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
|||||||
tool, exists := s.tools[toolCallParams.Name]
|
tool, exists := s.tools[toolCallParams.Name]
|
||||||
s.toolsLock.Unlock()
|
s.toolsLock.Unlock()
|
||||||
if !exists {
|
if !exists {
|
||||||
s.sendErrorResponse(client, req.ID, fmt.Sprintf("Tool '%s' not found",
|
s.sendErrorResponse(ctx, client, req.ID, fmt.Sprintf("Tool '%s' not found",
|
||||||
toolCallParams.Name), errCodeInvalidParams)
|
toolCallParams.Name), errCodeInvalidParams)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a context with the configured timeout
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), s.conf.Mcp.ToolTimeout)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
// Log parameters before execution
|
// Log parameters before execution
|
||||||
logx.Infof("Executing tool '%s' with arguments: %#v", toolCallParams.Name, toolCallParams.Arguments)
|
logx.Infof("Executing tool '%s' with arguments: %#v", toolCallParams.Name, toolCallParams.Arguments)
|
||||||
|
|
||||||
@@ -671,6 +632,7 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
|||||||
var err error
|
var err error
|
||||||
|
|
||||||
// Create a channel to receive the result
|
// Create a channel to receive the result
|
||||||
|
// make sure to have 1 size buffer to avoid channel leak if timeout
|
||||||
resultCh := make(chan struct {
|
resultCh := make(chan struct {
|
||||||
result any
|
result any
|
||||||
err error
|
err error
|
||||||
@@ -678,7 +640,7 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
|||||||
|
|
||||||
// Execute the tool handler in a goroutine
|
// Execute the tool handler in a goroutine
|
||||||
go func() {
|
go func() {
|
||||||
toolResult, toolErr := tool.Handler(toolCallParams.Arguments)
|
toolResult, toolErr := tool.Handler(ctx, toolCallParams.Arguments)
|
||||||
resultCh <- struct {
|
resultCh <- struct {
|
||||||
result any
|
result any
|
||||||
err error
|
err error
|
||||||
@@ -694,9 +656,9 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
|||||||
result = res.result
|
result = res.result
|
||||||
err = res.err
|
err = res.err
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
// Handle timeout
|
// Handle request timeout
|
||||||
logx.Errorf("Tool execution timed out after %v: %s", s.conf.Mcp.ToolTimeout, toolCallParams.Name)
|
logx.Errorf("Tool execution timed out after %v: %s", s.conf.Mcp.MessageTimeout, toolCallParams.Name)
|
||||||
s.sendErrorResponse(client, req.ID, "Tool execution timed out", errCodeTimeout)
|
s.sendErrorResponse(ctx, client, req.ID, "Tool execution timed out", errCodeTimeout)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -710,7 +672,7 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
|||||||
// Add meta information if progress token was provided
|
// Add meta information if progress token was provided
|
||||||
if progressToken != nil {
|
if progressToken != nil {
|
||||||
callToolResult.Result.Meta = map[string]any{
|
callToolResult.Result.Meta = map[string]any{
|
||||||
"progressToken": progressToken,
|
progressTokenKey: progressToken,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -722,12 +684,11 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
|||||||
|
|
||||||
callToolResult.Content = []any{
|
callToolResult.Content = []any{
|
||||||
TextContent{
|
TextContent{
|
||||||
Type: contentTypeText,
|
|
||||||
Text: fmt.Sprintf("Error: %v", err),
|
Text: fmt.Sprintf("Error: %v", err),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
callToolResult.IsError = true
|
callToolResult.IsError = true
|
||||||
s.sendResponse(client, req.ID, callToolResult)
|
s.sendResponse(ctx, client, req.ID, callToolResult)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -736,10 +697,9 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
|||||||
case string:
|
case string:
|
||||||
// Simple string becomes text content
|
// Simple string becomes text content
|
||||||
callToolResult.Content = append(callToolResult.Content, TextContent{
|
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||||
Type: contentTypeText,
|
|
||||||
Text: v,
|
Text: v,
|
||||||
Annotations: &Annotations{
|
Annotations: &Annotations{
|
||||||
Audience: []roleType{roleUser, roleAssistant},
|
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
case map[string]any:
|
case map[string]any:
|
||||||
@@ -749,69 +709,63 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
|||||||
jsonStr = []byte(err.Error())
|
jsonStr = []byte(err.Error())
|
||||||
}
|
}
|
||||||
callToolResult.Content = append(callToolResult.Content, TextContent{
|
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||||
Type: contentTypeText,
|
|
||||||
Text: string(jsonStr),
|
Text: string(jsonStr),
|
||||||
Annotations: &Annotations{
|
Annotations: &Annotations{
|
||||||
Audience: []roleType{roleUser, roleAssistant},
|
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
case TextContent:
|
case TextContent:
|
||||||
// Direct TextContent object
|
|
||||||
callToolResult.Content = append(callToolResult.Content, v)
|
callToolResult.Content = append(callToolResult.Content, v)
|
||||||
case ImageContent:
|
case ImageContent:
|
||||||
// Direct ImageContent object
|
|
||||||
callToolResult.Content = append(callToolResult.Content, v)
|
callToolResult.Content = append(callToolResult.Content, v)
|
||||||
case []any:
|
case []any:
|
||||||
// Array of content items
|
|
||||||
callToolResult.Content = v
|
callToolResult.Content = v
|
||||||
case ToolResult:
|
case ToolResult:
|
||||||
// Handle legacy ToolResult type
|
// Handle legacy ToolResult type
|
||||||
switch v.Type {
|
switch v.Type {
|
||||||
case contentTypeText:
|
case ContentTypeText:
|
||||||
callToolResult.Content = append(callToolResult.Content, TextContent{
|
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||||
Type: contentTypeText,
|
|
||||||
Text: fmt.Sprintf("%v", v.Content),
|
Text: fmt.Sprintf("%v", v.Content),
|
||||||
Annotations: &Annotations{
|
Annotations: &Annotations{
|
||||||
Audience: []roleType{roleUser, roleAssistant},
|
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
case contentTypeImage:
|
case ContentTypeImage:
|
||||||
if imgData, ok := v.Content.(map[string]any); ok {
|
if imgData, ok := v.Content.(map[string]any); ok {
|
||||||
callToolResult.Content = append(callToolResult.Content, ImageContent{
|
callToolResult.Content = append(callToolResult.Content, ImageContent{
|
||||||
Type: contentTypeImage,
|
|
||||||
Data: fmt.Sprintf("%v", imgData["data"]),
|
Data: fmt.Sprintf("%v", imgData["data"]),
|
||||||
MimeType: fmt.Sprintf("%v", imgData["mimeType"]),
|
MimeType: fmt.Sprintf("%v", imgData["mimeType"]),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
callToolResult.Content = append(callToolResult.Content, TextContent{
|
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||||
Type: contentTypeText,
|
|
||||||
Text: fmt.Sprintf("%v", v.Content),
|
Text: fmt.Sprintf("%v", v.Content),
|
||||||
Annotations: &Annotations{
|
Annotations: &Annotations{
|
||||||
Audience: []roleType{roleUser, roleAssistant},
|
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
// For any other type, convert to string
|
// For any other type, convert to string
|
||||||
callToolResult.Content = append(callToolResult.Content, TextContent{
|
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||||
Type: contentTypeText,
|
|
||||||
Text: fmt.Sprintf("%v", v),
|
Text: fmt.Sprintf("%v", v),
|
||||||
Annotations: &Annotations{
|
Annotations: &Annotations{
|
||||||
Audience: []roleType{roleUser, roleAssistant},
|
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
callToolResult.Content = toTypedContents(callToolResult.Content)
|
||||||
logx.Infof("Tool call result: %#v", callToolResult)
|
logx.Infof("Tool call result: %#v", callToolResult)
|
||||||
s.sendResponse(client, req.ID, callToolResult)
|
|
||||||
|
s.sendResponse(ctx, client, req.ID, callToolResult)
|
||||||
}
|
}
|
||||||
|
|
||||||
// processResourcesRead processes the resources/read request
|
// processResourcesRead processes the resources/read request
|
||||||
func (s *sseMcpServer) processResourcesRead(client *mcpClient, req Request) {
|
func (s *sseMcpServer) processResourcesRead(ctx context.Context, client *mcpClient, req Request) {
|
||||||
var params ResourceReadParams
|
var params ResourceReadParams
|
||||||
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
||||||
s.sendErrorResponse(client, req.ID, "Invalid parameters", errCodeInvalidParams)
|
s.sendErrorResponse(ctx, client, req.ID, "Invalid parameters", errCodeInvalidParams)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -821,7 +775,7 @@ func (s *sseMcpServer) processResourcesRead(client *mcpClient, req Request) {
|
|||||||
s.resourcesLock.Unlock()
|
s.resourcesLock.Unlock()
|
||||||
|
|
||||||
if !exists {
|
if !exists {
|
||||||
s.sendErrorResponse(client, req.ID, fmt.Sprintf("Resource with URI '%s' not found",
|
s.sendErrorResponse(ctx, client, req.ID, fmt.Sprintf("Resource with URI '%s' not found",
|
||||||
params.URI), errCodeResourceNotFound)
|
params.URI), errCodeResourceNotFound)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -837,14 +791,14 @@ func (s *sseMcpServer) processResourcesRead(client *mcpClient, req Request) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
s.sendResponse(client, req.ID, result)
|
s.sendResponse(ctx, client, req.ID, result)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute the resource handler
|
// Execute the resource handler
|
||||||
content, err := resource.Handler()
|
content, err := resource.Handler(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.sendErrorResponse(client, req.ID, fmt.Sprintf("Error reading resource: %v", err),
|
s.sendErrorResponse(ctx, client, req.ID, fmt.Sprintf("Error reading resource: %v", err),
|
||||||
errCodeInternalError)
|
errCodeInternalError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -865,14 +819,14 @@ func (s *sseMcpServer) processResourcesRead(client *mcpClient, req Request) {
|
|||||||
Contents: []ResourceContent{content},
|
Contents: []ResourceContent{content},
|
||||||
}
|
}
|
||||||
|
|
||||||
s.sendResponse(client, req.ID, result)
|
s.sendResponse(ctx, client, req.ID, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
// processResourceSubscribe processes the resources/subscribe request
|
// processResourceSubscribe processes the resources/subscribe request
|
||||||
func (s *sseMcpServer) processResourceSubscribe(client *mcpClient, req Request) {
|
func (s *sseMcpServer) processResourceSubscribe(ctx context.Context, client *mcpClient, req Request) {
|
||||||
var params ResourceSubscribeParams
|
var params ResourceSubscribeParams
|
||||||
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
||||||
s.sendErrorResponse(client, req.ID, "Invalid parameters", errCodeInvalidParams)
|
s.sendErrorResponse(ctx, client, req.ID, "Invalid parameters", errCodeInvalidParams)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -882,19 +836,17 @@ func (s *sseMcpServer) processResourceSubscribe(client *mcpClient, req Request)
|
|||||||
s.resourcesLock.Unlock()
|
s.resourcesLock.Unlock()
|
||||||
|
|
||||||
if !exists {
|
if !exists {
|
||||||
s.sendErrorResponse(client, req.ID, fmt.Sprintf("Resource with URI '%s' not found",
|
s.sendErrorResponse(ctx, client, req.ID, fmt.Sprintf("Resource with URI '%s' not found",
|
||||||
params.URI), errCodeResourceNotFound)
|
params.URI), errCodeResourceNotFound)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send success response for the subscription
|
// Send success response for the subscription
|
||||||
s.sendResponse(client, req.ID, struct{}{})
|
s.sendResponse(ctx, client, req.ID, struct{}{})
|
||||||
|
|
||||||
logx.Infof("Client %s subscribed to resource '%s'", client.id, params.URI)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// processNotificationCancelled processes the notifications/cancelled notification
|
// processNotificationCancelled processes the notifications/cancelled notification
|
||||||
func (s *sseMcpServer) processNotificationCancelled(client *mcpClient, req Request) {
|
func (s *sseMcpServer) processNotificationCancelled(ctx context.Context, client *mcpClient, req Request) {
|
||||||
// Extract the requestId that was canceled
|
// Extract the requestId that was canceled
|
||||||
type CancelParams struct {
|
type CancelParams struct {
|
||||||
RequestId int64 `json:"requestId"`
|
RequestId int64 `json:"requestId"`
|
||||||
@@ -918,18 +870,17 @@ func (s *sseMcpServer) processNotificationInitialized(client *mcpClient) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// processPing processes the ping request and responds immediately
|
// processPing processes the ping request and responds immediately
|
||||||
func (s *sseMcpServer) processPing(client *mcpClient, req Request) {
|
func (s *sseMcpServer) processPing(ctx context.Context, client *mcpClient, req Request) {
|
||||||
// A ping request should simply respond with an empty result to confirm the server is alive
|
// A ping request should simply respond with an empty result to confirm the server is alive
|
||||||
logx.Infof("Received ping request with ID: %d", req.ID)
|
logx.Infof("Received ping request with ID: %d", req.ID)
|
||||||
|
|
||||||
// Send an empty response with client's original request ID
|
// Send an empty response with client's original request ID
|
||||||
s.sendResponse(client, req.ID, struct{}{})
|
s.sendResponse(ctx, client, req.ID, struct{}{})
|
||||||
|
|
||||||
logx.Infof("Sent ping response for ID: %d", req.ID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// sendErrorResponse sends an error response via the SSE channel
|
// sendErrorResponse sends an error response via the SSE channel
|
||||||
func (s *sseMcpServer) sendErrorResponse(client *mcpClient, id int64, message string, code int) {
|
func (s *sseMcpServer) sendErrorResponse(ctx context.Context, client *mcpClient,
|
||||||
|
id int64, message string, code int) {
|
||||||
errorResponse := struct {
|
errorResponse := struct {
|
||||||
JsonRpc string `json:"jsonrpc"`
|
JsonRpc string `json:"jsonrpc"`
|
||||||
ID int64 `json:"id"`
|
ID int64 `json:"id"`
|
||||||
@@ -949,11 +900,17 @@ func (s *sseMcpServer) sendErrorResponse(client *mcpClient, id int64, message st
|
|||||||
sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData))
|
sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData))
|
||||||
logx.Infof("Sending error for ID %d: %s", id, sseMessage)
|
logx.Infof("Sending error for ID %d: %s", id, sseMessage)
|
||||||
|
|
||||||
client.channel <- sseMessage
|
// cannot receive from ctx.Done() because we're sending to the channel for SSE messages
|
||||||
|
select {
|
||||||
|
case client.channel <- sseMessage:
|
||||||
|
default:
|
||||||
|
// Channel buffer is full, log warning and continue
|
||||||
|
logx.Infof("Client %s channel is full while sending error response with ID %d", client.id, id)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// sendResponse sends a success response via the SSE channel
|
// sendResponse sends a success response via the SSE channel
|
||||||
func (s *sseMcpServer) sendResponse(client *mcpClient, id int64, result any) {
|
func (s *sseMcpServer) sendResponse(ctx context.Context, client *mcpClient, id int64, result any) {
|
||||||
response := Response{
|
response := Response{
|
||||||
JsonRpc: jsonRpcVersion,
|
JsonRpc: jsonRpcVersion,
|
||||||
ID: id,
|
ID: id,
|
||||||
@@ -962,7 +919,7 @@ func (s *sseMcpServer) sendResponse(client *mcpClient, id int64, result any) {
|
|||||||
|
|
||||||
jsonData, err := json.Marshal(response)
|
jsonData, err := json.Marshal(response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.sendErrorResponse(client, id, "Failed to marshal response", errCodeInternalError)
|
s.sendErrorResponse(ctx, client, id, "Failed to marshal response", errCodeInternalError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -970,5 +927,11 @@ func (s *sseMcpServer) sendResponse(client *mcpClient, id int64, result any) {
|
|||||||
sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData))
|
sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData))
|
||||||
logx.Infof("Sending response for ID %d: %s", id, sseMessage)
|
logx.Infof("Sending response for ID %d: %s", id, sseMessage)
|
||||||
|
|
||||||
client.channel <- sseMessage
|
// cannot receive from ctx.Done() because we're sending to the channel for SSE messages
|
||||||
|
select {
|
||||||
|
case client.channel <- sseMessage:
|
||||||
|
default:
|
||||||
|
// Channel buffer is full, log warning and continue
|
||||||
|
logx.Infof("Client %s channel is full while sending response with ID %d", client.id, id)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ host: localhost
|
|||||||
port: 8080
|
port: 8080
|
||||||
mcp:
|
mcp:
|
||||||
name: mcp-test-server
|
name: mcp-test-server
|
||||||
toolTimeout: 5s
|
messageTimeout: 5s
|
||||||
`
|
`
|
||||||
|
|
||||||
var c McpConf
|
var c McpConf
|
||||||
@@ -82,7 +82,6 @@ func (m *mockMcpServer) registerExampleTool() {
|
|||||||
Name: "test.tool",
|
Name: "test.tool",
|
||||||
Description: "A test tool",
|
Description: "A test tool",
|
||||||
InputSchema: InputSchema{
|
InputSchema: InputSchema{
|
||||||
Type: "object",
|
|
||||||
Properties: map[string]any{
|
Properties: map[string]any{
|
||||||
"input": map[string]any{
|
"input": map[string]any{
|
||||||
"type": "string",
|
"type": "string",
|
||||||
@@ -91,7 +90,7 @@ func (m *mockMcpServer) registerExampleTool() {
|
|||||||
},
|
},
|
||||||
Required: []string{"input"},
|
Required: []string{"input"},
|
||||||
},
|
},
|
||||||
Handler: func(params map[string]any) (any, error) {
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
input, ok := params["input"].(string)
|
input, ok := params["input"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("invalid input parameter")
|
return nil, fmt.Errorf("invalid input parameter")
|
||||||
@@ -135,7 +134,7 @@ port: 8080
|
|||||||
mcp:
|
mcp:
|
||||||
cors:
|
cors:
|
||||||
- http://localhost:3000
|
- http://localhost:3000
|
||||||
toolTimeout: 5s
|
messageTimeout: 5s
|
||||||
`
|
`
|
||||||
|
|
||||||
var c McpConf
|
var c McpConf
|
||||||
@@ -186,7 +185,6 @@ func TestRegisterTool(t *testing.T) {
|
|||||||
Name: "example.tool",
|
Name: "example.tool",
|
||||||
Description: "An example tool",
|
Description: "An example tool",
|
||||||
InputSchema: InputSchema{
|
InputSchema: InputSchema{
|
||||||
Type: "object",
|
|
||||||
Properties: map[string]any{
|
Properties: map[string]any{
|
||||||
"input": map[string]any{
|
"input": map[string]any{
|
||||||
"type": "string",
|
"type": "string",
|
||||||
@@ -194,7 +192,7 @@ func TestRegisterTool(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Handler: func(params map[string]any) (any, error) {
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
return "result", nil
|
return "result", nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -280,7 +278,7 @@ func TestToolsList(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Process the tool call
|
// Process the tool call
|
||||||
mock.server.processListTools(client, req)
|
mock.server.processListTools(context.Background(), client, req)
|
||||||
|
|
||||||
// Get the response from the client's channel
|
// Get the response from the client's channel
|
||||||
select {
|
select {
|
||||||
@@ -328,7 +326,7 @@ func TestToolCallBasic(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Process the tool call
|
// Process the tool call
|
||||||
mock.server.processToolCall(client, req)
|
mock.server.processToolCall(context.Background(), client, req)
|
||||||
|
|
||||||
// Get the response from the client's channel
|
// Get the response from the client's channel
|
||||||
select {
|
select {
|
||||||
@@ -355,8 +353,7 @@ func TestToolCallBasic(t *testing.T) {
|
|||||||
|
|
||||||
// Verify the response content
|
// Verify the response content
|
||||||
assert.Len(t, parsed.Result.Content, 1, "Response should contain one content item")
|
assert.Len(t, parsed.Result.Content, 1, "Response should contain one content item")
|
||||||
assert.Equal(t, "text", parsed.Result.Content[0]["type"], "Content type should be text")
|
assert.Equal(t, "Processed: test-input", parsed.Result.Content[0][ContentTypeText], "Tool result incorrect")
|
||||||
assert.Equal(t, "Processed: test-input", parsed.Result.Content[0]["text"], "Tool result incorrect")
|
|
||||||
assert.False(t, parsed.Result.IsError, "Response should not be an error")
|
assert.False(t, parsed.Result.IsError, "Response should not be an error")
|
||||||
case <-time.After(100 * time.Millisecond):
|
case <-time.After(100 * time.Millisecond):
|
||||||
t.Fatal("Timed out waiting for tool call response")
|
t.Fatal("Timed out waiting for tool call response")
|
||||||
@@ -373,10 +370,9 @@ func TestToolCallMapResult(t *testing.T) {
|
|||||||
Name: "map.tool",
|
Name: "map.tool",
|
||||||
Description: "A tool that returns a map result",
|
Description: "A tool that returns a map result",
|
||||||
InputSchema: InputSchema{
|
InputSchema: InputSchema{
|
||||||
Type: "object",
|
|
||||||
Properties: map[string]any{},
|
Properties: map[string]any{},
|
||||||
},
|
},
|
||||||
Handler: func(params map[string]any) (any, error) {
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
// Return a complex nested map structure
|
// Return a complex nested map structure
|
||||||
return map[string]any{
|
return map[string]any{
|
||||||
"string": "value",
|
"string": "value",
|
||||||
@@ -417,7 +413,7 @@ func TestToolCallMapResult(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Process the tool call
|
// Process the tool call
|
||||||
mock.server.processToolCall(client, req)
|
mock.server.processToolCall(context.Background(), client, req)
|
||||||
|
|
||||||
// Get the response from the client's channel
|
// Get the response from the client's channel
|
||||||
select {
|
select {
|
||||||
@@ -445,13 +441,8 @@ func TestToolCallMapResult(t *testing.T) {
|
|||||||
firstItem, ok := content[0].(map[string]any)
|
firstItem, ok := content[0].(map[string]any)
|
||||||
require.True(t, ok, "First content item should be an object")
|
require.True(t, ok, "First content item should be an object")
|
||||||
|
|
||||||
// Verify it's a text content
|
|
||||||
contentType, ok := firstItem["type"].(string)
|
|
||||||
require.True(t, ok, "Content should have a type")
|
|
||||||
assert.Equal(t, "text", contentType, "Content type should be text")
|
|
||||||
|
|
||||||
// Get the text content which should be our JSON
|
// Get the text content which should be our JSON
|
||||||
text, ok := firstItem["text"].(string)
|
text, ok := firstItem[ContentTypeText].(string)
|
||||||
require.True(t, ok, "Content should have text")
|
require.True(t, ok, "Content should have text")
|
||||||
|
|
||||||
// Verify the text is valid JSON and contains our data
|
// Verify the text is valid JSON and contains our data
|
||||||
@@ -496,10 +487,9 @@ func TestToolCallArrayResult(t *testing.T) {
|
|||||||
Name: "array.tool",
|
Name: "array.tool",
|
||||||
Description: "A tool that returns an array result",
|
Description: "A tool that returns an array result",
|
||||||
InputSchema: InputSchema{
|
InputSchema: InputSchema{
|
||||||
Type: "object",
|
|
||||||
Properties: map[string]any{},
|
Properties: map[string]any{},
|
||||||
},
|
},
|
||||||
Handler: func(params map[string]any) (any, error) {
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
// Return an array of mixed content types
|
// Return an array of mixed content types
|
||||||
return []any{
|
return []any{
|
||||||
"string item",
|
"string item",
|
||||||
@@ -536,7 +526,7 @@ func TestToolCallArrayResult(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Process the tool call
|
// Process the tool call
|
||||||
mock.server.processToolCall(client, req)
|
mock.server.processToolCall(context.Background(), client, req)
|
||||||
|
|
||||||
// Get the response from the client's channel
|
// Get the response from the client's channel
|
||||||
select {
|
select {
|
||||||
@@ -574,16 +564,14 @@ func TestToolCallTextContentResult(t *testing.T) {
|
|||||||
Name: "text.content.tool",
|
Name: "text.content.tool",
|
||||||
Description: "A tool that returns a TextContent result",
|
Description: "A tool that returns a TextContent result",
|
||||||
InputSchema: InputSchema{
|
InputSchema: InputSchema{
|
||||||
Type: "object",
|
|
||||||
Properties: map[string]any{},
|
Properties: map[string]any{},
|
||||||
},
|
},
|
||||||
Handler: func(params map[string]any) (any, error) {
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
// Return a TextContent object directly
|
// Return a TextContent object directly
|
||||||
return TextContent{
|
return TextContent{
|
||||||
Type: "text",
|
|
||||||
Text: "This is a direct TextContent result",
|
Text: "This is a direct TextContent result",
|
||||||
Annotations: &Annotations{
|
Annotations: &Annotations{
|
||||||
Audience: []roleType{roleUser, roleAssistant},
|
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||||
Priority: func() *float64 { p := 0.9; return &p }(),
|
Priority: func() *float64 { p := 0.9; return &p }(),
|
||||||
},
|
},
|
||||||
}, nil
|
}, nil
|
||||||
@@ -614,7 +602,7 @@ func TestToolCallTextContentResult(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Process the tool call
|
// Process the tool call
|
||||||
mock.server.processToolCall(client, req)
|
mock.server.processToolCall(context.Background(), client, req)
|
||||||
|
|
||||||
// Get the response from the client's channel
|
// Get the response from the client's channel
|
||||||
select {
|
select {
|
||||||
@@ -642,16 +630,6 @@ func TestToolCallTextContentResult(t *testing.T) {
|
|||||||
firstItem, ok := content[0].(map[string]any)
|
firstItem, ok := content[0].(map[string]any)
|
||||||
require.True(t, ok, "First content item should be an object")
|
require.True(t, ok, "First content item should be an object")
|
||||||
|
|
||||||
// Verify it's a text content
|
|
||||||
contentType, ok := firstItem["type"].(string)
|
|
||||||
require.True(t, ok, "Content should have a type")
|
|
||||||
assert.Equal(t, "text", contentType, "Content type should be text")
|
|
||||||
|
|
||||||
// Check text content
|
|
||||||
text, ok := firstItem["text"].(string)
|
|
||||||
require.True(t, ok, "Content should have text")
|
|
||||||
assert.Equal(t, "This is a direct TextContent result", text, "Text content should match")
|
|
||||||
|
|
||||||
// Check annotations
|
// Check annotations
|
||||||
annotations, ok := firstItem["annotations"].(map[string]any)
|
annotations, ok := firstItem["annotations"].(map[string]any)
|
||||||
require.True(t, ok, "Should have annotations")
|
require.True(t, ok, "Should have annotations")
|
||||||
@@ -679,13 +657,11 @@ func TestToolCallImageContentResult(t *testing.T) {
|
|||||||
Name: "image.content.tool",
|
Name: "image.content.tool",
|
||||||
Description: "A tool that returns an ImageContent result",
|
Description: "A tool that returns an ImageContent result",
|
||||||
InputSchema: InputSchema{
|
InputSchema: InputSchema{
|
||||||
Type: "object",
|
|
||||||
Properties: map[string]any{},
|
Properties: map[string]any{},
|
||||||
},
|
},
|
||||||
Handler: func(params map[string]any) (any, error) {
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
// Return an ImageContent object directly
|
// Return an ImageContent object directly
|
||||||
return ImageContent{
|
return ImageContent{
|
||||||
Type: "image",
|
|
||||||
Data: "dGVzdCBpbWFnZSBkYXRhIChiYXNlNjQgZW5jb2RlZCk=", // "test image data (base64 encoded)" in base64
|
Data: "dGVzdCBpbWFnZSBkYXRhIChiYXNlNjQgZW5jb2RlZCk=", // "test image data (base64 encoded)" in base64
|
||||||
MimeType: "image/png",
|
MimeType: "image/png",
|
||||||
}, nil
|
}, nil
|
||||||
@@ -716,7 +692,7 @@ func TestToolCallImageContentResult(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Process the tool call
|
// Process the tool call
|
||||||
mock.server.processToolCall(client, req)
|
mock.server.processToolCall(context.Background(), client, req)
|
||||||
|
|
||||||
// Get the response from the client's channel
|
// Get the response from the client's channel
|
||||||
select {
|
select {
|
||||||
@@ -744,11 +720,6 @@ func TestToolCallImageContentResult(t *testing.T) {
|
|||||||
firstItem, ok := content[0].(map[string]any)
|
firstItem, ok := content[0].(map[string]any)
|
||||||
require.True(t, ok, "First content item should be an object")
|
require.True(t, ok, "First content item should be an object")
|
||||||
|
|
||||||
// Verify it's an image content
|
|
||||||
contentType, ok := firstItem["type"].(string)
|
|
||||||
require.True(t, ok, "Content should have a type")
|
|
||||||
assert.Equal(t, "image", contentType, "Content type should be image")
|
|
||||||
|
|
||||||
// Check image data
|
// Check image data
|
||||||
data, ok := firstItem["data"].(string)
|
data, ok := firstItem["data"].(string)
|
||||||
require.True(t, ok, "Content should have data")
|
require.True(t, ok, "Content should have data")
|
||||||
@@ -773,12 +744,12 @@ func TestToolCallToolResultType(t *testing.T) {
|
|||||||
Name: "toolresult.tool",
|
Name: "toolresult.tool",
|
||||||
Description: "A tool that returns a ToolResult object",
|
Description: "A tool that returns a ToolResult object",
|
||||||
InputSchema: InputSchema{
|
InputSchema: InputSchema{
|
||||||
Type: "object",
|
Type: ContentTypeObject,
|
||||||
Properties: map[string]any{},
|
Properties: map[string]any{},
|
||||||
},
|
},
|
||||||
Handler: func(params map[string]any) (any, error) {
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
return ToolResult{
|
return ToolResult{
|
||||||
Type: "text",
|
Type: ContentTypeText,
|
||||||
Content: "This is a ToolResult with text content type",
|
Content: "This is a ToolResult with text content type",
|
||||||
}, nil
|
}, nil
|
||||||
},
|
},
|
||||||
@@ -790,10 +761,10 @@ func TestToolCallToolResultType(t *testing.T) {
|
|||||||
Name: "toolresult.image.tool",
|
Name: "toolresult.image.tool",
|
||||||
Description: "A tool that returns a ToolResult with image content",
|
Description: "A tool that returns a ToolResult with image content",
|
||||||
InputSchema: InputSchema{
|
InputSchema: InputSchema{
|
||||||
Type: "object",
|
Type: ContentTypeObject,
|
||||||
Properties: map[string]any{},
|
Properties: map[string]any{},
|
||||||
},
|
},
|
||||||
Handler: func(params map[string]any) (any, error) {
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
return ToolResult{
|
return ToolResult{
|
||||||
Type: "image",
|
Type: "image",
|
||||||
Content: map[string]any{
|
Content: map[string]any{
|
||||||
@@ -810,10 +781,9 @@ func TestToolCallToolResultType(t *testing.T) {
|
|||||||
Name: "toolresult.audio.tool",
|
Name: "toolresult.audio.tool",
|
||||||
Description: "A tool that returns a ToolResult with audio content",
|
Description: "A tool that returns a ToolResult with audio content",
|
||||||
InputSchema: InputSchema{
|
InputSchema: InputSchema{
|
||||||
Type: "object",
|
|
||||||
Properties: map[string]any{},
|
Properties: map[string]any{},
|
||||||
},
|
},
|
||||||
Handler: func(params map[string]any) (any, error) {
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
// Test with image type
|
// Test with image type
|
||||||
return ToolResult{
|
return ToolResult{
|
||||||
Type: "audio",
|
Type: "audio",
|
||||||
@@ -831,10 +801,9 @@ func TestToolCallToolResultType(t *testing.T) {
|
|||||||
Name: "toolresult.int.tool",
|
Name: "toolresult.int.tool",
|
||||||
Description: "A tool that returns a ToolResult with int content",
|
Description: "A tool that returns a ToolResult with int content",
|
||||||
InputSchema: InputSchema{
|
InputSchema: InputSchema{
|
||||||
Type: "object",
|
|
||||||
Properties: map[string]any{},
|
Properties: map[string]any{},
|
||||||
},
|
},
|
||||||
Handler: func(params map[string]any) (any, error) {
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
return 2, nil
|
return 2, nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -845,10 +814,9 @@ func TestToolCallToolResultType(t *testing.T) {
|
|||||||
Name: "toolresult.bad.tool",
|
Name: "toolresult.bad.tool",
|
||||||
Description: "A tool that returns a ToolResult with bad content",
|
Description: "A tool that returns a ToolResult with bad content",
|
||||||
InputSchema: InputSchema{
|
InputSchema: InputSchema{
|
||||||
Type: "object",
|
|
||||||
Properties: map[string]any{},
|
Properties: map[string]any{},
|
||||||
},
|
},
|
||||||
Handler: func(params map[string]any) (any, error) {
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
return map[string]any{
|
return map[string]any{
|
||||||
"type": "custom",
|
"type": "custom",
|
||||||
"data": make(chan int),
|
"data": make(chan int),
|
||||||
@@ -881,7 +849,7 @@ func TestToolCallToolResultType(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Process the tool call
|
// Process the tool call
|
||||||
mock.server.processToolCall(client, req)
|
mock.server.processToolCall(context.Background(), client, req)
|
||||||
|
|
||||||
// Get the response from the client's channel
|
// Get the response from the client's channel
|
||||||
select {
|
select {
|
||||||
@@ -909,13 +877,8 @@ func TestToolCallToolResultType(t *testing.T) {
|
|||||||
firstItem, ok := content[0].(map[string]any)
|
firstItem, ok := content[0].(map[string]any)
|
||||||
require.True(t, ok, "First content item should be an object")
|
require.True(t, ok, "First content item should be an object")
|
||||||
|
|
||||||
// Verify it's a text content
|
|
||||||
contentType, ok := firstItem["type"].(string)
|
|
||||||
require.True(t, ok, "Content should have a type")
|
|
||||||
assert.Equal(t, "text", contentType, "Content type should be text")
|
|
||||||
|
|
||||||
// Check text content
|
// Check text content
|
||||||
text, ok := firstItem["text"].(string)
|
text, ok := firstItem[ContentTypeText].(string)
|
||||||
require.True(t, ok, "Content should have text")
|
require.True(t, ok, "Content should have text")
|
||||||
assert.Equal(t, "This is a ToolResult with text content type", text, "Text content should match")
|
assert.Equal(t, "This is a ToolResult with text content type", text, "Text content should match")
|
||||||
|
|
||||||
@@ -947,7 +910,7 @@ func TestToolCallToolResultType(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Process the tool call
|
// Process the tool call
|
||||||
mock.server.processToolCall(client, req)
|
mock.server.processToolCall(context.Background(), client, req)
|
||||||
|
|
||||||
// Get the response from the client's channel
|
// Get the response from the client's channel
|
||||||
select {
|
select {
|
||||||
@@ -975,11 +938,6 @@ func TestToolCallToolResultType(t *testing.T) {
|
|||||||
firstItem, ok := content[0].(map[string]any)
|
firstItem, ok := content[0].(map[string]any)
|
||||||
require.True(t, ok, "First content item should be an object")
|
require.True(t, ok, "First content item should be an object")
|
||||||
|
|
||||||
// Verify it's an image content
|
|
||||||
contentType, ok := firstItem["type"].(string)
|
|
||||||
require.True(t, ok, "Content should have a type")
|
|
||||||
assert.Equal(t, "image", contentType, "Content type should be image")
|
|
||||||
|
|
||||||
// Check image data and mime type
|
// Check image data and mime type
|
||||||
data, ok := firstItem["data"].(string)
|
data, ok := firstItem["data"].(string)
|
||||||
require.True(t, ok, "Content should have data")
|
require.True(t, ok, "Content should have data")
|
||||||
@@ -1017,7 +975,7 @@ func TestToolCallToolResultType(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Process the tool call
|
// Process the tool call
|
||||||
mock.server.processToolCall(client, req)
|
mock.server.processToolCall(context.Background(), client, req)
|
||||||
|
|
||||||
// Get the response from the client's channel
|
// Get the response from the client's channel
|
||||||
select {
|
select {
|
||||||
@@ -1040,15 +998,6 @@ func TestToolCallToolResultType(t *testing.T) {
|
|||||||
content, ok := result["content"].([]any)
|
content, ok := result["content"].([]any)
|
||||||
require.True(t, ok, "Result should have a content array")
|
require.True(t, ok, "Result should have a content array")
|
||||||
require.NotEmpty(t, content, "Content should not be empty")
|
require.NotEmpty(t, content, "Content should not be empty")
|
||||||
|
|
||||||
// The first content item should be converted from ToolResult to ImageContent
|
|
||||||
firstItem, ok := content[0].(map[string]any)
|
|
||||||
require.True(t, ok, "First content item should be an object")
|
|
||||||
|
|
||||||
// Verify it's an image content
|
|
||||||
contentType, ok := firstItem["type"].(string)
|
|
||||||
require.True(t, ok, "Content should have a type")
|
|
||||||
assert.Equal(t, "text", contentType, "Content type should be image")
|
|
||||||
case <-time.After(100 * time.Millisecond):
|
case <-time.After(100 * time.Millisecond):
|
||||||
t.Fatal("Timed out waiting for tool call response")
|
t.Fatal("Timed out waiting for tool call response")
|
||||||
}
|
}
|
||||||
@@ -1077,7 +1026,7 @@ func TestToolCallToolResultType(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Process the tool call
|
// Process the tool call
|
||||||
mock.server.processToolCall(client, req)
|
mock.server.processToolCall(context.Background(), client, req)
|
||||||
|
|
||||||
// Get the response from the client's channel
|
// Get the response from the client's channel
|
||||||
select {
|
select {
|
||||||
@@ -1100,15 +1049,6 @@ func TestToolCallToolResultType(t *testing.T) {
|
|||||||
content, ok := result["content"].([]any)
|
content, ok := result["content"].([]any)
|
||||||
require.True(t, ok, "Result should have a content array")
|
require.True(t, ok, "Result should have a content array")
|
||||||
require.NotEmpty(t, content, "Content should not be empty")
|
require.NotEmpty(t, content, "Content should not be empty")
|
||||||
|
|
||||||
// The first content item should be converted from ToolResult to ImageContent
|
|
||||||
firstItem, ok := content[0].(map[string]any)
|
|
||||||
require.True(t, ok, "First content item should be an object")
|
|
||||||
|
|
||||||
// Verify it's an image content
|
|
||||||
contentType, ok := firstItem["type"].(string)
|
|
||||||
require.True(t, ok, "Content should have a type")
|
|
||||||
assert.Equal(t, "text", contentType, "Content type should be image")
|
|
||||||
case <-time.After(100 * time.Millisecond):
|
case <-time.After(100 * time.Millisecond):
|
||||||
t.Fatal("Timed out waiting for tool call response")
|
t.Fatal("Timed out waiting for tool call response")
|
||||||
}
|
}
|
||||||
@@ -1137,7 +1077,7 @@ func TestToolCallToolResultType(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Process the tool call
|
// Process the tool call
|
||||||
mock.server.processToolCall(client, req)
|
mock.server.processToolCall(context.Background(), client, req)
|
||||||
|
|
||||||
// Get the response from the client's channel
|
// Get the response from the client's channel
|
||||||
select {
|
select {
|
||||||
@@ -1159,10 +1099,9 @@ func TestToolCallError(t *testing.T) {
|
|||||||
Name: "error.tool",
|
Name: "error.tool",
|
||||||
Description: "A tool that returns an error",
|
Description: "A tool that returns an error",
|
||||||
InputSchema: InputSchema{
|
InputSchema: InputSchema{
|
||||||
Type: "object",
|
|
||||||
Properties: map[string]any{},
|
Properties: map[string]any{},
|
||||||
},
|
},
|
||||||
Handler: func(params map[string]any) (any, error) {
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
return nil, fmt.Errorf("tool execution failed")
|
return nil, fmt.Errorf("tool execution failed")
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@@ -1189,7 +1128,7 @@ func TestToolCallError(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Process the tool call
|
// Process the tool call
|
||||||
mock.server.processToolCall(client, req)
|
mock.server.processToolCall(context.Background(), client, req)
|
||||||
|
|
||||||
// Check the response
|
// Check the response
|
||||||
select {
|
select {
|
||||||
@@ -1207,20 +1146,16 @@ func TestToolCallTimeout(t *testing.T) {
|
|||||||
mock := newMockMcpServer(t)
|
mock := newMockMcpServer(t)
|
||||||
defer mock.shutdown()
|
defer mock.shutdown()
|
||||||
|
|
||||||
// Set a very short timeout for testing
|
|
||||||
mock.server.conf.Mcp.ToolTimeout = 10 * time.Millisecond
|
|
||||||
|
|
||||||
// Register a tool that times out
|
// Register a tool that times out
|
||||||
err := mock.server.RegisterTool(Tool{
|
err := mock.server.RegisterTool(Tool{
|
||||||
Name: "timeout.tool",
|
Name: "timeout.tool",
|
||||||
Description: "A tool that times out",
|
Description: "A tool that times out",
|
||||||
InputSchema: InputSchema{
|
InputSchema: InputSchema{
|
||||||
Type: "object",
|
|
||||||
Properties: map[string]any{},
|
Properties: map[string]any{},
|
||||||
},
|
},
|
||||||
Handler: func(params map[string]any) (any, error) {
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
time.Sleep(50 * time.Millisecond) // Sleep longer than timeout
|
<-ctx.Done()
|
||||||
return "this should never be returned", nil
|
return nil, fmt.Errorf("tool execution timed out")
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -1244,16 +1179,24 @@ func TestToolCallTimeout(t *testing.T) {
|
|||||||
Method: methodToolsCall,
|
Method: methodToolsCall,
|
||||||
Params: paramBytes,
|
Params: paramBytes,
|
||||||
}
|
}
|
||||||
|
jsonBody, _ := json.Marshal(req)
|
||||||
|
|
||||||
// Process the tool call
|
// Create HTTP request
|
||||||
mock.server.processToolCall(client, req)
|
r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client", bytes.NewReader(jsonBody))
|
||||||
|
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Process through handleRequest
|
||||||
|
go mock.server.handleRequest(w, r)
|
||||||
|
|
||||||
// Check the response
|
// Check the response
|
||||||
select {
|
select {
|
||||||
case response := <-client.channel:
|
case response := <-client.channel:
|
||||||
assert.Contains(t, response, "event: message", "Response should have message event")
|
assert.Contains(t, response, "event: message", "Response should have message event")
|
||||||
assert.Contains(t, response, `-32001`, "Response should contain a timeout error code")
|
assert.Contains(t, response, `-32001`, "Response should contain a timeout error code")
|
||||||
case <-time.After(150 * time.Millisecond):
|
case <-time.After(100 * time.Millisecond):
|
||||||
t.Fatal("Timed out waiting for tool call response")
|
t.Fatal("Timed out waiting for tool call response")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1274,7 +1217,7 @@ func TestInitializeAndNotifications(t *testing.T) {
|
|||||||
Params: json.RawMessage(`{}`),
|
Params: json.RawMessage(`{}`),
|
||||||
}
|
}
|
||||||
|
|
||||||
mock.server.processInitialize(client, initReq)
|
mock.server.processInitialize(context.Background(), client, initReq)
|
||||||
|
|
||||||
// Check that client is initialized after initialize request
|
// Check that client is initialized after initialize request
|
||||||
assert.True(t, client.initialized, "Client should be marked as initialized after initialize request")
|
assert.True(t, client.initialized, "Client should be marked as initialized after initialize request")
|
||||||
@@ -1418,7 +1361,7 @@ func TestNotificationCancelled_badParams(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
buf := logtest.NewCollector(t)
|
buf := logtest.NewCollector(t)
|
||||||
mock.server.processNotificationCancelled(client, cancelReq)
|
mock.server.processNotificationCancelled(context.Background(), client, cancelReq)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-client.channel:
|
case <-client.channel:
|
||||||
@@ -1593,7 +1536,7 @@ func TestGetPrompt(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Process the request
|
// Process the request
|
||||||
mock.server.processGetPrompt(client, promptReq)
|
mock.server.processGetPrompt(context.Background(), client, promptReq)
|
||||||
|
|
||||||
// Check response
|
// Check response
|
||||||
select {
|
select {
|
||||||
@@ -1622,7 +1565,7 @@ func TestGetPrompt(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Process the request
|
// Process the request
|
||||||
mock.server.processGetPrompt(client, promptReq)
|
mock.server.processGetPrompt(context.Background(), client, promptReq)
|
||||||
|
|
||||||
// Check response
|
// Check response
|
||||||
select {
|
select {
|
||||||
@@ -1636,6 +1579,44 @@ func TestGetPrompt(t *testing.T) {
|
|||||||
t.Fatal("Timed out waiting for prompt response")
|
t.Fatal("Timed out waiting for prompt response")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("test prompt with nil params", func(t *testing.T) {
|
||||||
|
mock := newMockMcpServer(t)
|
||||||
|
defer mock.shutdown()
|
||||||
|
|
||||||
|
// Create a test client
|
||||||
|
client := addTestClient(mock.server, "test-client", true)
|
||||||
|
|
||||||
|
// Register a test prompt
|
||||||
|
testPrompt := Prompt{
|
||||||
|
Name: "test.prompt",
|
||||||
|
Description: "A test prompt",
|
||||||
|
}
|
||||||
|
mock.server.RegisterPrompt(testPrompt)
|
||||||
|
|
||||||
|
// Create a get prompt request
|
||||||
|
paramBytes, _ := json.Marshal(map[string]any{
|
||||||
|
"name": "test.prompt",
|
||||||
|
})
|
||||||
|
promptReq := Request{
|
||||||
|
JsonRpc: jsonRpcVersion,
|
||||||
|
ID: 1,
|
||||||
|
Method: "prompts/get",
|
||||||
|
Params: paramBytes,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process the request
|
||||||
|
mock.server.processGetPrompt(context.Background(), client, promptReq)
|
||||||
|
|
||||||
|
// Check response
|
||||||
|
select {
|
||||||
|
case response := <-client.channel:
|
||||||
|
_, err := parseEvent(response)
|
||||||
|
assert.NoError(t, err, "Should be able to parse event")
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Fatal("Timed out waiting for prompt response")
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestBroadcast tests the broadcast functionality
|
// TestBroadcast tests the broadcast functionality
|
||||||
@@ -1903,34 +1884,79 @@ func TestNotificationInitialized(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSendBadResponse(t *testing.T) {
|
func TestSendResponse(t *testing.T) {
|
||||||
mock := newMockMcpServer(t)
|
t.Run("bad response", func(t *testing.T) {
|
||||||
defer mock.shutdown()
|
mock := newMockMcpServer(t)
|
||||||
|
defer mock.shutdown()
|
||||||
|
|
||||||
// Create a test client
|
// Create a test client
|
||||||
client := addTestClient(mock.server, "test-client", true)
|
client := addTestClient(mock.server, "test-client", true)
|
||||||
|
|
||||||
// Create a response
|
// Create a response
|
||||||
response := Response{
|
response := Response{
|
||||||
JsonRpc: jsonRpcVersion,
|
JsonRpc: jsonRpcVersion,
|
||||||
ID: 1,
|
ID: 1,
|
||||||
Result: make(chan int),
|
Result: make(chan int),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send the response
|
// Send the response
|
||||||
mock.server.sendResponse(client, 1, response)
|
mock.server.sendResponse(context.Background(), client, 1, response)
|
||||||
|
|
||||||
// Check the response in the client's channel
|
// Check the response in the client's channel
|
||||||
select {
|
select {
|
||||||
case res := <-client.channel:
|
case res := <-client.channel:
|
||||||
evt, err := parseEvent(res)
|
evt, err := parseEvent(res)
|
||||||
require.NoError(t, err, "Should parse event without error")
|
require.NoError(t, err, "Should parse event without error")
|
||||||
errMsg, ok := evt.Data["error"].(map[string]any)
|
errMsg, ok := evt.Data["error"].(map[string]any)
|
||||||
require.True(t, ok, "Should have error in response")
|
require.True(t, ok, "Should have error in response")
|
||||||
assert.Equal(t, float64(errCodeInternalError), errMsg["code"], "Error code should match")
|
assert.Equal(t, float64(errCodeInternalError), errMsg["code"], "Error code should match")
|
||||||
case <-time.After(100 * time.Millisecond):
|
case <-time.After(100 * time.Millisecond):
|
||||||
t.Fatal("Timed out waiting for response")
|
t.Fatal("Timed out waiting for response")
|
||||||
}
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("channel full", func(t *testing.T) {
|
||||||
|
mock := newMockMcpServer(t)
|
||||||
|
defer mock.shutdown()
|
||||||
|
|
||||||
|
// Create a test client
|
||||||
|
client := addTestClient(mock.server, "test-client", true)
|
||||||
|
for i := 0; i < eventChanSize; i++ {
|
||||||
|
client.channel <- "test"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a response
|
||||||
|
response := Response{
|
||||||
|
JsonRpc: jsonRpcVersion,
|
||||||
|
ID: 1,
|
||||||
|
Result: "foo",
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := logtest.NewCollector(t)
|
||||||
|
// Send the response
|
||||||
|
mock.server.sendResponse(context.Background(), client, 1, response)
|
||||||
|
// Check the response in the client's channel
|
||||||
|
assert.Contains(t, buf.String(), "channel is full")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSendErrorResponse(t *testing.T) {
|
||||||
|
t.Run("channel full", func(t *testing.T) {
|
||||||
|
mock := newMockMcpServer(t)
|
||||||
|
defer mock.shutdown()
|
||||||
|
|
||||||
|
// Create a test client
|
||||||
|
client := addTestClient(mock.server, "test-client", true)
|
||||||
|
for i := 0; i < eventChanSize; i++ {
|
||||||
|
client.channel <- "test"
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := logtest.NewCollector(t)
|
||||||
|
// Send the response
|
||||||
|
mock.server.sendErrorResponse(context.Background(), client, 1, "foo", errCodeInternalError)
|
||||||
|
// Check the response in the client's channel
|
||||||
|
assert.Contains(t, buf.String(), "channel is full")
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestMethodToolsCall tests the handling of tools/call method through handleRequest
|
// TestMethodToolsCall tests the handling of tools/call method through handleRequest
|
||||||
@@ -2028,8 +2054,7 @@ func TestMethodToolsCall(t *testing.T) {
|
|||||||
if len(content) > 0 {
|
if len(content) > 0 {
|
||||||
firstItem, ok := content[0].(map[string]any)
|
firstItem, ok := content[0].(map[string]any)
|
||||||
if ok {
|
if ok {
|
||||||
assert.Equal(t, "text", firstItem["type"], "Content type should be text")
|
assert.Contains(t, firstItem[ContentTypeText], "Processed: test-input", "Content should include processed input")
|
||||||
assert.Contains(t, firstItem["text"], "Processed: test-input", "Content should include processed input")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case <-time.After(100 * time.Millisecond):
|
case <-time.After(100 * time.Millisecond):
|
||||||
@@ -2145,7 +2170,6 @@ func TestMethodPromptsGet(t *testing.T) {
|
|||||||
{
|
{
|
||||||
Name: "topic",
|
Name: "topic",
|
||||||
Description: "Topic to discuss",
|
Description: "Topic to discuss",
|
||||||
Default: "artificial intelligence",
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Content: "Hello {{name}}! Let's talk about {{topic}}.",
|
Content: "Hello {{name}}! Let's talk about {{topic}}.",
|
||||||
@@ -2227,14 +2251,12 @@ func TestMethodPromptsGet(t *testing.T) {
|
|||||||
if len(messages) > 0 {
|
if len(messages) > 0 {
|
||||||
message, ok := messages[0].(map[string]any)
|
message, ok := messages[0].(map[string]any)
|
||||||
require.True(t, ok, "Message should be an object")
|
require.True(t, ok, "Message should be an object")
|
||||||
assert.Equal(t, "user", message["role"], "Role should be 'user'")
|
assert.Equal(t, string(RoleUser), message["role"], "Role should be 'user'")
|
||||||
|
|
||||||
content, ok := message["content"].(map[string]any)
|
content, ok := message["content"].(map[string]any)
|
||||||
require.True(t, ok, "Should have content object")
|
require.True(t, ok, "Should have content object")
|
||||||
assert.Equal(t, "text", content["type"], "Content type should be text")
|
assert.Equal(t, ContentTypeText, content["type"], "Content type should be text")
|
||||||
assert.Contains(t, content["text"], "Hello Test User", "Content should include the name argument")
|
assert.Contains(t, content[ContentTypeText], "Hello Test User", "Content should include the name argument")
|
||||||
assert.Contains(t, content["text"], "about artificial intelligence",
|
|
||||||
"Content should include the default topic argument")
|
|
||||||
}
|
}
|
||||||
case <-time.After(100 * time.Millisecond):
|
case <-time.After(100 * time.Millisecond):
|
||||||
t.Fatal("Timed out waiting for prompt get response")
|
t.Fatal("Timed out waiting for prompt get response")
|
||||||
@@ -2255,27 +2277,24 @@ func TestMethodPromptsGet(t *testing.T) {
|
|||||||
{
|
{
|
||||||
Name: "question",
|
Name: "question",
|
||||||
Description: "User's question",
|
Description: "User's question",
|
||||||
Default: "How does this work?",
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Handler: func(args map[string]string) ([]PromptMessage, error) {
|
Handler: func(ctx context.Context, args map[string]string) ([]PromptMessage, error) {
|
||||||
username := args["username"]
|
username := args["username"]
|
||||||
question := args["question"]
|
question := args["question"]
|
||||||
|
|
||||||
// Create a system message
|
// Create a system message
|
||||||
systemMessage := PromptMessage{
|
systemMessage := PromptMessage{
|
||||||
Role: "system",
|
Role: RoleAssistant,
|
||||||
Content: TextContent{
|
Content: TextContent{
|
||||||
Type: "text",
|
|
||||||
Text: "You are a helpful assistant.",
|
Text: "You are a helpful assistant.",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a user message
|
// Create a user message
|
||||||
userMessage := PromptMessage{
|
userMessage := PromptMessage{
|
||||||
Role: "user",
|
Role: RoleUser,
|
||||||
Content: TextContent{
|
Content: TextContent{
|
||||||
Type: "text",
|
|
||||||
Text: fmt.Sprintf("Hi, I'm %s and I'm wondering: %s", username, question),
|
Text: fmt.Sprintf("Hi, I'm %s and I'm wondering: %s", username, question),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -2340,20 +2359,20 @@ func TestMethodPromptsGet(t *testing.T) {
|
|||||||
|
|
||||||
// Check message content
|
// Check message content
|
||||||
if len(messages) >= 2 {
|
if len(messages) >= 2 {
|
||||||
// First message should be system
|
// First message should be assistant
|
||||||
message1, _ := messages[0].(map[string]any)
|
message1, _ := messages[0].(map[string]any)
|
||||||
assert.Equal(t, "system", message1["role"], "First role should be 'system'")
|
assert.Equal(t, string(RoleAssistant), message1["role"], "First role should be 'system'")
|
||||||
|
|
||||||
content1, _ := message1["content"].(map[string]any)
|
content1, _ := message1["content"].(map[string]any)
|
||||||
assert.Contains(t, content1["text"], "helpful assistant", "System message should be correct")
|
assert.Contains(t, content1[ContentTypeText], "helpful assistant", "System message should be correct")
|
||||||
|
|
||||||
// Second message should be user
|
// Second message should be user
|
||||||
message2, _ := messages[1].(map[string]any)
|
message2, _ := messages[1].(map[string]any)
|
||||||
assert.Equal(t, "user", message2["role"], "Second role should be 'user'")
|
assert.Equal(t, string(RoleUser), message2["role"], "Second role should be 'user'")
|
||||||
|
|
||||||
content2, _ := message2["content"].(map[string]any)
|
content2, _ := message2["content"].(map[string]any)
|
||||||
assert.Contains(t, content2["text"], "Dynamic User", "User message should contain username")
|
assert.Contains(t, content2[ContentTypeText], "Dynamic User", "User message should contain username")
|
||||||
assert.Contains(t, content2["text"], "How to test this?", "User message should contain question")
|
assert.Contains(t, content2[ContentTypeText], "How to test this?", "User message should contain question")
|
||||||
}
|
}
|
||||||
case <-time.After(100 * time.Millisecond):
|
case <-time.After(100 * time.Millisecond):
|
||||||
t.Fatal("Timed out waiting for prompt get response")
|
t.Fatal("Timed out waiting for prompt get response")
|
||||||
@@ -2459,7 +2478,7 @@ func TestMethodPromptsGet(t *testing.T) {
|
|||||||
Name: "error-handler-prompt",
|
Name: "error-handler-prompt",
|
||||||
Description: "A prompt with a handler that returns an error",
|
Description: "A prompt with a handler that returns an error",
|
||||||
Arguments: []PromptArgument{},
|
Arguments: []PromptArgument{},
|
||||||
Handler: func(args map[string]string) ([]PromptMessage, error) {
|
Handler: func(ctx context.Context, args map[string]string) ([]PromptMessage, error) {
|
||||||
return nil, fmt.Errorf("test handler error")
|
return nil, fmt.Errorf("test handler error")
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -2583,7 +2602,7 @@ func TestMethodResourcesList(t *testing.T) {
|
|||||||
URI: "file:///test/resource.txt",
|
URI: "file:///test/resource.txt",
|
||||||
Description: "A test resource with handler",
|
Description: "A test resource with handler",
|
||||||
MimeType: "text/plain",
|
MimeType: "text/plain",
|
||||||
Handler: func() (ResourceContent, error) {
|
Handler: func(ctx context.Context) (ResourceContent, error) {
|
||||||
return ResourceContent{
|
return ResourceContent{
|
||||||
URI: "file:///test/resource.txt",
|
URI: "file:///test/resource.txt",
|
||||||
MimeType: "text/plain",
|
MimeType: "text/plain",
|
||||||
@@ -2654,7 +2673,7 @@ func TestMethodResourcesRead(t *testing.T) {
|
|||||||
URI: "file:///test/resource.txt",
|
URI: "file:///test/resource.txt",
|
||||||
Description: "A test resource with handler",
|
Description: "A test resource with handler",
|
||||||
MimeType: "text/plain",
|
MimeType: "text/plain",
|
||||||
Handler: func() (ResourceContent, error) {
|
Handler: func(ctx context.Context) (ResourceContent, error) {
|
||||||
return ResourceContent{
|
return ResourceContent{
|
||||||
URI: "file:///test/resource.txt",
|
URI: "file:///test/resource.txt",
|
||||||
MimeType: "text/plain",
|
MimeType: "text/plain",
|
||||||
@@ -2729,7 +2748,7 @@ func TestMethodResourcesRead(t *testing.T) {
|
|||||||
require.True(t, ok, "Content should be an object")
|
require.True(t, ok, "Content should be an object")
|
||||||
assert.Equal(t, "file:///test/resource.txt", content["uri"], "URI should match")
|
assert.Equal(t, "file:///test/resource.txt", content["uri"], "URI should match")
|
||||||
assert.Equal(t, "text/plain", content["mimeType"], "MimeType should match")
|
assert.Equal(t, "text/plain", content["mimeType"], "MimeType should match")
|
||||||
assert.Equal(t, "This is test resource content", content["text"], "Text content should match")
|
assert.Equal(t, "This is test resource content", content[ContentTypeText], "Text content should match")
|
||||||
}
|
}
|
||||||
case <-time.After(100 * time.Millisecond):
|
case <-time.After(100 * time.Millisecond):
|
||||||
t.Fatal("Timed out waiting for resource read response")
|
t.Fatal("Timed out waiting for resource read response")
|
||||||
@@ -2799,7 +2818,7 @@ func TestMethodResourcesRead(t *testing.T) {
|
|||||||
require.True(t, ok, "Content should be an object")
|
require.True(t, ok, "Content should be an object")
|
||||||
assert.Equal(t, "file:///test/no-handler.txt", content["uri"], "URI should match")
|
assert.Equal(t, "file:///test/no-handler.txt", content["uri"], "URI should match")
|
||||||
assert.Equal(t, "text/plain", content["mimeType"], "MimeType should match")
|
assert.Equal(t, "text/plain", content["mimeType"], "MimeType should match")
|
||||||
_, ok = content["text"]
|
_, ok = content[ContentTypeText]
|
||||||
assert.False(t, ok, "Text content should be empty string")
|
assert.False(t, ok, "Text content should be empty string")
|
||||||
}
|
}
|
||||||
case <-time.After(100 * time.Millisecond):
|
case <-time.After(100 * time.Millisecond):
|
||||||
@@ -2880,7 +2899,7 @@ func TestMethodResourcesRead(t *testing.T) {
|
|||||||
client := addTestClient(mock.server, "test-client-resources", true)
|
client := addTestClient(mock.server, "test-client-resources", true)
|
||||||
|
|
||||||
// Process through handleRequest
|
// Process through handleRequest
|
||||||
mock.server.processResourcesRead(client, req)
|
mock.server.processResourcesRead(context.Background(), client, req)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case response := <-client.channel:
|
case response := <-client.channel:
|
||||||
@@ -2898,7 +2917,7 @@ func TestMethodResourcesRead(t *testing.T) {
|
|||||||
URI: "file:///test/error.txt",
|
URI: "file:///test/error.txt",
|
||||||
Description: "A test resource with handler that returns error",
|
Description: "A test resource with handler that returns error",
|
||||||
MimeType: "text/plain",
|
MimeType: "text/plain",
|
||||||
Handler: func() (ResourceContent, error) {
|
Handler: func(ctx context.Context) (ResourceContent, error) {
|
||||||
return ResourceContent{}, fmt.Errorf("test handler error")
|
return ResourceContent{}, fmt.Errorf("test handler error")
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -2946,7 +2965,7 @@ func TestMethodResourcesRead(t *testing.T) {
|
|||||||
URI: "file:///test/missing-fields.txt",
|
URI: "file:///test/missing-fields.txt",
|
||||||
Description: "A test resource with handler that returns content missing fields",
|
Description: "A test resource with handler that returns content missing fields",
|
||||||
MimeType: "text/plain",
|
MimeType: "text/plain",
|
||||||
Handler: func() (ResourceContent, error) {
|
Handler: func(ctx context.Context) (ResourceContent, error) {
|
||||||
// Return ResourceContent without URI and MimeType
|
// Return ResourceContent without URI and MimeType
|
||||||
return ResourceContent{
|
return ResourceContent{
|
||||||
Text: "Content with missing fields",
|
Text: "Content with missing fields",
|
||||||
@@ -3006,7 +3025,7 @@ func TestMethodResourcesRead(t *testing.T) {
|
|||||||
require.True(t, ok, "Content should be an object")
|
require.True(t, ok, "Content should be an object")
|
||||||
assert.Equal(t, "file:///test/missing-fields.txt", content["uri"], "URI should be filled from request")
|
assert.Equal(t, "file:///test/missing-fields.txt", content["uri"], "URI should be filled from request")
|
||||||
assert.Equal(t, "text/plain", content["mimeType"], "MimeType should be filled from resource")
|
assert.Equal(t, "text/plain", content["mimeType"], "MimeType should be filled from resource")
|
||||||
assert.Equal(t, "Content with missing fields", content["text"], "Text content should match")
|
assert.Equal(t, "Content with missing fields", content[ContentTypeText], "Text content should match")
|
||||||
}
|
}
|
||||||
case <-time.After(100 * time.Millisecond):
|
case <-time.After(100 * time.Millisecond):
|
||||||
t.Fatal("Timed out waiting for resource read response")
|
t.Fatal("Timed out waiting for resource read response")
|
||||||
@@ -3159,7 +3178,7 @@ func TestMethodResourcesSubscribe(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
client := addTestClient(mock.server, "test-client-sub-not-found", true)
|
client := addTestClient(mock.server, "test-client-sub-not-found", true)
|
||||||
mock.server.processResourceSubscribe(client, req)
|
mock.server.processResourceSubscribe(context.Background(), client, req)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case response := <-client.channel:
|
case response := <-client.channel:
|
||||||
@@ -3268,7 +3287,7 @@ func TestToolCallUnmarshalError(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Process the tool call directly
|
// Process the tool call directly
|
||||||
mock.server.processToolCall(client, req)
|
mock.server.processToolCall(context.Background(), client, req)
|
||||||
|
|
||||||
// Check for error response about invalid JSON
|
// Check for error response about invalid JSON
|
||||||
select {
|
select {
|
||||||
@@ -3316,7 +3335,7 @@ func TestToolCallWithInvalidParams(t *testing.T) {
|
|||||||
jsonBody, _ := json.Marshal(req)
|
jsonBody, _ := json.Marshal(req)
|
||||||
|
|
||||||
// Create HTTP request
|
// Create HTTP request
|
||||||
r := httptest.NewRequest("POST", "/?session_id=test-client-invalid-json", bytes.NewReader(jsonBody))
|
r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-invalid-json", bytes.NewReader(jsonBody))
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
// Process through handleRequest
|
// Process through handleRequest
|
||||||
|
|||||||
42
mcp/types.go
42
mcp/types.go
@@ -1,6 +1,7 @@
|
|||||||
package mcp
|
package mcp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@@ -45,19 +46,18 @@ type ListToolsResult struct {
|
|||||||
|
|
||||||
// Message Content Types
|
// Message Content Types
|
||||||
|
|
||||||
// roleType represents the sender or recipient of messages in a conversation
|
// RoleType represents the sender or recipient of messages in a conversation
|
||||||
type roleType string
|
type RoleType string
|
||||||
|
|
||||||
// PromptArgument defines a single argument that can be passed to a prompt
|
// PromptArgument defines a single argument that can be passed to a prompt
|
||||||
type PromptArgument struct {
|
type PromptArgument struct {
|
||||||
Name string `json:"name"` // Argument name
|
Name string `json:"name"` // Argument name
|
||||||
Description string `json:"description,omitempty"` // Human-readable description
|
Description string `json:"description,omitempty"` // Human-readable description
|
||||||
Required bool `json:"required,omitempty"` // Whether this argument is required
|
Required bool `json:"required,omitempty"` // Whether this argument is required
|
||||||
Default string `json:"default,omitempty"` // Default value if not provided
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// PromptHandler is a function that dynamically generates prompt content
|
// PromptHandler is a function that dynamically generates prompt content
|
||||||
type PromptHandler func(args map[string]string) ([]PromptMessage, error)
|
type PromptHandler func(ctx context.Context, args map[string]string) ([]PromptMessage, error)
|
||||||
|
|
||||||
// Prompt represents an MCP Prompt definition
|
// Prompt represents an MCP Prompt definition
|
||||||
type Prompt struct {
|
type Prompt struct {
|
||||||
@@ -70,31 +70,43 @@ type Prompt struct {
|
|||||||
|
|
||||||
// PromptMessage represents a message in a conversation
|
// PromptMessage represents a message in a conversation
|
||||||
type PromptMessage struct {
|
type PromptMessage struct {
|
||||||
Role roleType `json:"role"` // Message sender role
|
Role RoleType `json:"role"` // Message sender role
|
||||||
Content any `json:"content"` // Message content (TextContent, ImageContent, etc.)
|
Content any `json:"content"` // Message content (TextContent, ImageContent, etc.)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TextContent represents text content in a message
|
// TextContent represents text content in a message
|
||||||
type TextContent struct {
|
type TextContent struct {
|
||||||
Type string `json:"type"` // Always "text"
|
|
||||||
Text string `json:"text"` // The text content
|
Text string `json:"text"` // The text content
|
||||||
Annotations *Annotations `json:"annotations,omitempty"` // Optional annotations
|
Annotations *Annotations `json:"annotations,omitempty"` // Optional annotations
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type typedTextContent struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
TextContent
|
||||||
|
}
|
||||||
|
|
||||||
// ImageContent represents image data in a message
|
// ImageContent represents image data in a message
|
||||||
type ImageContent struct {
|
type ImageContent struct {
|
||||||
Type string `json:"type"` // Always "image"
|
|
||||||
Data string `json:"data"` // Base64-encoded image data
|
Data string `json:"data"` // Base64-encoded image data
|
||||||
MimeType string `json:"mimeType"` // MIME type (e.g., "image/png")
|
MimeType string `json:"mimeType"` // MIME type (e.g., "image/png")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type typedImageContent struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
ImageContent
|
||||||
|
}
|
||||||
|
|
||||||
// AudioContent represents audio data in a message
|
// AudioContent represents audio data in a message
|
||||||
type AudioContent struct {
|
type AudioContent struct {
|
||||||
Type string `json:"type"` // Always "audio"
|
|
||||||
Data string `json:"data"` // Base64-encoded audio data
|
Data string `json:"data"` // Base64-encoded audio data
|
||||||
MimeType string `json:"mimeType"` // MIME type (e.g., "audio/mp3")
|
MimeType string `json:"mimeType"` // MIME type (e.g., "audio/mp3")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type typedAudioContent struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
AudioContent
|
||||||
|
}
|
||||||
|
|
||||||
// FileContent represents file content
|
// FileContent represents file content
|
||||||
type FileContent struct {
|
type FileContent struct {
|
||||||
URI string `json:"uri"` // URI identifying the file
|
URI string `json:"uri"` // URI identifying the file
|
||||||
@@ -115,16 +127,14 @@ type EmbeddedResource struct {
|
|||||||
|
|
||||||
// Annotations provides additional metadata for content
|
// Annotations provides additional metadata for content
|
||||||
type Annotations struct {
|
type Annotations struct {
|
||||||
Audience []roleType `json:"audience,omitempty"` // Who should see this content
|
Audience []RoleType `json:"audience,omitempty"` // Who should see this content
|
||||||
Priority *float64 `json:"priority,omitempty"` // Optional priority (0-1)
|
Priority *float64 `json:"priority,omitempty"` // Optional priority (0-1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tool-related Types
|
// Tool-related Types
|
||||||
|
|
||||||
// Tool Definition Types
|
|
||||||
|
|
||||||
// ToolHandler is a function that handles tool calls
|
// ToolHandler is a function that handles tool calls
|
||||||
type ToolHandler func(params map[string]any) (any, error)
|
type ToolHandler func(ctx context.Context, params map[string]any) (any, error)
|
||||||
|
|
||||||
// Tool represents a Model Context Protocol Tool definition
|
// Tool represents a Model Context Protocol Tool definition
|
||||||
type Tool struct {
|
type Tool struct {
|
||||||
@@ -136,7 +146,7 @@ type Tool struct {
|
|||||||
|
|
||||||
// InputSchema represents tool's input schema in JSON Schema format
|
// InputSchema represents tool's input schema in JSON Schema format
|
||||||
type InputSchema struct {
|
type InputSchema struct {
|
||||||
Type string `json:"type"` // Always "object" for tool inputs
|
Type string `json:"type"`
|
||||||
Properties map[string]any `json:"properties"` // Property definitions
|
Properties map[string]any `json:"properties"` // Property definitions
|
||||||
Required []string `json:"required,omitempty"` // List of required properties
|
Required []string `json:"required,omitempty"` // List of required properties
|
||||||
}
|
}
|
||||||
@@ -144,8 +154,8 @@ type InputSchema struct {
|
|||||||
// CallToolResult represents a tool call result that conforms to the MCP schema
|
// CallToolResult represents a tool call result that conforms to the MCP schema
|
||||||
type CallToolResult struct {
|
type CallToolResult struct {
|
||||||
Result
|
Result
|
||||||
Content []interface{} `json:"content"` // Content items (text, images, etc.)
|
Content []any `json:"content"` // Content items (text, images, etc.)
|
||||||
IsError bool `json:"isError,omitempty"` // True if tool execution failed
|
IsError bool `json:"isError,omitempty"` // True if tool execution failed
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resource represents a Model Context Protocol Resource definition
|
// Resource represents a Model Context Protocol Resource definition
|
||||||
@@ -158,7 +168,7 @@ type Resource struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ResourceHandler is a function that handles resource read requests
|
// ResourceHandler is a function that handles resource read requests
|
||||||
type ResourceHandler func() (ResourceContent, error)
|
type ResourceHandler func(ctx context.Context) (ResourceContent, error)
|
||||||
|
|
||||||
// ResourceContent represents the content of a resource
|
// ResourceContent represents the content of a resource
|
||||||
type ResourceContent struct {
|
type ResourceContent struct {
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package mcp
|
package mcp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@@ -79,7 +80,7 @@ func TestToolStructs(t *testing.T) {
|
|||||||
},
|
},
|
||||||
Required: []string{"input"},
|
Required: []string{"input"},
|
||||||
},
|
},
|
||||||
Handler: func(params map[string]any) (any, error) {
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
return "result", nil
|
return "result", nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -145,44 +146,38 @@ func TestResourceStructs(t *testing.T) {
|
|||||||
func TestContentTypes(t *testing.T) {
|
func TestContentTypes(t *testing.T) {
|
||||||
// Test TextContent
|
// Test TextContent
|
||||||
textContent := TextContent{
|
textContent := TextContent{
|
||||||
Type: "text",
|
|
||||||
Text: "Sample text",
|
Text: "Sample text",
|
||||||
Annotations: &Annotations{
|
Annotations: &Annotations{
|
||||||
Audience: []roleType{roleUser, roleAssistant},
|
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||||
Priority: ptr(1.0),
|
Priority: ptr(1.0),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := json.Marshal(textContent)
|
data, err := json.Marshal(textContent)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Contains(t, string(data), `"type":"text"`)
|
|
||||||
assert.Contains(t, string(data), `"text":"Sample text"`)
|
assert.Contains(t, string(data), `"text":"Sample text"`)
|
||||||
assert.Contains(t, string(data), `"audience":["user","assistant"]`)
|
assert.Contains(t, string(data), `"audience":["user","assistant"]`)
|
||||||
assert.Contains(t, string(data), `"priority":1`)
|
assert.Contains(t, string(data), `"priority":1`)
|
||||||
|
|
||||||
// Test ImageContent
|
// Test ImageContent
|
||||||
imageContent := ImageContent{
|
imageContent := ImageContent{
|
||||||
Type: "image",
|
|
||||||
Data: "base64data",
|
Data: "base64data",
|
||||||
MimeType: "image/png",
|
MimeType: "image/png",
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err = json.Marshal(imageContent)
|
data, err = json.Marshal(imageContent)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Contains(t, string(data), `"type":"image"`)
|
|
||||||
assert.Contains(t, string(data), `"data":"base64data"`)
|
assert.Contains(t, string(data), `"data":"base64data"`)
|
||||||
assert.Contains(t, string(data), `"mimeType":"image/png"`)
|
assert.Contains(t, string(data), `"mimeType":"image/png"`)
|
||||||
|
|
||||||
// Test AudioContent
|
// Test AudioContent
|
||||||
audioContent := AudioContent{
|
audioContent := AudioContent{
|
||||||
Type: "audio",
|
|
||||||
Data: "base64audio",
|
Data: "base64audio",
|
||||||
MimeType: "audio/mp3",
|
MimeType: "audio/mp3",
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err = json.Marshal(audioContent)
|
data, err = json.Marshal(audioContent)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Contains(t, string(data), `"type":"audio"`)
|
|
||||||
assert.Contains(t, string(data), `"data":"base64audio"`)
|
assert.Contains(t, string(data), `"data":"base64audio"`)
|
||||||
assert.Contains(t, string(data), `"mimeType":"audio/mp3"`)
|
assert.Contains(t, string(data), `"mimeType":"audio/mp3"`)
|
||||||
}
|
}
|
||||||
@@ -197,7 +192,6 @@ func TestCallToolResult(t *testing.T) {
|
|||||||
},
|
},
|
||||||
Content: []interface{}{
|
Content: []interface{}{
|
||||||
TextContent{
|
TextContent{
|
||||||
Type: "text",
|
|
||||||
Text: "Sample result",
|
Text: "Sample result",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -207,6 +201,6 @@ func TestCallToolResult(t *testing.T) {
|
|||||||
data, err := json.Marshal(result)
|
data, err := json.Marshal(result)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Contains(t, string(data), `"_meta":{"progressToken":"token123"}`)
|
assert.Contains(t, string(data), `"_meta":{"progressToken":"token123"}`)
|
||||||
assert.Contains(t, string(data), `"content":[{"type":"text","text":"Sample result"}]`)
|
assert.Contains(t, string(data), `"content":[{"text":"Sample result"}]`)
|
||||||
assert.NotContains(t, string(data), `"isError":`)
|
assert.NotContains(t, string(data), `"isError":`)
|
||||||
}
|
}
|
||||||
|
|||||||
104
mcp/util.go
104
mcp/util.go
@@ -1,15 +1,107 @@
|
|||||||
package mcp
|
package mcp
|
||||||
|
|
||||||
import (
|
import "fmt"
|
||||||
"fmt"
|
|
||||||
)
|
// formatSSEMessage formats a Server-Sent Event message with proper CRLF line endings
|
||||||
|
func formatSSEMessage(event string, data []byte) string {
|
||||||
|
return fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", event, string(data))
|
||||||
|
}
|
||||||
|
|
||||||
// ptr is a helper function to get a pointer to a value
|
// ptr is a helper function to get a pointer to a value
|
||||||
func ptr[T any](v T) *T {
|
func ptr[T any](v T) *T {
|
||||||
return &v
|
return &v
|
||||||
}
|
}
|
||||||
|
|
||||||
// formatSSEMessage formats a Server-Sent Event message with proper CRLF line endings
|
func toTypedContents(contents []any) []any {
|
||||||
func formatSSEMessage(event string, data []byte) string {
|
typedContents := make([]any, len(contents))
|
||||||
return fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", event, string(data))
|
|
||||||
|
for i, content := range contents {
|
||||||
|
switch v := content.(type) {
|
||||||
|
case TextContent:
|
||||||
|
typedContents[i] = typedTextContent{
|
||||||
|
Type: ContentTypeText,
|
||||||
|
TextContent: v,
|
||||||
|
}
|
||||||
|
case ImageContent:
|
||||||
|
typedContents[i] = typedImageContent{
|
||||||
|
Type: ContentTypeImage,
|
||||||
|
ImageContent: v,
|
||||||
|
}
|
||||||
|
case AudioContent:
|
||||||
|
typedContents[i] = typedAudioContent{
|
||||||
|
Type: ContentTypeAudio,
|
||||||
|
AudioContent: v,
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
typedContents[i] = typedTextContent{
|
||||||
|
Type: ContentTypeText,
|
||||||
|
TextContent: TextContent{
|
||||||
|
Text: fmt.Sprintf("Unknown content type: %T", v),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return typedContents
|
||||||
|
}
|
||||||
|
|
||||||
|
func toTypedPromptMessages(messages []PromptMessage) []PromptMessage {
|
||||||
|
typedMessages := make([]PromptMessage, len(messages))
|
||||||
|
|
||||||
|
for i, msg := range messages {
|
||||||
|
switch v := msg.Content.(type) {
|
||||||
|
case TextContent:
|
||||||
|
typedMessages[i] = PromptMessage{
|
||||||
|
Role: msg.Role,
|
||||||
|
Content: typedTextContent{
|
||||||
|
Type: ContentTypeText,
|
||||||
|
TextContent: v,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
case ImageContent:
|
||||||
|
typedMessages[i] = PromptMessage{
|
||||||
|
Role: msg.Role,
|
||||||
|
Content: typedImageContent{
|
||||||
|
Type: ContentTypeImage,
|
||||||
|
ImageContent: v,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
case AudioContent:
|
||||||
|
typedMessages[i] = PromptMessage{
|
||||||
|
Role: msg.Role,
|
||||||
|
Content: typedAudioContent{
|
||||||
|
Type: ContentTypeAudio,
|
||||||
|
AudioContent: v,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
typedMessages[i] = PromptMessage{
|
||||||
|
Role: msg.Role,
|
||||||
|
Content: typedTextContent{
|
||||||
|
Type: ContentTypeText,
|
||||||
|
TextContent: TextContent{
|
||||||
|
Text: fmt.Sprintf("Unknown content type: %T", v),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return typedMessages
|
||||||
|
}
|
||||||
|
|
||||||
|
// validatePromptArguments checks if all required arguments are provided
|
||||||
|
// Returns a list of missing required arguments
|
||||||
|
func validatePromptArguments(prompt Prompt, providedArgs map[string]string) []string {
|
||||||
|
var missingArgs []string
|
||||||
|
|
||||||
|
for _, arg := range prompt.Arguments {
|
||||||
|
if arg.Required {
|
||||||
|
if value, exists := providedArgs[arg.Name]; !exists || len(value) == 0 {
|
||||||
|
missingArgs = append(missingArgs, arg.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return missingArgs
|
||||||
}
|
}
|
||||||
|
|||||||
253
mcp/util_test.go
253
mcp/util_test.go
@@ -8,29 +8,9 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestPtr(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
v interface{}
|
|
||||||
}{
|
|
||||||
{"string", "test"},
|
|
||||||
{"int", 42},
|
|
||||||
{"bool", true},
|
|
||||||
{"float", 3.14},
|
|
||||||
{"struct", struct{ Name string }{"test"}},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got := ptr(tt.v)
|
|
||||||
assert.NotNil(t, got, "ptr() should not return nil")
|
|
||||||
assert.Equal(t, tt.v, *got, "dereferenced pointer should equal input value")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type Event struct {
|
type Event struct {
|
||||||
Type string
|
Type string
|
||||||
Data map[string]any
|
Data map[string]any
|
||||||
@@ -61,3 +41,234 @@ func parseEvent(input string) (*Event, error) {
|
|||||||
|
|
||||||
return &evt, nil
|
return &evt, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestToTypedPromptMessages tests the toTypedPromptMessages function
|
||||||
|
func TestToTypedPromptMessages(t *testing.T) {
|
||||||
|
// Test with multiple message types in one test
|
||||||
|
t.Run("MixedContentTypes", func(t *testing.T) {
|
||||||
|
// Create test data with different content types
|
||||||
|
messages := []PromptMessage{
|
||||||
|
{
|
||||||
|
Role: RoleUser,
|
||||||
|
Content: TextContent{
|
||||||
|
Text: "Hello, this is a text message",
|
||||||
|
Annotations: &Annotations{
|
||||||
|
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||||
|
Priority: ptr(0.8),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: RoleAssistant,
|
||||||
|
Content: ImageContent{
|
||||||
|
Data: "base64ImageData",
|
||||||
|
MimeType: "image/jpeg",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: RoleUser,
|
||||||
|
Content: AudioContent{
|
||||||
|
Data: "base64AudioData",
|
||||||
|
MimeType: "audio/mp3",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "system",
|
||||||
|
Content: "This is a simple string that should be handled as unknown type",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call the function
|
||||||
|
result := toTypedPromptMessages(messages)
|
||||||
|
|
||||||
|
// Validate results
|
||||||
|
require.Len(t, result, 4, "Should return the same number of messages")
|
||||||
|
|
||||||
|
// Validate first message (TextContent)
|
||||||
|
msg := result[0]
|
||||||
|
assert.Equal(t, RoleUser, msg.Role, "Role should be preserved")
|
||||||
|
|
||||||
|
// Type assertion using reflection since Content is an interface
|
||||||
|
typed, ok := msg.Content.(typedTextContent)
|
||||||
|
require.True(t, ok, "Should be typedTextContent")
|
||||||
|
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
|
||||||
|
assert.Equal(t, "Hello, this is a text message", typed.Text, "Text content should be preserved")
|
||||||
|
require.NotNil(t, typed.Annotations, "Annotations should be preserved")
|
||||||
|
assert.Equal(t, []RoleType{RoleUser, RoleAssistant}, typed.Annotations.Audience, "Audience should be preserved")
|
||||||
|
require.NotNil(t, typed.Annotations.Priority, "Priority should be preserved")
|
||||||
|
assert.Equal(t, 0.8, *typed.Annotations.Priority, "Priority value should be preserved")
|
||||||
|
|
||||||
|
// Validate second message (ImageContent)
|
||||||
|
msg = result[1]
|
||||||
|
assert.Equal(t, RoleAssistant, msg.Role, "Role should be preserved")
|
||||||
|
|
||||||
|
// Type assertion for image content
|
||||||
|
typedImg, ok := msg.Content.(typedImageContent)
|
||||||
|
require.True(t, ok, "Should be typedImageContent")
|
||||||
|
assert.Equal(t, ContentTypeImage, typedImg.Type, "Type should be image")
|
||||||
|
assert.Equal(t, "base64ImageData", typedImg.Data, "Image data should be preserved")
|
||||||
|
assert.Equal(t, "image/jpeg", typedImg.MimeType, "MimeType should be preserved")
|
||||||
|
|
||||||
|
// Validate third message (AudioContent)
|
||||||
|
msg = result[2]
|
||||||
|
assert.Equal(t, RoleUser, msg.Role, "Role should be preserved")
|
||||||
|
|
||||||
|
// Type assertion for audio content
|
||||||
|
typedAudio, ok := msg.Content.(typedAudioContent)
|
||||||
|
require.True(t, ok, "Should be typedAudioContent")
|
||||||
|
assert.Equal(t, ContentTypeAudio, typedAudio.Type, "Type should be audio")
|
||||||
|
assert.Equal(t, "base64AudioData", typedAudio.Data, "Audio data should be preserved")
|
||||||
|
assert.Equal(t, "audio/mp3", typedAudio.MimeType, "MimeType should be preserved")
|
||||||
|
|
||||||
|
// Validate fourth message (unknown type converted to TextContent)
|
||||||
|
msg = result[3]
|
||||||
|
assert.Equal(t, RoleType("system"), msg.Role, "Role should be preserved")
|
||||||
|
|
||||||
|
// Should be converted to a typedTextContent with error message
|
||||||
|
typedUnknown, ok := msg.Content.(typedTextContent)
|
||||||
|
require.True(t, ok, "Unknown content should be converted to typedTextContent")
|
||||||
|
assert.Equal(t, ContentTypeText, typedUnknown.Type, "Type should be text")
|
||||||
|
assert.Contains(t, typedUnknown.Text, "Unknown content type:", "Should contain error about unknown type")
|
||||||
|
assert.Contains(t, typedUnknown.Text, "string", "Should mention the actual type")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test empty input
|
||||||
|
t.Run("EmptyInput", func(t *testing.T) {
|
||||||
|
messages := []PromptMessage{}
|
||||||
|
result := toTypedPromptMessages(messages)
|
||||||
|
assert.Empty(t, result, "Should return empty slice for empty input")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test with nil annotations
|
||||||
|
t.Run("NilAnnotations", func(t *testing.T) {
|
||||||
|
messages := []PromptMessage{
|
||||||
|
{
|
||||||
|
Role: RoleUser,
|
||||||
|
Content: TextContent{
|
||||||
|
Text: "Text with nil annotations",
|
||||||
|
Annotations: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := toTypedPromptMessages(messages)
|
||||||
|
require.Len(t, result, 1, "Should return one message")
|
||||||
|
|
||||||
|
typed, ok := result[0].Content.(typedTextContent)
|
||||||
|
require.True(t, ok, "Should be typedTextContent")
|
||||||
|
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
|
||||||
|
assert.Equal(t, "Text with nil annotations", typed.Text, "Text content should be preserved")
|
||||||
|
assert.Nil(t, typed.Annotations, "Nil annotations should be preserved as nil")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestToTypedContents tests the toTypedContents function
|
||||||
|
func TestToTypedContents(t *testing.T) {
|
||||||
|
// Test with multiple content types in one test
|
||||||
|
t.Run("MixedContentTypes", func(t *testing.T) {
|
||||||
|
// Create test data with different content types
|
||||||
|
contents := []any{
|
||||||
|
TextContent{
|
||||||
|
Text: "Hello, this is a text content",
|
||||||
|
Annotations: &Annotations{
|
||||||
|
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||||
|
Priority: ptr(0.7),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ImageContent{
|
||||||
|
Data: "base64ImageData",
|
||||||
|
MimeType: "image/png",
|
||||||
|
},
|
||||||
|
AudioContent{
|
||||||
|
Data: "base64AudioData",
|
||||||
|
MimeType: "audio/wav",
|
||||||
|
},
|
||||||
|
"This is a simple string that should be handled as unknown type",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call the function
|
||||||
|
result := toTypedContents(contents)
|
||||||
|
|
||||||
|
// Validate results
|
||||||
|
require.Len(t, result, 4, "Should return the same number of contents")
|
||||||
|
|
||||||
|
// Validate first content (TextContent)
|
||||||
|
typed, ok := result[0].(typedTextContent)
|
||||||
|
require.True(t, ok, "Should be typedTextContent")
|
||||||
|
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
|
||||||
|
assert.Equal(t, "Hello, this is a text content", typed.Text, "Text content should be preserved")
|
||||||
|
require.NotNil(t, typed.Annotations, "Annotations should be preserved")
|
||||||
|
assert.Equal(t, []RoleType{RoleUser, RoleAssistant}, typed.Annotations.Audience, "Audience should be preserved")
|
||||||
|
require.NotNil(t, typed.Annotations.Priority, "Priority should be preserved")
|
||||||
|
assert.Equal(t, 0.7, *typed.Annotations.Priority, "Priority value should be preserved")
|
||||||
|
|
||||||
|
// Validate second content (ImageContent)
|
||||||
|
typedImg, ok := result[1].(typedImageContent)
|
||||||
|
require.True(t, ok, "Should be typedImageContent")
|
||||||
|
assert.Equal(t, ContentTypeImage, typedImg.Type, "Type should be image")
|
||||||
|
assert.Equal(t, "base64ImageData", typedImg.Data, "Image data should be preserved")
|
||||||
|
assert.Equal(t, "image/png", typedImg.MimeType, "MimeType should be preserved")
|
||||||
|
|
||||||
|
// Validate third content (AudioContent)
|
||||||
|
typedAudio, ok := result[2].(typedAudioContent)
|
||||||
|
require.True(t, ok, "Should be typedAudioContent")
|
||||||
|
assert.Equal(t, ContentTypeAudio, typedAudio.Type, "Type should be audio")
|
||||||
|
assert.Equal(t, "base64AudioData", typedAudio.Data, "Audio data should be preserved")
|
||||||
|
assert.Equal(t, "audio/wav", typedAudio.MimeType, "MimeType should be preserved")
|
||||||
|
|
||||||
|
// Validate fourth content (unknown type converted to TextContent)
|
||||||
|
typedUnknown, ok := result[3].(typedTextContent)
|
||||||
|
require.True(t, ok, "Unknown content should be converted to typedTextContent")
|
||||||
|
assert.Equal(t, ContentTypeText, typedUnknown.Type, "Type should be text")
|
||||||
|
assert.Contains(t, typedUnknown.Text, "Unknown content type:", "Should contain error about unknown type")
|
||||||
|
assert.Contains(t, typedUnknown.Text, "string", "Should mention the actual type")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test empty input
|
||||||
|
t.Run("EmptyInput", func(t *testing.T) {
|
||||||
|
contents := []any{}
|
||||||
|
result := toTypedContents(contents)
|
||||||
|
assert.Empty(t, result, "Should return empty slice for empty input")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test with nil annotations
|
||||||
|
t.Run("NilAnnotations", func(t *testing.T) {
|
||||||
|
contents := []any{
|
||||||
|
TextContent{
|
||||||
|
Text: "Text with nil annotations",
|
||||||
|
Annotations: nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := toTypedContents(contents)
|
||||||
|
require.Len(t, result, 1, "Should return one content")
|
||||||
|
|
||||||
|
typed, ok := result[0].(typedTextContent)
|
||||||
|
require.True(t, ok, "Should be typedTextContent")
|
||||||
|
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
|
||||||
|
assert.Equal(t, "Text with nil annotations", typed.Text, "Text content should be preserved")
|
||||||
|
assert.Nil(t, typed.Annotations, "Nil annotations should be preserved as nil")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test with custom struct (should be handled as unknown type)
|
||||||
|
t.Run("CustomStruct", func(t *testing.T) {
|
||||||
|
type CustomContent struct {
|
||||||
|
Data string
|
||||||
|
}
|
||||||
|
|
||||||
|
contents := []any{
|
||||||
|
CustomContent{
|
||||||
|
Data: "custom data",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := toTypedContents(contents)
|
||||||
|
require.Len(t, result, 1, "Should return one content")
|
||||||
|
|
||||||
|
typed, ok := result[0].(typedTextContent)
|
||||||
|
require.True(t, ok, "Custom struct should be converted to typedTextContent")
|
||||||
|
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
|
||||||
|
assert.Contains(t, typed.Text, "Unknown content type:", "Should contain error about unknown type")
|
||||||
|
assert.Contains(t, typed.Text, "CustomContent", "Should mention the actual type")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
28
mcp/vars.go
28
mcp/vars.go
@@ -13,6 +13,9 @@ const (
|
|||||||
|
|
||||||
// Session identifier key used in request URLs
|
// Session identifier key used in request URLs
|
||||||
sessionIdKey = "session_id"
|
sessionIdKey = "session_id"
|
||||||
|
|
||||||
|
// progressTokenKey is used to track progress of long-running tasks
|
||||||
|
progressTokenKey = "progressToken"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Server-Sent Events (SSE) event types
|
// Server-Sent Events (SSE) event types
|
||||||
@@ -26,11 +29,20 @@ const (
|
|||||||
|
|
||||||
// Content type identifiers
|
// Content type identifiers
|
||||||
const (
|
const (
|
||||||
// Text content type
|
// ContentTypeObject is object content type
|
||||||
contentTypeText = "text"
|
ContentTypeObject = "object"
|
||||||
|
|
||||||
// Image content type
|
// ContentTypeText is text content type
|
||||||
contentTypeImage = "image"
|
ContentTypeText = "text"
|
||||||
|
|
||||||
|
// ContentTypeImage is image content type
|
||||||
|
ContentTypeImage = "image"
|
||||||
|
|
||||||
|
// ContentTypeAudio is audio content type
|
||||||
|
ContentTypeAudio = "audio"
|
||||||
|
|
||||||
|
// ContentTypeResource is resource content type
|
||||||
|
ContentTypeResource = "resource"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Collection keys for broadcast events
|
// Collection keys for broadcast events
|
||||||
@@ -72,11 +84,11 @@ const (
|
|||||||
|
|
||||||
// User and assistant role definitions
|
// User and assistant role definitions
|
||||||
const (
|
const (
|
||||||
// The "user" role - the entity asking questions
|
// RoleUser is the "user" role - the entity asking questions
|
||||||
roleUser roleType = "user"
|
RoleUser RoleType = "user"
|
||||||
|
|
||||||
// The "assistant" role - the entity providing responses
|
// RoleAssistant is the "assistant" role - the entity providing responses
|
||||||
roleAssistant roleType = "assistant"
|
RoleAssistant RoleType = "assistant"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Method names as defined in the MCP specification
|
// Method names as defined in the MCP specification
|
||||||
|
|||||||
@@ -146,13 +146,9 @@ func TestCollectionKeys(t *testing.T) {
|
|||||||
|
|
||||||
// TestRoleTypes checks that role types are used correctly
|
// TestRoleTypes checks that role types are used correctly
|
||||||
func TestRoleTypes(t *testing.T) {
|
func TestRoleTypes(t *testing.T) {
|
||||||
// Verify role type constants
|
|
||||||
assert.Equal(t, "user", string(roleUser), "User role should be 'user'")
|
|
||||||
assert.Equal(t, "assistant", string(roleAssistant), "Assistant role should be 'assistant'")
|
|
||||||
|
|
||||||
// Test in annotations
|
// Test in annotations
|
||||||
annotations := Annotations{
|
annotations := Annotations{
|
||||||
Audience: []roleType{roleUser, roleAssistant},
|
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||||
}
|
}
|
||||||
data, err := json.Marshal(annotations)
|
data, err := json.Marshal(annotations)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|||||||
@@ -302,6 +302,8 @@ go-zero 已被许多公司用于生产部署,接入场景如在线教育、电
|
|||||||
>103. 爱芯元智半导体股份有限公司
|
>103. 爱芯元智半导体股份有限公司
|
||||||
>104. 杭州升恒科技有限公司
|
>104. 杭州升恒科技有限公司
|
||||||
>105. 昆仑万维科技股份有限公司
|
>105. 昆仑万维科技股份有限公司
|
||||||
|
>106. 无锡盛算信息技术有限公司
|
||||||
|
>107. 深圳市聚货通信息科技有限公司
|
||||||
|
|
||||||
如果贵公司也已使用 go-zero,欢迎在 [登记地址](https://github.com/zeromicro/go-zero/issues/602) 登记,仅仅为了推广,不做其它用途。
|
如果贵公司也已使用 go-zero,欢迎在 [登记地址](https://github.com/zeromicro/go-zero/issues/602) 登记,仅仅为了推广,不做其它用途。
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ import (
|
|||||||
"github.com/zeromicro/go-zero/tools/goctl/api/validate"
|
"github.com/zeromicro/go-zero/tools/goctl/api/validate"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/config"
|
"github.com/zeromicro/go-zero/tools/goctl/config"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/internal/cobrax"
|
"github.com/zeromicro/go-zero/tools/goctl/internal/cobrax"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/pkg/env"
|
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/plugin"
|
"github.com/zeromicro/go-zero/tools/goctl/plugin"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -110,8 +109,5 @@ func init() {
|
|||||||
validateCmdFlags.StringVar(&validate.VarStringAPI, "api")
|
validateCmdFlags.StringVar(&validate.VarStringAPI, "api")
|
||||||
|
|
||||||
// Add sub-commands
|
// Add sub-commands
|
||||||
Cmd.AddCommand(dartCmd, docCmd, formatCmd, goCmd, javaCmd, ktCmd, newCmd, pluginCmd, tsCmd, validateCmd)
|
Cmd.AddCommand(dartCmd, docCmd, formatCmd, goCmd, javaCmd, ktCmd, newCmd, pluginCmd, tsCmd, validateCmd, swaggerCmd)
|
||||||
if env.UseExperimental() {
|
|
||||||
Cmd.AddCommand(swaggerCmd)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -42,8 +42,19 @@ var (
|
|||||||
func GoFormatApi(_ *cobra.Command, _ []string) error {
|
func GoFormatApi(_ *cobra.Command, _ []string) error {
|
||||||
var be errorx.BatchError
|
var be errorx.BatchError
|
||||||
if VarBoolUseStdin {
|
if VarBoolUseStdin {
|
||||||
if err := apiFormatReader(os.Stdin, VarStringDir, VarBoolSkipCheckDeclare); err != nil {
|
if env.UseExperimental() {
|
||||||
be.Add(err)
|
data, err := io.ReadAll(os.Stdin)
|
||||||
|
if err != nil {
|
||||||
|
be.Add(err)
|
||||||
|
} else {
|
||||||
|
if err := apiF.Source(data, os.Stdout); err != nil {
|
||||||
|
be.Add(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := apiFormatReader(os.Stdin, VarStringDir, VarBoolSkipCheckDeclare); err != nil {
|
||||||
|
be.Add(err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if len(VarStringDir) == 0 {
|
if len(VarStringDir) == 0 {
|
||||||
|
|||||||
@@ -69,3 +69,16 @@ func getListFromInfoOrDefault(properties map[string]string, key string, def []st
|
|||||||
}
|
}
|
||||||
return resp
|
return resp
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getFirstUsableString(def ...string) string {
|
||||||
|
if len(def) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
for _, val := range def {
|
||||||
|
str := util.Unquote(val)
|
||||||
|
if len(str) != 0 {
|
||||||
|
return str
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ info (
|
|||||||
host: "example.com" // 对应 swagger 的 host,不填默认为 127.0.0.1
|
host: "example.com" // 对应 swagger 的 host,不填默认为 127.0.0.1
|
||||||
basePath: "/v1" // 对应 swagger 的 basePath,不填默认为 /
|
basePath: "/v1" // 对应 swagger 的 basePath,不填默认为 /
|
||||||
wrapCodeMsg: "true" // 是否用 code-msg 通用响应体,如果开启,则以格式 {"code":0,"msg":"OK","data":$data} 包括响应体
|
wrapCodeMsg: "true" // 是否用 code-msg 通用响应体,如果开启,则以格式 {"code":0,"msg":"OK","data":$data} 包括响应体
|
||||||
bizCodeEnumDescription: "1001-未登录<br>1002-无权限操作" // 业务错误码枚举描述,json 格式,key 为业务错误码,value 为该错误码的描述,仅当 wrapCodeMsg 为 true 时生效
|
bizCodeEnumDescription: "1001-未登录<br>1002-无权限操作" // 全局业务错误码枚举描述,json 格式,key 为业务错误码,value 为该错误码的描述,仅当 wrapCodeMsg 为 true 时生效
|
||||||
// securityDefinitionsFromJson 为自定义鉴权配置,json 内容将直接放入 swagger 的 securityDefinitions 中,
|
// securityDefinitionsFromJson 为自定义鉴权配置,json 内容将直接放入 swagger 的 securityDefinitions 中,
|
||||||
// 格式参考 https://swagger.io/specification/v2/#security-definitions-object
|
// 格式参考 https://swagger.io/specification/v2/#security-definitions-object
|
||||||
// 在 api 的 @server 中可声明 authType 来指定其路由使用的鉴权类型
|
// 在 api 的 @server 中可声明 authType 来指定其路由使用的鉴权类型
|
||||||
@@ -52,6 +52,7 @@ type (
|
|||||||
service Swagger {
|
service Swagger {
|
||||||
@doc (
|
@doc (
|
||||||
description: "query 接口"
|
description: "query 接口"
|
||||||
|
bizCodeEnumDescription: " 1003-用不存在<br>1004-非法操作" // 接口级别业务错误码枚举描述,会覆盖全局的业务错误码,json 格式,key 为业务错误码,value 为该错误码的描述,仅当 wrapCodeMsg 为 true 时生效
|
||||||
)
|
)
|
||||||
@handler query
|
@handler query
|
||||||
get /query (QueryReq) returns (QueryResp)
|
get /query (QueryReq) returns (QueryResp)
|
||||||
|
|||||||
@@ -77,10 +77,10 @@ func spec2Path(info apiSpec.Info, group apiSpec.Group, route apiSpec.Route) spec
|
|||||||
Produces: getListFromInfoOrDefault(route.AtDoc.Properties, "produces", []string{applicationJson}),
|
Produces: getListFromInfoOrDefault(route.AtDoc.Properties, "produces", []string{applicationJson}),
|
||||||
Schemes: getListFromInfoOrDefault(route.AtDoc.Properties, "schemes", []string{schemeHttps}),
|
Schemes: getListFromInfoOrDefault(route.AtDoc.Properties, "schemes", []string{schemeHttps}),
|
||||||
Tags: getListFromInfoOrDefault(group.Annotation.Properties, "tags", []string{""}),
|
Tags: getListFromInfoOrDefault(group.Annotation.Properties, "tags", []string{""}),
|
||||||
Summary: getStringFromKVOrDefault(route.AtDoc.Properties, "summary", ""),
|
Summary: getStringFromKVOrDefault(route.AtDoc.Properties, "summary", getFirstUsableString(route.AtDoc.Text, route.Handler)),
|
||||||
Deprecated: getBoolFromKVOrDefault(route.AtDoc.Properties, "deprecated", false),
|
Deprecated: getBoolFromKVOrDefault(route.AtDoc.Properties, "deprecated", false),
|
||||||
Parameters: parametersFromType(route.Method, route.RequestType),
|
Parameters: parametersFromType(route.Method, route.RequestType),
|
||||||
Responses: jsonResponseFromType(info, route.ResponseType),
|
Responses: jsonResponseFromType(info, route.AtDoc, route.ResponseType),
|
||||||
Security: security,
|
Security: security,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import (
|
|||||||
apiSpec "github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
apiSpec "github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||||
)
|
)
|
||||||
|
|
||||||
func jsonResponseFromType(info apiSpec.Info, tp apiSpec.Type) *spec.Responses {
|
func jsonResponseFromType(info apiSpec.Info, atDoc apiSpec.AtDoc, tp apiSpec.Type) *spec.Responses {
|
||||||
p, _ := propertiesFromType(tp)
|
p, _ := propertiesFromType(tp)
|
||||||
props := spec.SchemaProps{
|
props := spec.SchemaProps{
|
||||||
Type: typeFromGoType(tp),
|
Type: typeFromGoType(tp),
|
||||||
@@ -19,7 +19,7 @@ func jsonResponseFromType(info apiSpec.Info, tp apiSpec.Type) *spec.Responses {
|
|||||||
Default: &spec.Response{
|
Default: &spec.Response{
|
||||||
ResponseProps: spec.ResponseProps{
|
ResponseProps: spec.ResponseProps{
|
||||||
Schema: &spec.Schema{
|
Schema: &spec.Schema{
|
||||||
SchemaProps: wrapCodeMsgProps(props, info),
|
SchemaProps: wrapCodeMsgProps(props, info, atDoc),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -256,11 +256,13 @@ func pathVariable2SwaggerVariable(path string) string {
|
|||||||
return "/" + strings.Join(resp, "/")
|
return "/" + strings.Join(resp, "/")
|
||||||
}
|
}
|
||||||
|
|
||||||
func wrapCodeMsgProps(properties spec.SchemaProps, api apiSpec.Info) spec.SchemaProps {
|
func wrapCodeMsgProps(properties spec.SchemaProps, api apiSpec.Info, atDoc apiSpec.AtDoc) spec.SchemaProps {
|
||||||
wrapCodeMsg := getBoolFromKVOrDefault(api.Properties, "wrapCodeMsg", false)
|
wrapCodeMsg := getBoolFromKVOrDefault(api.Properties, "wrapCodeMsg", false)
|
||||||
if !wrapCodeMsg {
|
if !wrapCodeMsg {
|
||||||
return properties
|
return properties
|
||||||
}
|
}
|
||||||
|
globalCodeDesc := getStringFromKVOrDefault(api.Properties, "bizCodeEnumDescription", "business code")
|
||||||
|
methodCodeDesc := getStringFromKVOrDefault(atDoc.Properties, "bizCodeEnumDescription", globalCodeDesc)
|
||||||
return spec.SchemaProps{
|
return spec.SchemaProps{
|
||||||
Type: []string{swaggerTypeObject},
|
Type: []string{swaggerTypeObject},
|
||||||
Properties: spec.SchemaProperties{
|
Properties: spec.SchemaProperties{
|
||||||
@@ -270,7 +272,7 @@ func wrapCodeMsgProps(properties spec.SchemaProps, api apiSpec.Info) spec.Schema
|
|||||||
},
|
},
|
||||||
SchemaProps: spec.SchemaProps{
|
SchemaProps: spec.SchemaProps{
|
||||||
Type: []string{swaggerTypeInteger},
|
Type: []string{swaggerTypeInteger},
|
||||||
Description: getStringFromKVOrDefault(api.Properties, "bizCodeEnumDescription", "business code"),
|
Description: methodCodeDesc,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"msg": {
|
"msg": {
|
||||||
|
|||||||
@@ -34,6 +34,10 @@ func writeProperty(writer io.Writer, member spec.Member, indent int) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if strings.Contains(name, "-") {
|
||||||
|
name = fmt.Sprintf("'%s'", name)
|
||||||
|
}
|
||||||
|
|
||||||
comment := member.GetComment()
|
comment := member.GetComment()
|
||||||
if len(comment) > 0 {
|
if len(comment) > 0 {
|
||||||
comment = strings.TrimPrefix(comment, "//")
|
comment = strings.TrimPrefix(comment, "//")
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ go 1.21
|
|||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/DATA-DOG/go-sqlmock v1.5.2
|
github.com/DATA-DOG/go-sqlmock v1.5.2
|
||||||
github.com/emicklei/proto v1.14.0
|
github.com/emicklei/proto v1.14.1
|
||||||
github.com/fatih/structtag v1.2.0
|
github.com/fatih/structtag v1.2.0
|
||||||
github.com/go-openapi/spec v0.21.1-0.20250328170532-a3928469592e
|
github.com/go-openapi/spec v0.21.1-0.20250328170532-a3928469592e
|
||||||
github.com/go-sql-driver/mysql v1.9.0
|
github.com/go-sql-driver/mysql v1.9.0
|
||||||
|
|||||||
@@ -32,8 +32,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
|
|||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||||
github.com/emicklei/go-restful/v3 v3.11.0 h1:rAQeMHw1c7zTmncogyy8VvRZwtkmkZ4FxERmMY4rD+g=
|
github.com/emicklei/go-restful/v3 v3.11.0 h1:rAQeMHw1c7zTmncogyy8VvRZwtkmkZ4FxERmMY4rD+g=
|
||||||
github.com/emicklei/go-restful/v3 v3.11.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc=
|
github.com/emicklei/go-restful/v3 v3.11.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc=
|
||||||
github.com/emicklei/proto v1.14.0 h1:WYxC0OrBuuC+FUCTZvb8+fzEHdZMwLEF+OnVfZA3LXU=
|
github.com/emicklei/proto v1.14.1 h1:fFq+Bj70XXZWXWikcVRvYZxrMS4KIIiPAqdJ8vPrenY=
|
||||||
github.com/emicklei/proto v1.14.0/go.mod h1:rn1FgRS/FANiZdD2djyH7TMA9jdRDcYQ9IEN9yvjX0A=
|
github.com/emicklei/proto v1.14.1/go.mod h1:rn1FgRS/FANiZdD2djyH7TMA9jdRDcYQ9IEN9yvjX0A=
|
||||||
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
|
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
|
||||||
github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU=
|
github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU=
|
||||||
github.com/fatih/structtag v1.2.0 h1:/OdNE99OxoI/PqaW/SuSK9uxxT3f/tcSZgon/ssNSx4=
|
github.com/fatih/structtag v1.2.0 h1:/OdNE99OxoI/PqaW/SuSK9uxxT3f/tcSZgon/ssNSx4=
|
||||||
|
|||||||
@@ -28,7 +28,7 @@
|
|||||||
"dir": "{{.goctl.api.dir}}",
|
"dir": "{{.goctl.api.dir}}",
|
||||||
"iu": "Ignore update",
|
"iu": "Ignore update",
|
||||||
"stdin": "Use stdin to input api doc content, press \"ctrl + d\" to send EOF",
|
"stdin": "Use stdin to input api doc content, press \"ctrl + d\" to send EOF",
|
||||||
"declare": "Use to skip check api types already declare"
|
"declare": "Use to skip check api types already declare, deprecated if goctl version >= 1.8.3"
|
||||||
},
|
},
|
||||||
"go": {
|
"go": {
|
||||||
"short": "Generate go files for provided api in api file",
|
"short": "Generate go files for provided api in api file",
|
||||||
@@ -185,6 +185,7 @@
|
|||||||
},
|
},
|
||||||
"pg": {
|
"pg": {
|
||||||
"short": "Generate postgresql model",
|
"short": "Generate postgresql model",
|
||||||
|
"prefix": "The cache prefix, effective when --cache is true",
|
||||||
"datasource": {
|
"datasource": {
|
||||||
"short": "Generate model from datasource",
|
"short": "Generate model from datasource",
|
||||||
"url": "The data source of database,like \"postgres://root:password@127.0.0.1:5432/database?sslmode=disable\"",
|
"url": "The data source of database,like \"postgres://root:password@127.0.0.1:5432/database?sslmode=disable\"",
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// BuildVersion is the version of goctl.
|
// BuildVersion is the version of goctl.
|
||||||
const BuildVersion = "1.8.3-beta"
|
const BuildVersion = "1.8.3"
|
||||||
|
|
||||||
var tag = map[string]int{"pre-alpha": 0, "alpha": 1, "pre-bata": 2, "beta": 3, "released": 4, "": 5}
|
var tag = map[string]int{"pre-alpha": 0, "alpha": 1, "pre-bata": 2, "beta": 3, "released": 4, "": 5}
|
||||||
|
|
||||||
|
|||||||
@@ -58,6 +58,7 @@ func init() {
|
|||||||
pgDatasourceCmdFlags.StringVar(&command.VarStringBranch, "branch")
|
pgDatasourceCmdFlags.StringVar(&command.VarStringBranch, "branch")
|
||||||
pgCmd.PersistentFlags().StringSliceVarPWithDefaultValue(&command.VarStringSliceIgnoreColumns,
|
pgCmd.PersistentFlags().StringSliceVarPWithDefaultValue(&command.VarStringSliceIgnoreColumns,
|
||||||
"ignore-columns", "i", []string{"create_at", "created_at", "create_time", "update_at", "updated_at", "update_time"})
|
"ignore-columns", "i", []string{"create_at", "created_at", "create_time", "update_at", "updated_at", "update_time"})
|
||||||
|
pgCmd.PersistentFlags().StringVarPWithDefaultValue(&command.VarStringCachePrefix, "prefix", "p", "cache")
|
||||||
|
|
||||||
mongoCmdFlags.StringSliceVarP(&mongo.VarStringSliceType, "type", "t")
|
mongoCmdFlags.StringSliceVarP(&mongo.VarStringSliceType, "type", "t")
|
||||||
mongoCmdFlags.BoolVarP(&mongo.VarBoolCache, "cache", "c")
|
mongoCmdFlags.BoolVarP(&mongo.VarBoolCache, "cache", "c")
|
||||||
|
|||||||
@@ -225,7 +225,20 @@ func PostgreSqlDataSource(_ *cobra.Command, _ []string) error {
|
|||||||
}
|
}
|
||||||
ignoreColumns := mergeColumns(VarStringSliceIgnoreColumns)
|
ignoreColumns := mergeColumns(VarStringSliceIgnoreColumns)
|
||||||
|
|
||||||
return fromPostgreSqlDataSource(url, patterns, dir, schema, cfg, cache, idea, VarBoolStrict, ignoreColumns)
|
arg := pgDataSourceArg{
|
||||||
|
url: url,
|
||||||
|
dir: dir,
|
||||||
|
tablePat: patterns,
|
||||||
|
schema: schema,
|
||||||
|
cfg: cfg,
|
||||||
|
cache: cache,
|
||||||
|
idea: idea,
|
||||||
|
strict: VarBoolStrict,
|
||||||
|
ignoreColumns: ignoreColumns,
|
||||||
|
prefix: VarStringCachePrefix,
|
||||||
|
}
|
||||||
|
|
||||||
|
return fromPostgreSqlDataSource(arg)
|
||||||
}
|
}
|
||||||
|
|
||||||
type ddlArg struct {
|
type ddlArg struct {
|
||||||
@@ -339,32 +352,43 @@ func fromMysqlDataSource(arg dataSourceArg) error {
|
|||||||
return generator.StartFromInformationSchema(matchTables, arg.cache, arg.strict)
|
return generator.StartFromInformationSchema(matchTables, arg.cache, arg.strict)
|
||||||
}
|
}
|
||||||
|
|
||||||
func fromPostgreSqlDataSource(url string, pattern pattern, dir, schema string, cfg *config.Config, cache, idea, strict bool, ignoreColumns []string) error {
|
type pgDataSourceArg struct {
|
||||||
log := console.NewConsole(idea)
|
url, dir string
|
||||||
if len(url) == 0 {
|
tablePat pattern
|
||||||
|
schema string
|
||||||
|
cfg *config.Config
|
||||||
|
cache, idea bool
|
||||||
|
strict bool
|
||||||
|
ignoreColumns []string
|
||||||
|
prefix string
|
||||||
|
}
|
||||||
|
|
||||||
|
func fromPostgreSqlDataSource(arg pgDataSourceArg) error {
|
||||||
|
log := console.NewConsole(arg.idea)
|
||||||
|
if len(arg.url) == 0 {
|
||||||
log.Error("%v", "expected data source of postgresql, but nothing found")
|
log.Error("%v", "expected data source of postgresql, but nothing found")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(pattern) == 0 {
|
if len(arg.tablePat) == 0 {
|
||||||
log.Error("%v", "expected table or table globbing patterns, but nothing found")
|
log.Error("%v", "expected table or table globbing patterns, but nothing found")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
db := postgres.New(url)
|
db := postgres.New(arg.url)
|
||||||
im := model.NewPostgreSqlModel(db)
|
im := model.NewPostgreSqlModel(db)
|
||||||
|
|
||||||
tables, err := im.GetAllTables(schema)
|
tables, err := im.GetAllTables(arg.schema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
matchTables := make(map[string]*model.Table)
|
matchTables := make(map[string]*model.Table)
|
||||||
for _, item := range tables {
|
for _, item := range tables {
|
||||||
if !pattern.Match(item) {
|
if !arg.tablePat.Match(item) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
columnData, err := im.FindColumns(schema, item)
|
columnData, err := im.FindColumns(arg.schema, item)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -381,10 +405,11 @@ func fromPostgreSqlDataSource(url string, pattern pattern, dir, schema string, c
|
|||||||
return errors.New("no tables matched")
|
return errors.New("no tables matched")
|
||||||
}
|
}
|
||||||
|
|
||||||
generator, err := gen.NewDefaultGenerator("", dir, cfg, gen.WithConsoleOption(log), gen.WithPostgreSql(), gen.WithIgnoreColumns(ignoreColumns))
|
generator, err := gen.NewDefaultGenerator(arg.prefix, arg.dir, arg.cfg, gen.WithConsoleOption(log),
|
||||||
|
gen.WithPostgreSql(), gen.WithIgnoreColumns(arg.ignoreColumns))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return generator.StartFromInformationSchema(matchTables, cache, strict)
|
return generator.StartFromInformationSchema(matchTables, arg.cache, arg.strict)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -465,6 +465,20 @@ func (p *Parser) parsePathItem() []token.Token {
|
|||||||
if !p.advanceIfPeekTokenIs(token.IDENT) {
|
if !p.advanceIfPeekTokenIs(token.IDENT) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
list = append(list, p.curTok)
|
||||||
|
} else if p.peekTokenIs(token.DOT) {
|
||||||
|
// Allow dot (.) in path segments for file extensions like .php, .html, etc.
|
||||||
|
if !p.nextToken() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
list = append(list, p.curTok)
|
||||||
|
|
||||||
|
// After a dot, we expect an identifier (e.g., .php, .html)
|
||||||
|
if !p.advanceIfPeekTokenIs(token.IDENT) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
list = append(list, p.curTok)
|
list = append(list, p.curTok)
|
||||||
} else {
|
} else {
|
||||||
if p.peekTokenIs(token.LPAREN, token.Returns, token.AT_DOC, token.AT_HANDLER, token.SEMICOLON, token.RBRACE) {
|
if p.peekTokenIs(token.LPAREN, token.Returns, token.AT_DOC, token.AT_HANDLER, token.SEMICOLON, token.RBRACE) {
|
||||||
|
|||||||
@@ -760,11 +760,6 @@ func TestParser_Parse_service(t *testing.T) {
|
|||||||
}),
|
}),
|
||||||
},
|
},
|
||||||
Request: &ast.BodyStmt{
|
Request: &ast.BodyStmt{
|
||||||
LParen: ast.NewTokenNode(token.Token{Type: token.LPAREN, Text: "("}),
|
|
||||||
RParen: ast.NewTokenNode(token.Token{Type: token.RPAREN, Text: ")"}),
|
|
||||||
},
|
|
||||||
Returns: ast.NewTokenNode(token.Token{Type: token.IDENT, Text: "returns"}),
|
|
||||||
Response: &ast.BodyStmt{
|
|
||||||
LParen: ast.NewTokenNode(token.Token{Type: token.LPAREN, Text: "("}),
|
LParen: ast.NewTokenNode(token.Token{Type: token.LPAREN, Text: "("}),
|
||||||
Body: &ast.BodyExpr{
|
Body: &ast.BodyExpr{
|
||||||
Value: ast.NewTokenNode(token.Token{Type: token.IDENT, Text: "Foo"}),
|
Value: ast.NewTokenNode(token.Token{Type: token.IDENT, Text: "Foo"}),
|
||||||
@@ -1031,6 +1026,7 @@ func TestParser_Parse_pathItem(t *testing.T) {
|
|||||||
{input: "1", expected: "1"},
|
{input: "1", expected: "1"},
|
||||||
{input: "11", expected: "11"},
|
{input: "11", expected: "11"},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, v := range testData {
|
for _, v := range testData {
|
||||||
p := New("foo.api", v.input)
|
p := New("foo.api", v.input)
|
||||||
ok := p.nextToken()
|
ok := p.nextToken()
|
||||||
@@ -1066,6 +1062,38 @@ func TestParser_Parse_pathItem(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParser_Parse_pathItem_WithDot(t *testing.T) {
|
||||||
|
t.Run("valid with dots", func(t *testing.T) {
|
||||||
|
var testData = []struct {
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{input: "file.php", expected: "file.php"},
|
||||||
|
{input: "api_jsonrpc.php", expected: "api_jsonrpc.php"},
|
||||||
|
{input: "index.html", expected: "index.html"},
|
||||||
|
{input: "data.json", expected: "data.json"},
|
||||||
|
{input: "style.css", expected: "style.css"},
|
||||||
|
{input: "script.js", expected: "script.js"},
|
||||||
|
{input: "document.pdf", expected: "document.pdf"},
|
||||||
|
{input: "image.png", expected: "image.png"},
|
||||||
|
{input: "api.v1", expected: "api.v1"},
|
||||||
|
{input: "resource.with.multiple.dots", expected: "resource.with.multiple.dots"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range testData {
|
||||||
|
p := New("foo.api", v.input)
|
||||||
|
ok := p.nextToken()
|
||||||
|
assert.True(t, ok)
|
||||||
|
tokens := p.parsePathItem()
|
||||||
|
var expected []string
|
||||||
|
for _, tok := range tokens {
|
||||||
|
expected = append(expected, tok.Text)
|
||||||
|
}
|
||||||
|
assert.Equal(t, strings.Join(expected, ""), v.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestParser_Parse_parseTypeStmt(t *testing.T) {
|
func TestParser_Parse_parseTypeStmt(t *testing.T) {
|
||||||
assertEqual := func(t *testing.T, expected, actual ast.Stmt) {
|
assertEqual := func(t *testing.T, expected, actual ast.Stmt) {
|
||||||
if expected == nil {
|
if expected == nil {
|
||||||
@@ -1399,6 +1427,7 @@ func TestParser_Parse_parseTypeStmt(t *testing.T) {
|
|||||||
assertEqual(t, val.expected, one)
|
assertEqual(t, val.expected, one)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("parseTypeGroupStmt", func(t *testing.T) {
|
t.Run("parseTypeGroupStmt", func(t *testing.T) {
|
||||||
var testData = []struct {
|
var testData = []struct {
|
||||||
input string
|
input string
|
||||||
@@ -1472,6 +1501,7 @@ func TestParser_Parse_parseTypeStmt(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, val := range testData {
|
for _, val := range testData {
|
||||||
p := New("test.api", val.input)
|
p := New("test.api", val.input)
|
||||||
result := p.Parse()
|
result := p.Parse()
|
||||||
|
|||||||
Reference in New Issue
Block a user