mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-06-27 06:21:00 +08:00
Compare commits
99 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c9ff6a10d3 | ||
|
|
a71e56de52 | ||
|
|
bae8d4f4c8 | ||
|
|
8c6266f338 | ||
|
|
95d5b81f44 | ||
|
|
bca7bbc142 | ||
|
|
df9a52664b | ||
|
|
937cf0db96 | ||
|
|
75cebb65f8 | ||
|
|
410f56e73a | ||
|
|
017909a3ab | ||
|
|
0d31e6c375 | ||
|
|
0ba86b1849 | ||
|
|
4cacc4d9d3 | ||
|
|
a99c14da4a | ||
|
|
985582264a | ||
|
|
8364e341e1 | ||
|
|
0f2b589d4d | ||
|
|
19fec36d24 | ||
|
|
f037bf344d | ||
|
|
d99cf35b07 | ||
|
|
f459f1b5ff | ||
|
|
0140fd417b | ||
|
|
7969e0ca38 | ||
|
|
91c885b5b0 | ||
|
|
d4cccca387 | ||
|
|
4b2095ed03 | ||
|
|
1229eeb2d2 | ||
|
|
9142b146c5 | ||
|
|
8a1b2d5aed | ||
|
|
da5d39e6ca | ||
|
|
68c5a17c67 | ||
|
|
b53f9f5f2d | ||
|
|
36d57626b6 | ||
|
|
4e36ba832f | ||
|
|
a44954a771 | ||
|
|
f3edd4b880 | ||
|
|
2de3e397ff | ||
|
|
a435eb56f2 | ||
|
|
d80761c147 | ||
|
|
e7bd0d8b60 | ||
|
|
b109b3ef4c | ||
|
|
e3c371ac89 | ||
|
|
15eb6f4f6d | ||
|
|
4d3681b71c | ||
|
|
a682bda0bb | ||
|
|
45b27ad93a | ||
|
|
292a8302a1 | ||
|
|
91ab1f6d2b | ||
|
|
5048c350ae | ||
|
|
94edc32f3e | ||
|
|
ec989b2e2a | ||
|
|
82fe802e81 | ||
|
|
072d68f897 | ||
|
|
2e91ba5811 | ||
|
|
5564c43197 | ||
|
|
e55158b0f7 | ||
|
|
69aa7fe346 | ||
|
|
c3820a95c1 | ||
|
|
493f3bad0f | ||
|
|
eb0d5ad3a4 | ||
|
|
14192050ae | ||
|
|
9193e771e3 | ||
|
|
808b4e496a | ||
|
|
e416d01f8d | ||
|
|
789c5de873 | ||
|
|
52078a0c14 | ||
|
|
7ef13116a0 | ||
|
|
6b8053410a | ||
|
|
81c6928445 | ||
|
|
761c2dd716 | ||
|
|
aeceb3cfbe | ||
|
|
15ea07aad1 | ||
|
|
98bebbc74f | ||
|
|
eafd11d949 | ||
|
|
b251ce346e | ||
|
|
812140ba36 | ||
|
|
44735e949c | ||
|
|
bf313c3c56 | ||
|
|
94e7753262 | ||
|
|
9c478626d2 | ||
|
|
801c283478 | ||
|
|
2a54faf997 | ||
|
|
ecd98f3653 | ||
|
|
61641581eb | ||
|
|
6f2730d5ae | ||
|
|
0eff777b62 | ||
|
|
cafbf535f7 | ||
|
|
6edfce63e3 | ||
|
|
cdb0098b18 | ||
|
|
620c7f9693 | ||
|
|
dba444a382 | ||
|
|
b24fb3ebf7 | ||
|
|
967f0926eb | ||
|
|
e68c683df9 | ||
|
|
247985a065 | ||
|
|
80573af0d8 | ||
|
|
c0394b631a | ||
|
|
68d1aba377 |
18
.github/workflows/issue-translator.yml
vendored
18
.github/workflows/issue-translator.yml
vendored
@@ -1,18 +0,0 @@
|
|||||||
name: 'issue-translator'
|
|
||||||
on:
|
|
||||||
issue_comment:
|
|
||||||
types: [created]
|
|
||||||
issues:
|
|
||||||
types: [opened]
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
build:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- uses: usthe/issues-translate-action@v2.7
|
|
||||||
with:
|
|
||||||
IS_MODIFY_TITLE: true
|
|
||||||
# not require, default false, . Decide whether to modify the issue title
|
|
||||||
# if true, the robot account @Issues-translate-bot must have modification permissions, invite @Issues-translate-bot to your project or use your custom bot.
|
|
||||||
CUSTOM_BOT_NOTE: Bot detected the issue body's language is not English, translate it automatically. 👯👭🏻🧑🤝🧑👫🧑🏿🤝🧑🏻👩🏾🤝👨🏿👬🏿
|
|
||||||
# not require. Customize the translation robot prefix message.
|
|
||||||
7
.github/workflows/version-check.yml
vendored
7
.github/workflows/version-check.yml
vendored
@@ -21,7 +21,8 @@ jobs:
|
|||||||
id: get_version
|
id: get_version
|
||||||
run: |
|
run: |
|
||||||
# Extract version from tools/goctl/v* format
|
# Extract version from tools/goctl/v* format
|
||||||
echo "VERSION=${GITHUB_REF#refs/tags/tools/goctl/v}" >> $GITHUB_ENV
|
VERSION="${GITHUB_REF#refs/tags/tools/goctl/v}"
|
||||||
|
echo "VERSION=$VERSION" >> $GITHUB_ENV
|
||||||
echo "Extracted version: $VERSION"
|
echo "Extracted version: $VERSION"
|
||||||
|
|
||||||
- name: Check version in goctl source code
|
- name: Check version in goctl source code
|
||||||
@@ -31,7 +32,11 @@ jobs:
|
|||||||
|
|
||||||
# Check version in BuildVersion constant
|
# Check version in BuildVersion constant
|
||||||
VERSION_IN_CODE=$(grep -r "const BuildVersion =" . | grep -o '".*"' | tr -d '"')
|
VERSION_IN_CODE=$(grep -r "const BuildVersion =" . | grep -o '".*"' | tr -d '"')
|
||||||
|
echo "Version in code: $VERSION_IN_CODE"
|
||||||
|
echo "Expected version: $VERSION"
|
||||||
|
|
||||||
if [ "$VERSION_IN_CODE" != "$VERSION" ]; then
|
if [ "$VERSION_IN_CODE" != "$VERSION" ]; then
|
||||||
echo "Version mismatch: Version in code ($VERSION_IN_CODE) doesn't match tag version ($VERSION)"
|
echo "Version mismatch: Version in code ($VERSION_IN_CODE) doesn't match tag version ($VERSION)"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
echo "✅ Version check passed!"
|
||||||
|
|||||||
@@ -8,16 +8,12 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/core/mathx"
|
|
||||||
"github.com/zeromicro/go-zero/core/proc"
|
"github.com/zeromicro/go-zero/core/proc"
|
||||||
"github.com/zeromicro/go-zero/core/stat"
|
"github.com/zeromicro/go-zero/core/stat"
|
||||||
"github.com/zeromicro/go-zero/core/stringx"
|
"github.com/zeromicro/go-zero/core/stringx"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const numHistoryReasons = 5
|
||||||
numHistoryReasons = 5
|
|
||||||
timeFormat = "15:04:05"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ErrServiceUnavailable is returned when the Breaker state is open.
|
// ErrServiceUnavailable is returned when the Breaker state is open.
|
||||||
var ErrServiceUnavailable = errors.New("circuit breaker is open")
|
var ErrServiceUnavailable = errors.New("circuit breaker is open")
|
||||||
@@ -262,9 +258,9 @@ type errorWindow struct {
|
|||||||
|
|
||||||
func (ew *errorWindow) add(reason string) {
|
func (ew *errorWindow) add(reason string) {
|
||||||
ew.lock.Lock()
|
ew.lock.Lock()
|
||||||
ew.reasons[ew.index] = fmt.Sprintf("%s %s", time.Now().Format(timeFormat), reason)
|
ew.reasons[ew.index] = fmt.Sprintf("%s %s", time.Now().Format(time.TimeOnly), reason)
|
||||||
ew.index = (ew.index + 1) % numHistoryReasons
|
ew.index = (ew.index + 1) % numHistoryReasons
|
||||||
ew.count = mathx.MinInt(ew.count+1, numHistoryReasons)
|
ew.count = min(ew.count+1, numHistoryReasons)
|
||||||
ew.lock.Unlock()
|
ew.lock.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -423,7 +423,7 @@ func TestRegistry_Monitor(t *testing.T) {
|
|||||||
GetRegistry().clusters = map[string]*cluster{
|
GetRegistry().clusters = map[string]*cluster{
|
||||||
getClusterKey(endpoints): {
|
getClusterKey(endpoints): {
|
||||||
watchers: map[watchKey]*watchValue{
|
watchers: map[watchKey]*watchValue{
|
||||||
watchKey{
|
{
|
||||||
key: "foo",
|
key: "foo",
|
||||||
exactMatch: true,
|
exactMatch: true,
|
||||||
}: {
|
}: {
|
||||||
@@ -449,7 +449,7 @@ func TestRegistry_Unmonitor(t *testing.T) {
|
|||||||
GetRegistry().clusters = map[string]*cluster{
|
GetRegistry().clusters = map[string]*cluster{
|
||||||
getClusterKey(endpoints): {
|
getClusterKey(endpoints): {
|
||||||
watchers: map[watchKey]*watchValue{
|
watchers: map[watchKey]*watchValue{
|
||||||
watchKey{
|
{
|
||||||
key: "foo",
|
key: "foo",
|
||||||
exactMatch: true,
|
exactMatch: true,
|
||||||
}: {
|
}: {
|
||||||
|
|||||||
@@ -86,21 +86,16 @@ func TestConsistentHashIncrementalTransfer(t *testing.T) {
|
|||||||
|
|
||||||
func TestConsistentHashTransferOnFailure(t *testing.T) {
|
func TestConsistentHashTransferOnFailure(t *testing.T) {
|
||||||
index := 41
|
index := 41
|
||||||
keys, newKeys := getKeysBeforeAndAfterFailure(t, "localhost:", index)
|
ratioNotExists := getTransferRatioOnFailure(t, index)
|
||||||
var transferred int
|
assert.True(t, ratioNotExists == 0, fmt.Sprintf("%d: %f", index, ratioNotExists))
|
||||||
for k, v := range newKeys {
|
index = 13
|
||||||
if v != keys[k] {
|
ratio := getTransferRatioOnFailure(t, index)
|
||||||
transferred++
|
assert.True(t, ratio < 2.5/keySize, fmt.Sprintf("%d: %f", index, ratio))
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ratio := float32(transferred) / float32(requestSize)
|
|
||||||
assert.True(t, ratio < 2.5/float32(keySize), fmt.Sprintf("%d: %f", index, ratio))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConsistentHashLeastTransferOnFailure(t *testing.T) {
|
func TestConsistentHashLeastTransferOnFailure(t *testing.T) {
|
||||||
prefix := "localhost:"
|
prefix := "localhost:"
|
||||||
index := 41
|
index := 13
|
||||||
keys, newKeys := getKeysBeforeAndAfterFailure(t, prefix, index)
|
keys, newKeys := getKeysBeforeAndAfterFailure(t, prefix, index)
|
||||||
for k, v := range keys {
|
for k, v := range keys {
|
||||||
newV := newKeys[k]
|
newV := newKeys[k]
|
||||||
@@ -164,6 +159,17 @@ func getKeysBeforeAndAfterFailure(t *testing.T, prefix string, index int) (map[i
|
|||||||
return keys, newKeys
|
return keys, newKeys
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getTransferRatioOnFailure(t *testing.T, index int) float32 {
|
||||||
|
keys, newKeys := getKeysBeforeAndAfterFailure(t, "localhost:", index)
|
||||||
|
var transferred int
|
||||||
|
for k, v := range newKeys {
|
||||||
|
if v != keys[k] {
|
||||||
|
transferred++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return float32(transferred) / float32(requestSize)
|
||||||
|
}
|
||||||
|
|
||||||
type mockNode struct {
|
type mockNode struct {
|
||||||
addr string
|
addr string
|
||||||
id int
|
id int
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ package hash
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/md5"
|
"crypto/md5"
|
||||||
"fmt"
|
"encoding/hex"
|
||||||
|
|
||||||
"github.com/spaolacci/murmur3"
|
"github.com/spaolacci/murmur3"
|
||||||
)
|
)
|
||||||
@@ -20,6 +20,7 @@ func Md5(data []byte) []byte {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Md5Hex returns the md5 hex string of data.
|
// Md5Hex returns the md5 hex string of data.
|
||||||
|
// This function is optimized for better performance than fmt.Sprintf.
|
||||||
func Md5Hex(data []byte) string {
|
func Md5Hex(data []byte) string {
|
||||||
return fmt.Sprintf("%x", Md5(data))
|
return hex.EncodeToString(Md5(data))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,12 +7,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
fieldsContextKey contextKey
|
|
||||||
globalFields atomic.Value
|
globalFields atomic.Value
|
||||||
globalFieldsLock sync.Mutex
|
globalFieldsLock sync.Mutex
|
||||||
)
|
)
|
||||||
|
|
||||||
type contextKey struct{}
|
type fieldsKey struct{}
|
||||||
|
|
||||||
// AddGlobalFields adds global fields.
|
// AddGlobalFields adds global fields.
|
||||||
func AddGlobalFields(fields ...LogField) {
|
func AddGlobalFields(fields ...LogField) {
|
||||||
@@ -29,16 +28,16 @@ func AddGlobalFields(fields ...LogField) {
|
|||||||
|
|
||||||
// ContextWithFields returns a new context with the given fields.
|
// ContextWithFields returns a new context with the given fields.
|
||||||
func ContextWithFields(ctx context.Context, fields ...LogField) context.Context {
|
func ContextWithFields(ctx context.Context, fields ...LogField) context.Context {
|
||||||
if val := ctx.Value(fieldsContextKey); val != nil {
|
if val := ctx.Value(fieldsKey{}); val != nil {
|
||||||
if arr, ok := val.([]LogField); ok {
|
if arr, ok := val.([]LogField); ok {
|
||||||
allFields := make([]LogField, 0, len(arr)+len(fields))
|
allFields := make([]LogField, 0, len(arr)+len(fields))
|
||||||
allFields = append(allFields, arr...)
|
allFields = append(allFields, arr...)
|
||||||
allFields = append(allFields, fields...)
|
allFields = append(allFields, fields...)
|
||||||
return context.WithValue(ctx, fieldsContextKey, allFields)
|
return context.WithValue(ctx, fieldsKey{}, allFields)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return context.WithValue(ctx, fieldsContextKey, fields)
|
return context.WithValue(ctx, fieldsKey{}, fields)
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithFields returns a new logger with the given fields.
|
// WithFields returns a new logger with the given fields.
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ func TestAddGlobalFields(t *testing.T) {
|
|||||||
|
|
||||||
func TestContextWithFields(t *testing.T) {
|
func TestContextWithFields(t *testing.T) {
|
||||||
ctx := ContextWithFields(context.Background(), Field("a", 1), Field("b", 2))
|
ctx := ContextWithFields(context.Background(), Field("a", 1), Field("b", 2))
|
||||||
vals := ctx.Value(fieldsContextKey)
|
vals := ctx.Value(fieldsKey{})
|
||||||
assert.NotNil(t, vals)
|
assert.NotNil(t, vals)
|
||||||
fields, ok := vals.([]LogField)
|
fields, ok := vals.([]LogField)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
@@ -43,7 +43,7 @@ func TestContextWithFields(t *testing.T) {
|
|||||||
|
|
||||||
func TestWithFields(t *testing.T) {
|
func TestWithFields(t *testing.T) {
|
||||||
ctx := WithFields(context.Background(), Field("a", 1), Field("b", 2))
|
ctx := WithFields(context.Background(), Field("a", 1), Field("b", 2))
|
||||||
vals := ctx.Value(fieldsContextKey)
|
vals := ctx.Value(fieldsKey{})
|
||||||
assert.NotNil(t, vals)
|
assert.NotNil(t, vals)
|
||||||
fields, ok := vals.([]LogField)
|
fields, ok := vals.([]LogField)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
@@ -55,7 +55,7 @@ func TestWithFieldsAppend(t *testing.T) {
|
|||||||
ctx := context.WithValue(context.Background(), dummyKey, "dummy")
|
ctx := context.WithValue(context.Background(), dummyKey, "dummy")
|
||||||
ctx = ContextWithFields(ctx, Field("a", 1), Field("b", 2))
|
ctx = ContextWithFields(ctx, Field("a", 1), Field("b", 2))
|
||||||
ctx = ContextWithFields(ctx, Field("c", 3), Field("d", 4))
|
ctx = ContextWithFields(ctx, Field("c", 3), Field("d", 4))
|
||||||
vals := ctx.Value(fieldsContextKey)
|
vals := ctx.Value(fieldsKey{})
|
||||||
assert.NotNil(t, vals)
|
assert.NotNil(t, vals)
|
||||||
fields, ok := vals.([]LogField)
|
fields, ok := vals.([]LogField)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
@@ -80,8 +80,8 @@ func TestWithFieldsAppendCopy(t *testing.T) {
|
|||||||
ctxa := ContextWithFields(ctx, af)
|
ctxa := ContextWithFields(ctx, af)
|
||||||
ctxb := ContextWithFields(ctx, bf)
|
ctxb := ContextWithFields(ctx, bf)
|
||||||
|
|
||||||
assert.EqualValues(t, af, ctxa.Value(fieldsContextKey).([]LogField)[count])
|
assert.EqualValues(t, af, ctxa.Value(fieldsKey{}).([]LogField)[count])
|
||||||
assert.EqualValues(t, bf, ctxb.Value(fieldsContextKey).([]LogField)[count])
|
assert.EqualValues(t, bf, ctxb.Value(fieldsKey{}).([]LogField)[count])
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkAtomicValue(b *testing.B) {
|
func BenchmarkAtomicValue(b *testing.B) {
|
||||||
|
|||||||
@@ -224,7 +224,7 @@ func (l *richLogger) buildFields(fields ...LogField) []LogField {
|
|||||||
fields = append(fields, Field(spanKey, spanID))
|
fields = append(fields, Field(spanKey, spanID))
|
||||||
}
|
}
|
||||||
|
|
||||||
val := l.ctx.Value(fieldsContextKey)
|
val := l.ctx.Value(fieldsKey{})
|
||||||
if val != nil {
|
if val != nil {
|
||||||
if arr, ok := val.([]LogField); ok {
|
if arr, ok := val.([]LogField); ok {
|
||||||
fields = append(fields, arr...)
|
fields = append(fields, arr...)
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
dateFormat = "2006-01-02"
|
|
||||||
hoursPerDay = 24
|
hoursPerDay = 24
|
||||||
bufferSize = 100
|
bufferSize = 100
|
||||||
defaultDirMode = 0o755
|
defaultDirMode = 0o755
|
||||||
@@ -116,7 +115,7 @@ func (r *DailyRotateRule) OutdatedFiles() []string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var buf strings.Builder
|
var buf strings.Builder
|
||||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay*r.days)).Format(dateFormat)
|
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay*r.days)).Format(time.DateOnly)
|
||||||
buf.WriteString(r.filename)
|
buf.WriteString(r.filename)
|
||||||
buf.WriteString(r.delimiter)
|
buf.WriteString(r.delimiter)
|
||||||
buf.WriteString(boundary)
|
buf.WriteString(boundary)
|
||||||
@@ -425,7 +424,7 @@ func compressLogFile(file string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func getNowDate() string {
|
func getNowDate() string {
|
||||||
return time.Now().Format(dateFormat)
|
return time.Now().Format(time.DateOnly)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getNowDateInRFC3339Format() string {
|
func getNowDateInRFC3339Format() string {
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ func TestDailyRotateRuleOutdatedFiles(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("temp files", func(t *testing.T) {
|
t.Run("temp files", func(t *testing.T) {
|
||||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
|
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(time.DateOnly)
|
||||||
f1, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
f1, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
_ = f1.Close()
|
_ = f1.Close()
|
||||||
@@ -73,7 +73,7 @@ func TestDailyRotateRuleOutdatedFiles(t *testing.T) {
|
|||||||
|
|
||||||
func TestDailyRotateRuleShallRotate(t *testing.T) {
|
func TestDailyRotateRuleShallRotate(t *testing.T) {
|
||||||
var rule DailyRotateRule
|
var rule DailyRotateRule
|
||||||
rule.rotatedTime = time.Now().Add(time.Hour * 24).Format(dateFormat)
|
rule.rotatedTime = time.Now().Add(time.Hour * 24).Format(time.DateOnly)
|
||||||
assert.True(t, rule.ShallRotate(0))
|
assert.True(t, rule.ShallRotate(0))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -117,12 +117,12 @@ func TestSizeLimitRotateRuleOutdatedFiles(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("temp files", func(t *testing.T) {
|
t.Run("temp files", func(t *testing.T) {
|
||||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
|
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(time.DateOnly)
|
||||||
f1, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
f1, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
f2, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
f2, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
boundary1 := time.Now().Add(time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
|
boundary1 := time.Now().Add(time.Hour * time.Duration(hoursPerDay) * 2).Format(time.DateOnly)
|
||||||
f3, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary1)
|
f3, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary1)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
@@ -144,12 +144,12 @@ func TestSizeLimitRotateRuleOutdatedFiles(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("no backups", func(t *testing.T) {
|
t.Run("no backups", func(t *testing.T) {
|
||||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
|
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(time.DateOnly)
|
||||||
f1, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
f1, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
f2, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
f2, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
boundary1 := time.Now().Add(time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
|
boundary1 := time.Now().Add(time.Hour * time.Duration(hoursPerDay) * 2).Format(time.DateOnly)
|
||||||
f3, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary1)
|
f3, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary1)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
@@ -319,7 +319,7 @@ func TestRotateLoggerWrite(t *testing.T) {
|
|||||||
}
|
}
|
||||||
// the following write calls cannot be changed to Write, because of DATA RACE.
|
// the following write calls cannot be changed to Write, because of DATA RACE.
|
||||||
logger.write([]byte(`foo`))
|
logger.write([]byte(`foo`))
|
||||||
rule.rotatedTime = time.Now().Add(-time.Hour * 24).Format(dateFormat)
|
rule.rotatedTime = time.Now().Add(-time.Hour * 24).Format(time.DateOnly)
|
||||||
logger.write([]byte(`bar`))
|
logger.write([]byte(`bar`))
|
||||||
logger.Close()
|
logger.Close()
|
||||||
logger.write([]byte(`baz`))
|
logger.write([]byte(`baz`))
|
||||||
@@ -447,7 +447,7 @@ func TestRotateLoggerWithSizeLimitRotateRuleWrite(t *testing.T) {
|
|||||||
}
|
}
|
||||||
// the following write calls cannot be changed to Write, because of DATA RACE.
|
// the following write calls cannot be changed to Write, because of DATA RACE.
|
||||||
logger.write([]byte(`foo`))
|
logger.write([]byte(`foo`))
|
||||||
rule.rotatedTime = time.Now().Add(-time.Hour * 24).Format(dateFormat)
|
rule.rotatedTime = time.Now().Add(-time.Hour * 24).Format(time.DateOnly)
|
||||||
logger.write([]byte(`bar`))
|
logger.write([]byte(`bar`))
|
||||||
logger.Close()
|
logger.Close()
|
||||||
logger.write([]byte(`baz`))
|
logger.write([]byte(`baz`))
|
||||||
|
|||||||
@@ -13,6 +13,15 @@ const (
|
|||||||
|
|
||||||
// Marshal marshals the given val and returns the map that contains the fields.
|
// Marshal marshals the given val and returns the map that contains the fields.
|
||||||
// optional=another is not implemented, and it's hard to implement and not commonly used.
|
// optional=another is not implemented, and it's hard to implement and not commonly used.
|
||||||
|
// support anonymous field, e.g.:
|
||||||
|
//
|
||||||
|
// type Foo struct {
|
||||||
|
// Token string `header:"token"`
|
||||||
|
// }
|
||||||
|
// type FooB struct {
|
||||||
|
// Foo
|
||||||
|
// Bar string `json:"bar"`
|
||||||
|
// }
|
||||||
func Marshal(val any) (map[string]map[string]any, error) {
|
func Marshal(val any) (map[string]map[string]any, error) {
|
||||||
ret := make(map[string]map[string]any)
|
ret := make(map[string]map[string]any)
|
||||||
tp := reflect.TypeOf(val)
|
tp := reflect.TypeOf(val)
|
||||||
@@ -44,6 +53,16 @@ func getTag(field reflect.StructField) (string, bool) {
|
|||||||
return strings.TrimSpace(tag), false
|
return strings.TrimSpace(tag), false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func insertValue(collector map[string]map[string]any, tag string, key string, val any) {
|
||||||
|
if m, ok := collector[tag]; ok {
|
||||||
|
m[key] = val
|
||||||
|
} else {
|
||||||
|
collector[tag] = map[string]any{
|
||||||
|
key: val,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func processMember(field reflect.StructField, value reflect.Value,
|
func processMember(field reflect.StructField, value reflect.Value,
|
||||||
collector map[string]map[string]any) error {
|
collector map[string]map[string]any) error {
|
||||||
var key string
|
var key string
|
||||||
@@ -69,15 +88,20 @@ func processMember(field reflect.StructField, value reflect.Value,
|
|||||||
val = fmt.Sprint(val)
|
val = fmt.Sprint(val)
|
||||||
}
|
}
|
||||||
|
|
||||||
m, ok := collector[tag]
|
if field.Anonymous {
|
||||||
if ok {
|
anonCollector, err := Marshal(val)
|
||||||
m[key] = val
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for anonTag, anonMap := range anonCollector {
|
||||||
|
for anonKey, anonVal := range anonMap {
|
||||||
|
insertValue(collector, anonTag, anonKey, anonVal)
|
||||||
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
m = map[string]any{
|
insertValue(collector, tag, key, val)
|
||||||
key: val,
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
collector[tag] = m
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -118,7 +142,7 @@ func validateOptional(field reflect.StructField, value reflect.Value) error {
|
|||||||
if value.IsNil() {
|
if value.IsNil() {
|
||||||
return fmt.Errorf("field %q is nil", field.Name)
|
return fmt.Errorf("field %q is nil", field.Name)
|
||||||
}
|
}
|
||||||
case reflect.Array, reflect.Slice, reflect.Map:
|
case reflect.Slice, reflect.Map:
|
||||||
if value.IsNil() || value.Len() == 0 {
|
if value.IsNil() || value.Len() == 0 {
|
||||||
return fmt.Errorf("field %q is empty", field.Name)
|
return fmt.Errorf("field %q is empty", field.Name)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -27,6 +27,124 @@ func TestMarshal(t *testing.T) {
|
|||||||
assert.True(t, m[emptyTag]["Anonymous"].(bool))
|
assert.True(t, m[emptyTag]["Anonymous"].(bool))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMarshal_Anonymous(t *testing.T) {
|
||||||
|
t.Run("anonymous", func(t *testing.T) {
|
||||||
|
type BaseHeader struct {
|
||||||
|
Token string `header:"token"`
|
||||||
|
}
|
||||||
|
v := struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Address string `json:"address,options=[beijing,shanghai]"`
|
||||||
|
Age int `json:"age"`
|
||||||
|
BaseHeader
|
||||||
|
}{
|
||||||
|
Name: "kevin",
|
||||||
|
Address: "shanghai",
|
||||||
|
Age: 20,
|
||||||
|
BaseHeader: BaseHeader{
|
||||||
|
Token: "token_xxx",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
m, err := Marshal(v)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "kevin", m["json"]["name"])
|
||||||
|
assert.Equal(t, "shanghai", m["json"]["address"])
|
||||||
|
assert.Equal(t, 20, m["json"]["age"].(int))
|
||||||
|
assert.Equal(t, "token_xxx", m["header"]["token"])
|
||||||
|
|
||||||
|
v1 := struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Address string `json:"address,options=[beijing,shanghai]"`
|
||||||
|
Age int `json:"age"`
|
||||||
|
BaseHeader
|
||||||
|
}{
|
||||||
|
Name: "kevin",
|
||||||
|
Address: "shanghai",
|
||||||
|
Age: 20,
|
||||||
|
}
|
||||||
|
m1, err1 := Marshal(v1)
|
||||||
|
assert.Nil(t, err1)
|
||||||
|
assert.Equal(t, "kevin", m1["json"]["name"])
|
||||||
|
assert.Equal(t, "shanghai", m1["json"]["address"])
|
||||||
|
assert.Equal(t, 20, m1["json"]["age"].(int))
|
||||||
|
|
||||||
|
type AnotherHeader struct {
|
||||||
|
Version string `header:"version"`
|
||||||
|
}
|
||||||
|
v2 := struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Address string `json:"address,options=[beijing,shanghai]"`
|
||||||
|
Age int `json:"age"`
|
||||||
|
BaseHeader
|
||||||
|
AnotherHeader
|
||||||
|
}{
|
||||||
|
Name: "kevin",
|
||||||
|
Address: "shanghai",
|
||||||
|
Age: 20,
|
||||||
|
BaseHeader: BaseHeader{
|
||||||
|
Token: "token_xxx",
|
||||||
|
},
|
||||||
|
AnotherHeader: AnotherHeader{
|
||||||
|
Version: "v1.0",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
m2, err2 := Marshal(v2)
|
||||||
|
assert.Nil(t, err2)
|
||||||
|
assert.Equal(t, "kevin", m2["json"]["name"])
|
||||||
|
assert.Equal(t, "shanghai", m2["json"]["address"])
|
||||||
|
assert.Equal(t, 20, m2["json"]["age"].(int))
|
||||||
|
assert.Equal(t, "token_xxx", m2["header"]["token"])
|
||||||
|
assert.Equal(t, "v1.0", m2["header"]["version"])
|
||||||
|
|
||||||
|
type PointerHeader struct {
|
||||||
|
Ref *string `header:"ref"`
|
||||||
|
}
|
||||||
|
ref := "reference"
|
||||||
|
v3 := struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Address string `json:"address,options=[beijing,shanghai]"`
|
||||||
|
Age int `json:"age"`
|
||||||
|
PointerHeader
|
||||||
|
}{
|
||||||
|
Name: "kevin",
|
||||||
|
Address: "shanghai",
|
||||||
|
Age: 20,
|
||||||
|
PointerHeader: PointerHeader{
|
||||||
|
Ref: &ref,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
m3, err3 := Marshal(v3)
|
||||||
|
assert.Nil(t, err3)
|
||||||
|
assert.Equal(t, "kevin", m3["json"]["name"])
|
||||||
|
assert.Equal(t, "shanghai", m3["json"]["address"])
|
||||||
|
assert.Equal(t, 20, m3["json"]["age"].(int))
|
||||||
|
assert.Equal(t, "reference", *m3["header"]["ref"].(*string))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("bad anonymous", func(t *testing.T) {
|
||||||
|
type BaseHeader struct {
|
||||||
|
Token string `json:"token,options=[a,b]"`
|
||||||
|
}
|
||||||
|
|
||||||
|
v := struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Address string `json:"address,options=[beijing,shanghai]"`
|
||||||
|
Age int `json:"age"`
|
||||||
|
BaseHeader
|
||||||
|
}{
|
||||||
|
Name: "kevin",
|
||||||
|
Address: "shanghai",
|
||||||
|
Age: 20,
|
||||||
|
BaseHeader: BaseHeader{
|
||||||
|
Token: "c",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := Marshal(v)
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestMarshal_Ptr(t *testing.T) {
|
func TestMarshal_Ptr(t *testing.T) {
|
||||||
v := &struct {
|
v := &struct {
|
||||||
Name string `path:"name"`
|
Name string `path:"name"`
|
||||||
@@ -344,3 +462,15 @@ func TestMarshal_FromString(t *testing.T) {
|
|||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, "10", m["json"]["age"].(string))
|
assert.Equal(t, "10", m["json"]["age"].(string))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMarshal_Array(t *testing.T) {
|
||||||
|
v := struct {
|
||||||
|
H [1]int `json:"h,string"`
|
||||||
|
}{
|
||||||
|
H: [1]int{1},
|
||||||
|
}
|
||||||
|
|
||||||
|
m, err := Marshal(v)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "[1]", m["json"]["h"].(string))
|
||||||
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -15,7 +16,6 @@ import (
|
|||||||
"github.com/zeromicro/go-zero/core/jsonx"
|
"github.com/zeromicro/go-zero/core/jsonx"
|
||||||
"github.com/zeromicro/go-zero/core/lang"
|
"github.com/zeromicro/go-zero/core/lang"
|
||||||
"github.com/zeromicro/go-zero/core/proc"
|
"github.com/zeromicro/go-zero/core/proc"
|
||||||
"github.com/zeromicro/go-zero/core/stringx"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -30,7 +30,9 @@ var (
|
|||||||
errValueNotSettable = errors.New("value is not settable")
|
errValueNotSettable = errors.New("value is not settable")
|
||||||
errValueNotStruct = errors.New("value type is not struct")
|
errValueNotStruct = errors.New("value type is not struct")
|
||||||
keyUnmarshaler = NewUnmarshaler(defaultKeyName)
|
keyUnmarshaler = NewUnmarshaler(defaultKeyName)
|
||||||
|
boolType = reflect.TypeOf(false)
|
||||||
durationType = reflect.TypeOf(time.Duration(0))
|
durationType = reflect.TypeOf(time.Duration(0))
|
||||||
|
stringType = reflect.TypeOf("")
|
||||||
cacheKeys = make(map[string][]string)
|
cacheKeys = make(map[string][]string)
|
||||||
cacheKeysLock sync.Mutex
|
cacheKeysLock sync.Mutex
|
||||||
defaultCache = make(map[string]any)
|
defaultCache = make(map[string]any)
|
||||||
@@ -622,9 +624,19 @@ func (u *Unmarshaler) processFieldNotFromString(fieldType reflect.Type, value re
|
|||||||
|
|
||||||
return u.fillSliceFromString(fieldType, value, mapValue, fullName)
|
return u.fillSliceFromString(fieldType, value, mapValue, fullName)
|
||||||
case valueKind == reflect.String && derefedFieldType == durationType:
|
case valueKind == reflect.String && derefedFieldType == durationType:
|
||||||
return fillDurationValue(fieldType, value, mapValue.(string))
|
v, err := convertToString(mapValue, fullName)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return fillDurationValue(fieldType, value, v)
|
||||||
case valueKind == reflect.String && typeKind == reflect.Struct && u.implementsUnmarshaler(fieldType):
|
case valueKind == reflect.String && typeKind == reflect.Struct && u.implementsUnmarshaler(fieldType):
|
||||||
return u.fillUnmarshalerStruct(fieldType, value, mapValue.(string))
|
v, err := convertToString(mapValue, fullName)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return u.fillUnmarshalerStruct(fieldType, value, v)
|
||||||
default:
|
default:
|
||||||
return u.processFieldPrimitive(fieldType, value, mapValue, opts, fullName)
|
return u.processFieldPrimitive(fieldType, value, mapValue, opts, fullName)
|
||||||
}
|
}
|
||||||
@@ -755,24 +767,24 @@ func (u *Unmarshaler) processFieldWithEnvValue(fieldType reflect.Type, value ref
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
fieldKind := fieldType.Kind()
|
derefType := Deref(fieldType)
|
||||||
switch fieldKind {
|
switch derefType {
|
||||||
case reflect.Bool:
|
case boolType:
|
||||||
val, err := strconv.ParseBool(envVal)
|
val, err := strconv.ParseBool(envVal)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unmarshal field %q with environment variable, %w", fullName, err)
|
return fmt.Errorf("unmarshal field %q with environment variable, %w", fullName, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
value.SetBool(val)
|
SetValue(fieldType, value, reflect.ValueOf(val))
|
||||||
return nil
|
return nil
|
||||||
case durationType.Kind():
|
case durationType:
|
||||||
if err := fillDurationValue(fieldType, value, envVal); err != nil {
|
if err := fillDurationValue(fieldType, value, envVal); err != nil {
|
||||||
return fmt.Errorf("unmarshal field %q with environment variable, %w", fullName, err)
|
return fmt.Errorf("unmarshal field %q with environment variable, %w", fullName, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
case reflect.String:
|
case stringType:
|
||||||
value.SetString(envVal)
|
SetValue(fieldType, value, reflect.ValueOf(envVal))
|
||||||
return nil
|
return nil
|
||||||
default:
|
default:
|
||||||
return u.processFieldPrimitiveWithJSONNumber(fieldType, value, json.Number(envVal), opts, fullName)
|
return u.processFieldPrimitiveWithJSONNumber(fieldType, value, json.Number(envVal), opts, fullName)
|
||||||
@@ -894,7 +906,7 @@ func (u *Unmarshaler) processNamedFieldWithValueFromString(fieldType reflect.Typ
|
|||||||
valueKind.String())
|
valueKind.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
if !stringx.Contains(options, checkValue) {
|
if !slices.Contains(options, checkValue) {
|
||||||
return fmt.Errorf(`value "%s" for field %q is not defined in options "%v"`,
|
return fmt.Errorf(`value "%s" for field %q is not defined in options "%v"`,
|
||||||
mapValue, key, options)
|
mapValue, key, options)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -203,6 +203,20 @@ func TestUnmarshalDuration(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUnmarshalDurationUnexpectedError(t *testing.T) {
|
||||||
|
type inner struct {
|
||||||
|
Duration time.Duration `key:"duration"`
|
||||||
|
}
|
||||||
|
content := "{\"duration\": 1}"
|
||||||
|
var m = map[string]any{}
|
||||||
|
err := jsonx.Unmarshal([]byte(content), &m)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
var in inner
|
||||||
|
err = UnmarshalKey(m, &in)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "expect string")
|
||||||
|
}
|
||||||
|
|
||||||
func TestUnmarshalDurationDefault(t *testing.T) {
|
func TestUnmarshalDurationDefault(t *testing.T) {
|
||||||
type inner struct {
|
type inner struct {
|
||||||
Int int `key:"int"`
|
Int int `key:"int"`
|
||||||
@@ -4665,6 +4679,23 @@ func TestUnmarshal_EnvInt(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUnmarshal_EnvInt64(t *testing.T) {
|
||||||
|
type Value struct {
|
||||||
|
Age int64 `key:"age,env=TEST_NAME_INT64"`
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
envName = "TEST_NAME_INT64"
|
||||||
|
envVal = "88"
|
||||||
|
)
|
||||||
|
t.Setenv(envName, envVal)
|
||||||
|
|
||||||
|
var v Value
|
||||||
|
if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
|
||||||
|
assert.Equal(t, int64(88), v.Age)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestUnmarshal_EnvIntOverwrite(t *testing.T) {
|
func TestUnmarshal_EnvIntOverwrite(t *testing.T) {
|
||||||
type Value struct {
|
type Value struct {
|
||||||
Age int `key:"age,env=TEST_NAME_INT"`
|
Age int `key:"age,env=TEST_NAME_INT"`
|
||||||
@@ -4770,20 +4801,33 @@ func TestUnmarshal_EnvBoolBad(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshal_EnvDuration(t *testing.T) {
|
func TestUnmarshal_EnvDuration(t *testing.T) {
|
||||||
type Value struct {
|
|
||||||
Duration time.Duration `key:"duration,env=TEST_NAME_DURATION"`
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
envName = "TEST_NAME_DURATION"
|
envName = "TEST_NAME_DURATION"
|
||||||
envVal = "1s"
|
envVal = "1s"
|
||||||
)
|
)
|
||||||
t.Setenv(envName, envVal)
|
t.Setenv(envName, envVal)
|
||||||
|
|
||||||
|
t.Run("valid duration", func(t *testing.T) {
|
||||||
|
type Value struct {
|
||||||
|
Duration time.Duration `key:"duration,env=TEST_NAME_DURATION"`
|
||||||
|
}
|
||||||
|
|
||||||
var v Value
|
var v Value
|
||||||
if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
|
if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
|
||||||
assert.Equal(t, time.Second, v.Duration)
|
assert.Equal(t, time.Second, v.Duration)
|
||||||
}
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ptr of duration", func(t *testing.T) {
|
||||||
|
type Value struct {
|
||||||
|
Duration *time.Duration `key:"duration,env=TEST_NAME_DURATION"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var v Value
|
||||||
|
if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
|
||||||
|
assert.Equal(t, time.Second, *v.Duration)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshal_EnvDurationBadValue(t *testing.T) {
|
func TestUnmarshal_EnvDurationBadValue(t *testing.T) {
|
||||||
@@ -5995,6 +6039,16 @@ func TestUnmarshal_Unmarshaler(t *testing.T) {
|
|||||||
}, &v))
|
}, &v))
|
||||||
assert.Nil(t, v.Foo)
|
assert.Nil(t, v.Foo)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("json.Number", func(t *testing.T) {
|
||||||
|
v := struct {
|
||||||
|
Foo *mockUnmarshaler `json:"name"`
|
||||||
|
}{}
|
||||||
|
m := map[string]any{
|
||||||
|
"name": json.Number("123"),
|
||||||
|
}
|
||||||
|
assert.Error(t, UnmarshalJsonMap(m, &v))
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseJsonStringValue(t *testing.T) {
|
func TestParseJsonStringValue(t *testing.T) {
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -91,6 +92,15 @@ func ValidatePtr(v reflect.Value) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func convertToString(val any, fullName string) (string, error) {
|
||||||
|
v, ok := val.(string)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("expect string for field %s, but got type %T", fullName, val)
|
||||||
|
}
|
||||||
|
|
||||||
|
return v, nil
|
||||||
|
}
|
||||||
|
|
||||||
func convertTypeFromString(kind reflect.Kind, str string) (any, error) {
|
func convertTypeFromString(kind reflect.Kind, str string) (any, error) {
|
||||||
switch kind {
|
switch kind {
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
@@ -634,11 +644,11 @@ func validateValueInOptions(val any, options []string) error {
|
|||||||
if len(options) > 0 {
|
if len(options) > 0 {
|
||||||
switch v := val.(type) {
|
switch v := val.(type) {
|
||||||
case string:
|
case string:
|
||||||
if !stringx.Contains(options, v) {
|
if !slices.Contains(options, v) {
|
||||||
return fmt.Errorf(`error: value %q is not defined in options "%v"`, v, options)
|
return fmt.Errorf(`error: value %q is not defined in options "%v"`, v, options)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
if !stringx.Contains(options, Repr(v)) {
|
if !slices.Contains(options, Repr(v)) {
|
||||||
return fmt.Errorf(`error: value "%v" is not defined in options "%v"`, val, options)
|
return fmt.Errorf(`error: value "%v" is not defined in options "%v"`, val, options)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,19 +1,13 @@
|
|||||||
package mathx
|
package mathx
|
||||||
|
|
||||||
// MaxInt returns the larger one of a and b.
|
// MaxInt returns the larger one of a and b.
|
||||||
|
// Deprecated: use builtin max instead.
|
||||||
func MaxInt(a, b int) int {
|
func MaxInt(a, b int) int {
|
||||||
if a > b {
|
return max(a, b)
|
||||||
return a
|
|
||||||
}
|
|
||||||
|
|
||||||
return b
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MinInt returns the smaller one of a and b.
|
// MinInt returns the smaller one of a and b.
|
||||||
|
// Deprecated: use builtin min instead.
|
||||||
func MinInt(a, b int) int {
|
func MinInt(a, b int) int {
|
||||||
if a < b {
|
return min(a, b)
|
||||||
return a
|
|
||||||
}
|
|
||||||
|
|
||||||
return b
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,13 +1,12 @@
|
|||||||
package prof
|
package prof
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"fmt"
|
||||||
"strconv"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/olekukonko/tablewriter"
|
|
||||||
"github.com/zeromicro/go-zero/core/logx"
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
"github.com/zeromicro/go-zero/core/threading"
|
"github.com/zeromicro/go-zero/core/threading"
|
||||||
)
|
)
|
||||||
@@ -28,46 +27,15 @@ type (
|
|||||||
|
|
||||||
const flushInterval = 5 * time.Minute
|
const flushInterval = 5 * time.Minute
|
||||||
|
|
||||||
var (
|
var pc = &profileCenter{
|
||||||
pc = &profileCenter{
|
|
||||||
slots: make(map[string]*profileSlot),
|
slots: make(map[string]*profileSlot),
|
||||||
}
|
|
||||||
once sync.Once
|
|
||||||
)
|
|
||||||
|
|
||||||
func report(name string, duration time.Duration) {
|
|
||||||
updated := func() bool {
|
|
||||||
pc.lock.RLock()
|
|
||||||
defer pc.lock.RUnlock()
|
|
||||||
|
|
||||||
slot, ok := pc.slots[name]
|
|
||||||
if ok {
|
|
||||||
atomic.AddInt64(&slot.lifecount, 1)
|
|
||||||
atomic.AddInt64(&slot.lastcount, 1)
|
|
||||||
atomic.AddInt64(&slot.lifecycle, int64(duration))
|
|
||||||
atomic.AddInt64(&slot.lastcycle, int64(duration))
|
|
||||||
}
|
|
||||||
return ok
|
|
||||||
}()
|
|
||||||
|
|
||||||
if !updated {
|
|
||||||
func() {
|
|
||||||
pc.lock.Lock()
|
|
||||||
defer pc.lock.Unlock()
|
|
||||||
|
|
||||||
pc.slots[name] = &profileSlot{
|
|
||||||
lifecount: 1,
|
|
||||||
lastcount: 1,
|
|
||||||
lifecycle: int64(duration),
|
|
||||||
lastcycle: int64(duration),
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
once.Do(flushRepeatly)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func flushRepeatly() {
|
func init() {
|
||||||
|
flushRepeatedly()
|
||||||
|
}
|
||||||
|
|
||||||
|
func flushRepeatedly() {
|
||||||
threading.GoSafe(func() {
|
threading.GoSafe(func() {
|
||||||
for {
|
for {
|
||||||
time.Sleep(flushInterval)
|
time.Sleep(flushInterval)
|
||||||
@@ -76,42 +44,64 @@ func flushRepeatly() {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func report(name string, duration time.Duration) {
|
||||||
|
slot := loadOrStoreSlot(name, duration)
|
||||||
|
|
||||||
|
atomic.AddInt64(&slot.lifecount, 1)
|
||||||
|
atomic.AddInt64(&slot.lastcount, 1)
|
||||||
|
atomic.AddInt64(&slot.lifecycle, int64(duration))
|
||||||
|
atomic.AddInt64(&slot.lastcycle, int64(duration))
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadOrStoreSlot(name string, duration time.Duration) *profileSlot {
|
||||||
|
pc.lock.RLock()
|
||||||
|
slot, ok := pc.slots[name]
|
||||||
|
pc.lock.RUnlock()
|
||||||
|
|
||||||
|
if ok {
|
||||||
|
return slot
|
||||||
|
}
|
||||||
|
|
||||||
|
pc.lock.Lock()
|
||||||
|
defer pc.lock.Unlock()
|
||||||
|
|
||||||
|
// double-check
|
||||||
|
if slot, ok = pc.slots[name]; ok {
|
||||||
|
return slot
|
||||||
|
}
|
||||||
|
|
||||||
|
slot = &profileSlot{}
|
||||||
|
pc.slots[name] = slot
|
||||||
|
return slot
|
||||||
|
}
|
||||||
|
|
||||||
func generateReport() string {
|
func generateReport() string {
|
||||||
var buffer bytes.Buffer
|
var builder strings.Builder
|
||||||
buffer.WriteString("Profiling report\n")
|
builder.WriteString("Profiling report\n")
|
||||||
var data [][]string
|
builder.WriteString("QUEUE,LIFECOUNT,LIFECYCLE,LASTCOUNT,LASTCYCLE\n")
|
||||||
|
|
||||||
calcFn := func(total, count int64) string {
|
calcFn := func(total, count int64) string {
|
||||||
if count == 0 {
|
if count == 0 {
|
||||||
return "-"
|
return "-"
|
||||||
}
|
}
|
||||||
|
|
||||||
return (time.Duration(total) / time.Duration(count)).String()
|
return (time.Duration(total) / time.Duration(count)).String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func() {
|
|
||||||
pc.lock.Lock()
|
pc.lock.Lock()
|
||||||
defer pc.lock.Unlock()
|
|
||||||
|
|
||||||
for key, slot := range pc.slots {
|
for key, slot := range pc.slots {
|
||||||
data = append(data, []string{
|
builder.WriteString(fmt.Sprintf("%s,%d,%s,%d,%s\n",
|
||||||
key,
|
key,
|
||||||
strconv.FormatInt(slot.lifecount, 10),
|
slot.lifecount,
|
||||||
calcFn(slot.lifecycle, slot.lifecount),
|
calcFn(slot.lifecycle, slot.lifecount),
|
||||||
strconv.FormatInt(slot.lastcount, 10),
|
slot.lastcount,
|
||||||
calcFn(slot.lastcycle, slot.lastcount),
|
calcFn(slot.lastcycle, slot.lastcount),
|
||||||
})
|
))
|
||||||
|
|
||||||
// reset the data for last cycle
|
// reset last cycle stats
|
||||||
slot.lastcount = 0
|
atomic.StoreInt64(&slot.lastcount, 0)
|
||||||
slot.lastcycle = 0
|
atomic.StoreInt64(&slot.lastcycle, 0)
|
||||||
}
|
}
|
||||||
}()
|
pc.lock.Unlock()
|
||||||
|
|
||||||
table := tablewriter.NewWriter(&buffer)
|
return builder.String()
|
||||||
table.SetHeader([]string{"QUEUE", "LIFECOUNT", "LIFECYCLE", "LASTCOUNT", "LASTCYCLE"})
|
|
||||||
table.SetBorder(false)
|
|
||||||
table.AppendBulk(data)
|
|
||||||
table.Render()
|
|
||||||
|
|
||||||
return buffer.String()
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestReport(t *testing.T) {
|
func TestReport(t *testing.T) {
|
||||||
once.Do(func() {})
|
|
||||||
assert.NotContains(t, generateReport(), "foo")
|
assert.NotContains(t, generateReport(), "foo")
|
||||||
report("foo", time.Second)
|
report("foo", time.Second)
|
||||||
assert.Contains(t, generateReport(), "foo")
|
assert.Contains(t, generateReport(), "foo")
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"github.com/zeromicro/go-zero/core/stat"
|
"github.com/zeromicro/go-zero/core/stat"
|
||||||
"github.com/zeromicro/go-zero/core/trace"
|
"github.com/zeromicro/go-zero/core/trace"
|
||||||
"github.com/zeromicro/go-zero/internal/devserver"
|
"github.com/zeromicro/go-zero/internal/devserver"
|
||||||
|
"github.com/zeromicro/go-zero/internal/profiling"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -38,6 +39,8 @@ type (
|
|||||||
Telemetry trace.Config `json:",optional"`
|
Telemetry trace.Config `json:",optional"`
|
||||||
DevServer DevServerConfig `json:",optional"`
|
DevServer DevServerConfig `json:",optional"`
|
||||||
Shutdown proc.ShutdownConf `json:",optional"`
|
Shutdown proc.ShutdownConf `json:",optional"`
|
||||||
|
// Profiling is the configuration for continuous profiling.
|
||||||
|
Profiling profiling.Config `json:",optional"`
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -70,7 +73,9 @@ func (sc ServiceConf) SetUp() error {
|
|||||||
if len(sc.MetricsUrl) > 0 {
|
if len(sc.MetricsUrl) > 0 {
|
||||||
stat.SetReportWriter(stat.NewRemoteWriter(sc.MetricsUrl))
|
stat.SetReportWriter(stat.NewRemoteWriter(sc.MetricsUrl))
|
||||||
}
|
}
|
||||||
|
|
||||||
devserver.StartAgent(sc.DevServer)
|
devserver.StartAgent(sc.DevServer)
|
||||||
|
profiling.Start(sc.Profiling)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/core/logx"
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
"github.com/zeromicro/go-zero/core/proc"
|
"github.com/zeromicro/go-zero/core/proc"
|
||||||
"github.com/zeromicro/go-zero/core/syncx"
|
|
||||||
"github.com/zeromicro/go-zero/core/threading"
|
"github.com/zeromicro/go-zero/core/threading"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -35,7 +36,7 @@ type (
|
|||||||
// NewServiceGroup returns a ServiceGroup.
|
// NewServiceGroup returns a ServiceGroup.
|
||||||
func NewServiceGroup() *ServiceGroup {
|
func NewServiceGroup() *ServiceGroup {
|
||||||
sg := new(ServiceGroup)
|
sg := new(ServiceGroup)
|
||||||
sg.stopOnce = syncx.Once(sg.doStop)
|
sg.stopOnce = sync.OnceFunc(sg.doStop)
|
||||||
return sg
|
return sg
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ import (
|
|||||||
const (
|
const (
|
||||||
clusterNameKey = "CLUSTER_NAME"
|
clusterNameKey = "CLUSTER_NAME"
|
||||||
testEnv = "test.v"
|
testEnv = "test.v"
|
||||||
timeFormat = "2006-01-02 15:04:05"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -45,7 +44,7 @@ func Report(msg string) {
|
|||||||
if fn != nil {
|
if fn != nil {
|
||||||
reported := lessExecutor.DoOrDiscard(func() {
|
reported := lessExecutor.DoOrDiscard(func() {
|
||||||
var builder strings.Builder
|
var builder strings.Builder
|
||||||
builder.WriteString(fmt.Sprintln(time.Now().Format(timeFormat)))
|
builder.WriteString(fmt.Sprintln(time.Now().Format(time.DateTime)))
|
||||||
if len(clusterName) > 0 {
|
if len(clusterName) > 0 {
|
||||||
builder.WriteString(fmt.Sprintf("cluster: %s\n", clusterName))
|
builder.WriteString(fmt.Sprintf("cluster: %s\n", clusterName))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -609,6 +609,28 @@ func (s *Redis) GetBitCtx(ctx context.Context, key string, offset int64) (int, e
|
|||||||
return int(v), nil
|
return int(v), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetDel is the implementation of redis getdel command.
|
||||||
|
// Available since: redis version 6.2.0
|
||||||
|
func (s *Redis) GetDel(key string) (string, error) {
|
||||||
|
return s.GetDelCtx(context.Background(), key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDelCtx is the implementation of redis getdel command.
|
||||||
|
// Available since: redis version 6.2.0
|
||||||
|
func (s *Redis) GetDelCtx(ctx context.Context, key string) (string, error) {
|
||||||
|
conn, err := getRedis(s)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
val, err := conn.GetDel(ctx, key).Result()
|
||||||
|
if errors.Is(err, red.Nil) {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return val, err
|
||||||
|
}
|
||||||
|
|
||||||
// GetSet is the implementation of redis getset command.
|
// GetSet is the implementation of redis getset command.
|
||||||
func (s *Redis) GetSet(key, value string) (string, error) {
|
func (s *Redis) GetSet(key, value string) (string, error) {
|
||||||
return s.GetSetCtx(context.Background(), key, value)
|
return s.GetSetCtx(context.Background(), key, value)
|
||||||
|
|||||||
@@ -1071,6 +1071,34 @@ func TestRedis_Set(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRedis_GetDel(t *testing.T) {
|
||||||
|
t.Run("get_del", func(t *testing.T) {
|
||||||
|
runOnRedis(t, func(client *Redis) {
|
||||||
|
val, err := newRedis(client.Addr).GetDel("hello")
|
||||||
|
assert.Equal(t, "", val)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
err = client.Set("hello", "world")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
val, err = client.Get("hello")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "world", val)
|
||||||
|
val, err = client.GetDel("hello")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "world", val)
|
||||||
|
val, err = client.Get("hello")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "", val)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("get_del_with_error", func(t *testing.T) {
|
||||||
|
runOnRedisWithError(t, func(client *Redis) {
|
||||||
|
_, err := newRedis(client.Addr, badType()).GetDel("hello")
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestRedis_GetSet(t *testing.T) {
|
func TestRedis_GetSet(t *testing.T) {
|
||||||
t.Run("set_get", func(t *testing.T) {
|
t.Run("set_get", func(t *testing.T) {
|
||||||
runOnRedis(t, func(client *Redis) {
|
runOnRedis(t, func(client *Redis) {
|
||||||
|
|||||||
29
core/stores/sqlx/config.go
Normal file
29
core/stores/sqlx/config.go
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
package sqlx
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
var (
|
||||||
|
errEmptyDatasource = errors.New("empty datasource")
|
||||||
|
errEmptyDriverName = errors.New("empty driver name")
|
||||||
|
)
|
||||||
|
|
||||||
|
// SqlConf defines the configuration for sqlx.
|
||||||
|
type SqlConf struct {
|
||||||
|
DataSource string
|
||||||
|
DriverName string `json:",default=mysql"`
|
||||||
|
Replicas []string `json:",optional"`
|
||||||
|
Policy string `json:",default=round-robin,options=round-robin|random"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate validates the SqlxConf.
|
||||||
|
func (sc SqlConf) Validate() error {
|
||||||
|
if len(sc.DataSource) == 0 {
|
||||||
|
return errEmptyDatasource
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(sc.DriverName) == 0 {
|
||||||
|
return errEmptyDriverName
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
29
core/stores/sqlx/config_test.go
Normal file
29
core/stores/sqlx/config_test.go
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
package sqlx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/zeromicro/go-zero/core/conf"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestValidate(t *testing.T) {
|
||||||
|
text := []byte(`DataSource: primary:password@tcp(127.0.0.1:3306)/primary_db
|
||||||
|
`)
|
||||||
|
|
||||||
|
var sc SqlConf
|
||||||
|
err := conf.LoadFromYamlBytes(text, &sc)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "mysql", sc.DriverName)
|
||||||
|
assert.Equal(t, policyRoundRobin, sc.Policy)
|
||||||
|
assert.Nil(t, sc.Validate())
|
||||||
|
|
||||||
|
sc = SqlConf{}
|
||||||
|
assert.Equal(t, errEmptyDatasource, sc.Validate())
|
||||||
|
|
||||||
|
sc.DataSource = "primary:password@tcp(127.0.0.1:3306)/primary_db"
|
||||||
|
assert.Equal(t, errEmptyDriverName, sc.Validate())
|
||||||
|
|
||||||
|
sc.DriverName = "mysql"
|
||||||
|
assert.Nil(t, sc.Validate())
|
||||||
|
}
|
||||||
@@ -267,6 +267,20 @@ func TestUnmarshalRowStruct(t *testing.T) {
|
|||||||
}, "select name, age from users where user=?", "anyone"), ErrNotMatchDestination)
|
}, "select name, age from users where user=?", "anyone"), ErrNotMatchDestination)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
value := new(struct {
|
||||||
|
Name string
|
||||||
|
age int
|
||||||
|
})
|
||||||
|
|
||||||
|
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
||||||
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
|
||||||
|
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRow(value, rows, true)
|
||||||
|
}, "select name, age from users where user=?", "anyone"), ErrNotMatchDestination)
|
||||||
|
})
|
||||||
|
|
||||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("8")
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("8")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
@@ -310,6 +324,20 @@ func TestUnmarshalRowStructWithTags(t *testing.T) {
|
|||||||
}, "select name, age from users where user=?", "anyone"), ErrNotReadableValue)
|
}, "select name, age from users where user=?", "anyone"), ErrNotReadableValue)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
value := new(struct {
|
||||||
|
age int `db:"age"`
|
||||||
|
Name string `db:"name"`
|
||||||
|
})
|
||||||
|
|
||||||
|
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
||||||
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
|
||||||
|
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRow(value, rows, true)
|
||||||
|
}, "select name, age from users where user=?", "anyone"), ErrNotReadableValue)
|
||||||
|
})
|
||||||
|
|
||||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
var value struct {
|
var value struct {
|
||||||
Age *int `db:"age"`
|
Age *int `db:"age"`
|
||||||
@@ -1307,6 +1335,7 @@ func TestAnonymousStructPr(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAnonymousStructPrError(t *testing.T) {
|
func TestAnonymousStructPrError(t *testing.T) {
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
type Score struct {
|
type Score struct {
|
||||||
Discipline string `db:"discipline"`
|
Discipline string `db:"discipline"`
|
||||||
score uint `db:"score"`
|
score uint `db:"score"`
|
||||||
@@ -1319,13 +1348,12 @@ func TestAnonymousStructPrError(t *testing.T) {
|
|||||||
*ClassType
|
*ClassType
|
||||||
Score
|
Score
|
||||||
}
|
}
|
||||||
|
|
||||||
var value []*struct {
|
var value []*struct {
|
||||||
Age int64 `db:"age"`
|
Age int64 `db:"age"`
|
||||||
Class
|
Class
|
||||||
Name string `db:"name"`
|
Name string `db:"name"`
|
||||||
}
|
}
|
||||||
|
|
||||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
|
||||||
rs := sqlmock.NewRows([]string{
|
rs := sqlmock.NewRows([]string{
|
||||||
"name",
|
"name",
|
||||||
"age",
|
"age",
|
||||||
@@ -1338,10 +1366,50 @@ func TestAnonymousStructPrError(t *testing.T) {
|
|||||||
AddRow("second", 3, "grade one", "chinese", "class three grade two", 99)
|
AddRow("second", 3, "grade one", "chinese", "class three grade two", 99)
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").
|
mock.ExpectQuery("select (.+) from users where user=?").
|
||||||
WithArgs("anyone").WillReturnRows(rs)
|
WithArgs("anyone").WillReturnRows(rs)
|
||||||
assert.Error(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
return unmarshalRows(&value, rows, true)
|
return unmarshalRows(&value, rows, true)
|
||||||
}, "select name, age, grade, discipline, class_name, score from users where user=?",
|
}, "select name, age, grade, discipline, class_name, score from users where user=?",
|
||||||
"anyone"))
|
"anyone"), ErrNotReadableValue)
|
||||||
|
if len(value) > 0 {
|
||||||
|
assert.Equal(t, value[0].score, 0)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
type Score struct {
|
||||||
|
Discipline string
|
||||||
|
score uint
|
||||||
|
}
|
||||||
|
type ClassType struct {
|
||||||
|
Grade sql.NullString
|
||||||
|
ClassName *string
|
||||||
|
}
|
||||||
|
type Class struct {
|
||||||
|
*ClassType
|
||||||
|
Score
|
||||||
|
}
|
||||||
|
|
||||||
|
var value []*struct {
|
||||||
|
Age int64
|
||||||
|
Class
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
rs := sqlmock.NewRows([]string{
|
||||||
|
"name",
|
||||||
|
"age",
|
||||||
|
"grade",
|
||||||
|
"discipline",
|
||||||
|
"class_name",
|
||||||
|
"score",
|
||||||
|
}).
|
||||||
|
AddRow("first", 2, nil, "math", "experimental class", 100).
|
||||||
|
AddRow("second", 3, "grade one", "chinese", "class three grade two", 99)
|
||||||
|
mock.ExpectQuery("select (.+) from users where user=?").
|
||||||
|
WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRows(&value, rows, true)
|
||||||
|
}, "select name, age, grade, discipline, class_name, score from users where user=?",
|
||||||
|
"anyone"), ErrNotMatchDestination)
|
||||||
if len(value) > 0 {
|
if len(value) > 0 {
|
||||||
assert.Equal(t, value[0].score, 0)
|
assert.Equal(t, value[0].score, 0)
|
||||||
}
|
}
|
||||||
|
|||||||
65
core/stores/sqlx/rwstrategy.go
Normal file
65
core/stores/sqlx/rwstrategy.go
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
package sqlx
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
const (
|
||||||
|
// policyRoundRobin round-robin policy for selecting replicas.
|
||||||
|
policyRoundRobin = "round-robin"
|
||||||
|
// policyRandom random policy for selecting replicas.
|
||||||
|
policyRandom = "random"
|
||||||
|
|
||||||
|
// readPrimaryMode indicates that the operation is a read,
|
||||||
|
// but should be performed on the primary database instance.
|
||||||
|
//
|
||||||
|
// This mode is used in scenarios where data freshness and consistency are critical,
|
||||||
|
// such as immediately after writes or where replication lag may cause stale reads.
|
||||||
|
readPrimaryMode readWriteMode = "read-primary"
|
||||||
|
|
||||||
|
// readReplicaMode indicates that the operation is a read from replicas.
|
||||||
|
// This is suitable for scenarios where eventual consistency is acceptable,
|
||||||
|
// and the goal is to offload traffic from the primary and improve read scalability.
|
||||||
|
readReplicaMode readWriteMode = "read-replica"
|
||||||
|
|
||||||
|
// writeMode indicates that the operation is a write operation (to primary).
|
||||||
|
writeMode readWriteMode = "write"
|
||||||
|
|
||||||
|
// notSpecifiedMode indicates that the read/write mode is not specified.
|
||||||
|
notSpecifiedMode readWriteMode = ""
|
||||||
|
)
|
||||||
|
|
||||||
|
type readWriteModeKey struct{}
|
||||||
|
|
||||||
|
// WithReadPrimary sets the context to read-primary mode.
|
||||||
|
func WithReadPrimary(ctx context.Context) context.Context {
|
||||||
|
return context.WithValue(ctx, readWriteModeKey{}, readPrimaryMode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithReadReplica sets the context to read-replica mode.
|
||||||
|
func WithReadReplica(ctx context.Context) context.Context {
|
||||||
|
return context.WithValue(ctx, readWriteModeKey{}, readReplicaMode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithWrite sets the context to write mode, indicating that the operation is a write operation.
|
||||||
|
func WithWrite(ctx context.Context) context.Context {
|
||||||
|
return context.WithValue(ctx, readWriteModeKey{}, writeMode)
|
||||||
|
}
|
||||||
|
|
||||||
|
type readWriteMode string
|
||||||
|
|
||||||
|
func (m readWriteMode) isValid() bool {
|
||||||
|
return m == readPrimaryMode || m == readReplicaMode || m == writeMode
|
||||||
|
}
|
||||||
|
|
||||||
|
func getReadWriteMode(ctx context.Context) readWriteMode {
|
||||||
|
if mode := ctx.Value(readWriteModeKey{}); mode != nil {
|
||||||
|
if v, ok := mode.(readWriteMode); ok && v.isValid() {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return notSpecifiedMode
|
||||||
|
}
|
||||||
|
|
||||||
|
func usePrimary(ctx context.Context) bool {
|
||||||
|
return getReadWriteMode(ctx) != readReplicaMode
|
||||||
|
}
|
||||||
142
core/stores/sqlx/rwstrategy_test.go
Normal file
142
core/stores/sqlx/rwstrategy_test.go
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
package sqlx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIsValid(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
mode readWriteMode
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid read-primary mode",
|
||||||
|
mode: readPrimaryMode,
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid read-replica mode",
|
||||||
|
mode: readReplicaMode,
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid write mode",
|
||||||
|
mode: writeMode,
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "not specified mode (empty)",
|
||||||
|
mode: notSpecifiedMode,
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid custom string",
|
||||||
|
mode: readWriteMode("delete"),
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "case sensitive check",
|
||||||
|
mode: readWriteMode("READ"),
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
actual := tc.mode.isValid()
|
||||||
|
assert.Equal(t, tc.expected, actual)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithReadMode(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
readPrimaryCtx := WithReadPrimary(ctx)
|
||||||
|
|
||||||
|
val := readPrimaryCtx.Value(readWriteModeKey{})
|
||||||
|
assert.Equal(t, readPrimaryMode, val)
|
||||||
|
|
||||||
|
readReplicaCtx := WithReadReplica(ctx)
|
||||||
|
val = readReplicaCtx.Value(readWriteModeKey{})
|
||||||
|
assert.Equal(t, readReplicaMode, val)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithWriteMode(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
writeCtx := WithWrite(ctx)
|
||||||
|
|
||||||
|
val := writeCtx.Value(readWriteModeKey{})
|
||||||
|
assert.Equal(t, writeMode, val)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetReadWriteMode(t *testing.T) {
|
||||||
|
t.Run("valid read-primary mode", func(t *testing.T) {
|
||||||
|
ctx := context.WithValue(context.Background(), readWriteModeKey{}, readPrimaryMode)
|
||||||
|
assert.Equal(t, readPrimaryMode, getReadWriteMode(ctx))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("valid read-replica mode", func(t *testing.T) {
|
||||||
|
ctx := context.WithValue(context.Background(), readWriteModeKey{}, readReplicaMode)
|
||||||
|
assert.Equal(t, readReplicaMode, getReadWriteMode(ctx))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("valid write mode", func(t *testing.T) {
|
||||||
|
ctx := context.WithValue(context.Background(), readWriteModeKey{}, writeMode)
|
||||||
|
assert.Equal(t, writeMode, getReadWriteMode(ctx))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid mode value (wrong type)", func(t *testing.T) {
|
||||||
|
ctx := context.WithValue(context.Background(), readWriteModeKey{}, "not-a-mode")
|
||||||
|
assert.Equal(t, notSpecifiedMode, getReadWriteMode(ctx))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid mode value (wrong value)", func(t *testing.T) {
|
||||||
|
ctx := context.WithValue(context.Background(), readWriteModeKey{}, readWriteMode("delete"))
|
||||||
|
assert.Equal(t, notSpecifiedMode, getReadWriteMode(ctx))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("no mode set", func(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
assert.Equal(t, notSpecifiedMode, getReadWriteMode(ctx))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsePrimary(t *testing.T) {
|
||||||
|
t.Run("context with read-replica mode", func(t *testing.T) {
|
||||||
|
ctx := context.WithValue(context.Background(), readWriteModeKey{}, readReplicaMode)
|
||||||
|
assert.False(t, usePrimary(ctx))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("context with read-primary mode", func(t *testing.T) {
|
||||||
|
ctx := context.WithValue(context.Background(), readWriteModeKey{}, readPrimaryMode)
|
||||||
|
assert.True(t, usePrimary(ctx))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("context with write mode", func(t *testing.T) {
|
||||||
|
ctx := context.WithValue(context.Background(), readWriteModeKey{}, writeMode)
|
||||||
|
assert.True(t, usePrimary(ctx))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("context with invalid mode", func(t *testing.T) {
|
||||||
|
ctx := context.WithValue(context.Background(), readWriteModeKey{}, readWriteMode("invalid"))
|
||||||
|
assert.True(t, usePrimary(ctx))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("context with no mode set", func(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
assert.True(t, usePrimary(ctx))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithModeTwice(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
ctx = WithReadPrimary(ctx)
|
||||||
|
writeCtx := WithWrite(ctx)
|
||||||
|
|
||||||
|
val := writeCtx.Value(readWriteModeKey{})
|
||||||
|
assert.Equal(t, writeMode, val)
|
||||||
|
}
|
||||||
@@ -4,6 +4,9 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/core/breaker"
|
"github.com/zeromicro/go-zero/core/breaker"
|
||||||
"github.com/zeromicro/go-zero/core/errorx"
|
"github.com/zeromicro/go-zero/core/errorx"
|
||||||
@@ -52,9 +55,10 @@ type (
|
|||||||
beginTx beginnable
|
beginTx beginnable
|
||||||
brk breaker.Breaker
|
brk breaker.Breaker
|
||||||
accept breaker.Acceptable
|
accept breaker.Acceptable
|
||||||
|
index uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
connProvider func() (*sql.DB, error)
|
connProvider func(ctx context.Context) (*sql.DB, error)
|
||||||
|
|
||||||
sessionConn interface {
|
sessionConn interface {
|
||||||
Exec(query string, args ...any) (sql.Result, error)
|
Exec(query string, args ...any) (sql.Result, error)
|
||||||
@@ -64,10 +68,41 @@ type (
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// MustNewConn returns a SqlConn with the given SqlConf.
|
||||||
|
func MustNewConn(c SqlConf, opts ...SqlOption) SqlConn {
|
||||||
|
conn, err := NewConn(c, opts...)
|
||||||
|
if err != nil {
|
||||||
|
logx.Must(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return conn
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewConn returns a SqlConn with the given SqlConf.
|
||||||
|
func NewConn(c SqlConf, opts ...SqlOption) (SqlConn, error) {
|
||||||
|
if err := c.Validate(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
conn := &commonSqlConn{
|
||||||
|
onError: func(ctx context.Context, err error) {
|
||||||
|
logInstanceError(ctx, c.DataSource, err)
|
||||||
|
},
|
||||||
|
beginTx: begin,
|
||||||
|
brk: breaker.NewBreaker(),
|
||||||
|
}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(conn)
|
||||||
|
}
|
||||||
|
conn.connProv = getConnProvider(conn, c.DriverName, c.DataSource, c.Policy, c.Replicas)
|
||||||
|
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
// NewSqlConn returns a SqlConn with given driver name and datasource.
|
// NewSqlConn returns a SqlConn with given driver name and datasource.
|
||||||
func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
|
func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
|
||||||
conn := &commonSqlConn{
|
conn := &commonSqlConn{
|
||||||
connProv: func() (*sql.DB, error) {
|
connProv: func(context.Context) (*sql.DB, error) {
|
||||||
return getSqlConn(driverName, datasource)
|
return getSqlConn(driverName, datasource)
|
||||||
},
|
},
|
||||||
onError: func(ctx context.Context, err error) {
|
onError: func(ctx context.Context, err error) {
|
||||||
@@ -87,7 +122,7 @@ func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
|
|||||||
// Use it with caution; it's provided for other ORM to interact with.
|
// Use it with caution; it's provided for other ORM to interact with.
|
||||||
func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn {
|
func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn {
|
||||||
conn := &commonSqlConn{
|
conn := &commonSqlConn{
|
||||||
connProv: func() (*sql.DB, error) {
|
connProv: func(ctx context.Context) (*sql.DB, error) {
|
||||||
return db, nil
|
return db, nil
|
||||||
},
|
},
|
||||||
onError: func(ctx context.Context, err error) {
|
onError: func(ctx context.Context, err error) {
|
||||||
@@ -123,7 +158,7 @@ func (db *commonSqlConn) ExecCtx(ctx context.Context, q string, args ...any) (
|
|||||||
|
|
||||||
err = db.brk.DoWithAcceptableCtx(ctx, func() error {
|
err = db.brk.DoWithAcceptableCtx(ctx, func() error {
|
||||||
var conn *sql.DB
|
var conn *sql.DB
|
||||||
conn, err = db.connProv()
|
conn, err = db.connProv(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
db.onError(ctx, err)
|
db.onError(ctx, err)
|
||||||
return err
|
return err
|
||||||
@@ -151,7 +186,7 @@ func (db *commonSqlConn) PrepareCtx(ctx context.Context, query string) (stmt Stm
|
|||||||
|
|
||||||
err = db.brk.DoWithAcceptableCtx(ctx, func() error {
|
err = db.brk.DoWithAcceptableCtx(ctx, func() error {
|
||||||
var conn *sql.DB
|
var conn *sql.DB
|
||||||
conn, err = db.connProv()
|
conn, err = db.connProv(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
db.onError(ctx, err)
|
db.onError(ctx, err)
|
||||||
return err
|
return err
|
||||||
@@ -242,7 +277,7 @@ func (db *commonSqlConn) QueryRowsPartialCtx(ctx context.Context, v any,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (db *commonSqlConn) RawDB() (*sql.DB, error) {
|
func (db *commonSqlConn) RawDB() (*sql.DB, error) {
|
||||||
return db.connProv()
|
return db.connProv(context.Background())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *commonSqlConn) Transact(fn func(Session) error) error {
|
func (db *commonSqlConn) Transact(fn func(Session) error) error {
|
||||||
@@ -288,7 +323,7 @@ func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows)
|
|||||||
q string, args ...any) (err error) {
|
q string, args ...any) (err error) {
|
||||||
var scanFailed bool
|
var scanFailed bool
|
||||||
err = db.brk.DoWithAcceptableCtx(ctx, func() error {
|
err = db.brk.DoWithAcceptableCtx(ctx, func() error {
|
||||||
conn, err := db.connProv()
|
conn, err := db.connProv(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
db.onError(ctx, err)
|
db.onError(ctx, err)
|
||||||
return err
|
return err
|
||||||
@@ -311,6 +346,38 @@ func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows)
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getConnProvider(sc *commonSqlConn, driverName, datasource, policy string, replicas []string) connProvider {
|
||||||
|
return func(ctx context.Context) (*sql.DB, error) {
|
||||||
|
replicaCount := len(replicas)
|
||||||
|
|
||||||
|
if replicaCount == 0 || usePrimary(ctx) {
|
||||||
|
return getSqlConn(driverName, datasource)
|
||||||
|
}
|
||||||
|
|
||||||
|
var dsn string
|
||||||
|
|
||||||
|
if replicaCount == 1 {
|
||||||
|
dsn = replicas[0]
|
||||||
|
} else {
|
||||||
|
if len(policy) == 0 {
|
||||||
|
policy = policyRoundRobin
|
||||||
|
}
|
||||||
|
|
||||||
|
switch policy {
|
||||||
|
case policyRandom:
|
||||||
|
dsn = replicas[rand.Intn(replicaCount)]
|
||||||
|
case policyRoundRobin:
|
||||||
|
index := atomic.AddUint32(&sc.index, 1) - 1
|
||||||
|
dsn = replicas[index%uint32(replicaCount)]
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unknown policy: %s", policy)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return getSqlConn(driverName, dsn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// WithAcceptable returns a SqlOption that setting the acceptable function.
|
// WithAcceptable returns a SqlOption that setting the acceptable function.
|
||||||
// acceptable is the func to check if the error can be accepted.
|
// acceptable is the func to check if the error can be accepted.
|
||||||
func WithAcceptable(acceptable func(err error) bool) SqlOption {
|
func WithAcceptable(acceptable func(err error) bool) SqlOption {
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package sqlx
|
package sqlx
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
@@ -98,7 +99,7 @@ func TestSqlConn_RawDB(t *testing.T) {
|
|||||||
func TestSqlConn_Errors(t *testing.T) {
|
func TestSqlConn_Errors(t *testing.T) {
|
||||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
conn := NewSqlConnFromDB(db)
|
conn := NewSqlConnFromDB(db)
|
||||||
conn.(*commonSqlConn).connProv = func() (*sql.DB, error) {
|
conn.(*commonSqlConn).connProv = func(ctx context.Context) (*sql.DB, error) {
|
||||||
return nil, errors.New("error")
|
return nil, errors.New("error")
|
||||||
}
|
}
|
||||||
_, err := conn.Prepare("any")
|
_, err := conn.Prepare("any")
|
||||||
@@ -138,6 +139,148 @@ func TestSqlConn_Errors(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConfigSqlConn(t *testing.T) {
|
||||||
|
db, mock, err := sqlmock.New()
|
||||||
|
assert.NotNil(t, db)
|
||||||
|
assert.NotNil(t, mock)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
connManager.Inject(mockedDatasource, db)
|
||||||
|
mock.ExpectExec("any")
|
||||||
|
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"foo"}))
|
||||||
|
|
||||||
|
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
|
||||||
|
conn := MustNewConn(conf, withMysqlAcceptable())
|
||||||
|
|
||||||
|
_, err = conn.Exec("any", "value")
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
_, err = conn.Prepare("any")
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
|
||||||
|
var val string
|
||||||
|
assert.NotNil(t, conn.QueryRow(&val, "any"))
|
||||||
|
assert.NotNil(t, conn.QueryRowPartial(&val, "any"))
|
||||||
|
assert.NotNil(t, conn.QueryRows(&val, "any"))
|
||||||
|
assert.NotNil(t, conn.QueryRowsPartial(&val, "any"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigSqlConnStatement(t *testing.T) {
|
||||||
|
db, mock, err := sqlmock.New()
|
||||||
|
assert.NotNil(t, db)
|
||||||
|
assert.NotNil(t, mock)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
connManager.Inject(mockedDatasource, db)
|
||||||
|
|
||||||
|
mock.ExpectPrepare("any")
|
||||||
|
mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
|
||||||
|
mock.ExpectPrepare("any")
|
||||||
|
row := sqlmock.NewRows([]string{"foo"}).AddRow("bar")
|
||||||
|
mock.ExpectQuery("any").WillReturnRows(row)
|
||||||
|
|
||||||
|
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
|
||||||
|
conn := MustNewConn(conf, withMysqlAcceptable())
|
||||||
|
stmt, err := conn.Prepare("any")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
res, err := stmt.Exec()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
lastInsertID, err := res.LastInsertId()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, int64(2), lastInsertID)
|
||||||
|
rowsAffected, err := res.RowsAffected()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, int64(3), rowsAffected)
|
||||||
|
|
||||||
|
stmt, err = conn.Prepare("any")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
var val string
|
||||||
|
err = stmt.QueryRow(&val)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "bar", val)
|
||||||
|
|
||||||
|
mock.ExpectPrepare("any")
|
||||||
|
rows := sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar")
|
||||||
|
mock.ExpectQuery("any").WillReturnRows(rows)
|
||||||
|
|
||||||
|
stmt, err = conn.Prepare("any")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
var vals []string
|
||||||
|
assert.NoError(t, stmt.QueryRowsPartial(&vals))
|
||||||
|
assert.ElementsMatch(t, []string{"foo", "bar"}, vals)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigSqlConnQuery(t *testing.T) {
|
||||||
|
db, mock, err := sqlmock.New()
|
||||||
|
assert.NotNil(t, db)
|
||||||
|
assert.NotNil(t, mock)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
connManager.Inject(mockedDatasource, db)
|
||||||
|
|
||||||
|
t.Run("QueryRow", func(t *testing.T) {
|
||||||
|
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"foo"}).AddRow("bar"))
|
||||||
|
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
|
||||||
|
conn := MustNewConn(conf)
|
||||||
|
var val string
|
||||||
|
assert.NoError(t, conn.QueryRow(&val, "any"))
|
||||||
|
assert.Equal(t, "bar", val)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("QueryRowPartial", func(t *testing.T) {
|
||||||
|
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"foo"}).AddRow("bar"))
|
||||||
|
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
|
||||||
|
conn := MustNewConn(conf)
|
||||||
|
var val string
|
||||||
|
assert.NoError(t, conn.QueryRowPartial(&val, "any"))
|
||||||
|
assert.Equal(t, "bar", val)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("QueryRows", func(t *testing.T) {
|
||||||
|
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar"))
|
||||||
|
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
|
||||||
|
conn := MustNewConn(conf)
|
||||||
|
var vals []string
|
||||||
|
assert.NoError(t, conn.QueryRows(&vals, "any"))
|
||||||
|
assert.ElementsMatch(t, []string{"foo", "bar"}, vals)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("QueryRowsPartial", func(t *testing.T) {
|
||||||
|
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar"))
|
||||||
|
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
|
||||||
|
conn := MustNewConn(conf)
|
||||||
|
var vals []string
|
||||||
|
assert.NoError(t, conn.QueryRowsPartial(&vals, "any"))
|
||||||
|
assert.ElementsMatch(t, []string{"foo", "bar"}, vals)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigSqlConnErr(t *testing.T) {
|
||||||
|
t.Run("panic on empty config", func(t *testing.T) {
|
||||||
|
original := logx.ExitOnFatal.True()
|
||||||
|
logx.ExitOnFatal.Set(false)
|
||||||
|
defer logx.ExitOnFatal.Set(original)
|
||||||
|
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
MustNewConn(SqlConf{})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
t.Run("on error", func(t *testing.T) {
|
||||||
|
db, mock, err := sqlmock.New()
|
||||||
|
assert.NotNil(t, db)
|
||||||
|
assert.NotNil(t, mock)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
connManager.Inject(mockedDatasource, db)
|
||||||
|
|
||||||
|
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
|
||||||
|
conn := MustNewConn(conf)
|
||||||
|
conn.(*commonSqlConn).connProv = func(ctx context.Context) (*sql.DB, error) {
|
||||||
|
return nil, errors.New("error")
|
||||||
|
}
|
||||||
|
_, err = conn.Prepare("any")
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestStatement(t *testing.T) {
|
func TestStatement(t *testing.T) {
|
||||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
mock.ExpectPrepare("any").WillBeClosed()
|
mock.ExpectPrepare("any").WillBeClosed()
|
||||||
@@ -303,6 +446,93 @@ func TestWithAcceptable(t *testing.T) {
|
|||||||
assert.True(t, conn.accept(acceptableErr3))
|
assert.True(t, conn.accept(acceptableErr3))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProvider(t *testing.T) {
|
||||||
|
defer func() {
|
||||||
|
_ = connManager.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
primaryDSN := "primary:password@tcp(127.0.0.1:3306)/primary_db"
|
||||||
|
replicasDSN := []string{
|
||||||
|
"replica_one:pwd@tcp(localhost:3306)/replica_one",
|
||||||
|
"replica_two:pwd@tcp(localhost:3306)/replica_two",
|
||||||
|
"replica_three:pwd@tcp(localhost:3306)/replica_three",
|
||||||
|
}
|
||||||
|
|
||||||
|
primaryDB, err := connManager.GetResource(primaryDSN, func() (io.Closer, error) { return sql.Open(mysqlDriverName, primaryDSN) })
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.NotNil(t, primaryDB)
|
||||||
|
replicaOneDB, err := connManager.GetResource(replicasDSN[0], func() (io.Closer, error) { return sql.Open(mysqlDriverName, replicasDSN[0]) })
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.NotNil(t, replicaOneDB)
|
||||||
|
replicaTwoDB, err := connManager.GetResource(replicasDSN[1], func() (io.Closer, error) { return sql.Open(mysqlDriverName, replicasDSN[1]) })
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.NotNil(t, replicaTwoDB)
|
||||||
|
replicaThreeDB, err := connManager.GetResource(replicasDSN[2], func() (io.Closer, error) { return sql.Open(mysqlDriverName, replicasDSN[2]) })
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.NotNil(t, replicaThreeDB)
|
||||||
|
|
||||||
|
sc := &commonSqlConn{}
|
||||||
|
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRoundRobin, nil)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
db, err := sc.connProv(ctx)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, primaryDB, db)
|
||||||
|
|
||||||
|
ctx = WithWrite(ctx)
|
||||||
|
db, err = sc.connProv(ctx)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, primaryDB, db)
|
||||||
|
|
||||||
|
ctx = WithReadPrimary(ctx)
|
||||||
|
db, err = sc.connProv(ctx)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, primaryDB, db)
|
||||||
|
|
||||||
|
// no mode set, should return primary
|
||||||
|
ctx = context.Background()
|
||||||
|
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRoundRobin, replicasDSN)
|
||||||
|
db, err = sc.connProv(ctx)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, primaryDB, db)
|
||||||
|
|
||||||
|
ctx = WithReadReplica(ctx)
|
||||||
|
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRoundRobin, []string{replicasDSN[0]})
|
||||||
|
db, err = sc.connProv(ctx)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, replicaOneDB, db)
|
||||||
|
|
||||||
|
// default policy is round-robin
|
||||||
|
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRoundRobin, replicasDSN)
|
||||||
|
replicas := []io.Closer{replicaOneDB, replicaTwoDB, replicaThreeDB}
|
||||||
|
for i := 0; i < len(replicasDSN); i++ {
|
||||||
|
db, err = sc.connProv(ctx)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, replicas[i], db)
|
||||||
|
}
|
||||||
|
|
||||||
|
// random policy
|
||||||
|
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRandom, replicasDSN)
|
||||||
|
for i := 0; i < len(replicasDSN); i++ {
|
||||||
|
db, err = sc.connProv(ctx)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Contains(t, replicas, db)
|
||||||
|
}
|
||||||
|
|
||||||
|
// unknown policy
|
||||||
|
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, "unknown", replicasDSN)
|
||||||
|
_, err = sc.connProv(ctx)
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
|
||||||
|
// empty policy transforms to round-robin
|
||||||
|
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, "", replicasDSN)
|
||||||
|
for i := 0; i < len(replicasDSN); i++ {
|
||||||
|
db, err = sc.connProv(ctx)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, replicas[i], db)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func buildConn() (mock sqlmock.Sqlmock, err error) {
|
func buildConn() (mock sqlmock.Sqlmock, err error) {
|
||||||
_, err = connManager.GetResource(mockedDatasource, func() (io.Closer, error) {
|
_, err = connManager.GetResource(mockedDatasource, func() (io.Closer, error) {
|
||||||
var db *sql.DB
|
var db *sql.DB
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ func getCachedSqlConn(driverName, server string) (*sql.DB, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if driverName != mysqlDriverName {
|
if driverName == mysqlDriverName {
|
||||||
if cfg, e := mysql.ParseDSN(server); e != nil {
|
if cfg, e := mysql.ParseDSN(server); e != nil {
|
||||||
// if cannot parse, don't collect the metrics
|
// if cannot parse, don't collect the metrics
|
||||||
logx.Error(e)
|
logx.Error(e)
|
||||||
|
|||||||
@@ -156,7 +156,7 @@ func begin(db *sql.DB) (trans, error) {
|
|||||||
|
|
||||||
func transact(ctx context.Context, db *commonSqlConn, b beginnable,
|
func transact(ctx context.Context, db *commonSqlConn, b beginnable,
|
||||||
fn func(context.Context, Session) error) (err error) {
|
fn func(context.Context, Session) error) (err error) {
|
||||||
conn, err := db.connProv()
|
conn, err := db.connProv(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
db.onError(ctx, err)
|
db.onError(ctx, err)
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -117,7 +117,7 @@ func TestTxExceptions(t *testing.T) {
|
|||||||
|
|
||||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
conn := &commonSqlConn{
|
conn := &commonSqlConn{
|
||||||
connProv: func() (*sql.DB, error) {
|
connProv: func(ctx context.Context) (*sql.DB, error) {
|
||||||
return nil, errors.New("foo")
|
return nil, errors.New("foo")
|
||||||
},
|
},
|
||||||
beginTx: begin,
|
beginTx: begin,
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package stringx
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"slices"
|
||||||
"unicode"
|
"unicode"
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/core/lang"
|
"github.com/zeromicro/go-zero/core/lang"
|
||||||
@@ -15,14 +16,9 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// Contains checks if str is in list.
|
// Contains checks if str is in list.
|
||||||
|
// Deprecated: use slices.Contains instead.
|
||||||
func Contains(list []string, str string) bool {
|
func Contains(list []string, str string) bool {
|
||||||
for _, each := range list {
|
return slices.Contains(list, str)
|
||||||
if each == str {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Filter filters chars from s with given filter function.
|
// Filter filters chars from s with given filter function.
|
||||||
@@ -123,11 +119,7 @@ func Remove(strings []string, strs ...string) []string {
|
|||||||
// Reverse reverses s.
|
// Reverse reverses s.
|
||||||
func Reverse(s string) string {
|
func Reverse(s string) string {
|
||||||
runes := []rune(s)
|
runes := []rune(s)
|
||||||
|
slices.Reverse(runes)
|
||||||
for from, to := 0, len(runes)-1; from < to; from, to = from+1, to-1 {
|
|
||||||
runes[from], runes[to] = runes[to], runes[from]
|
|
||||||
}
|
|
||||||
|
|
||||||
return string(runes)
|
return string(runes)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,28 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestContainsString(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
slice []string
|
||||||
|
value string
|
||||||
|
expect bool
|
||||||
|
}{
|
||||||
|
{[]string{"1"}, "1", true},
|
||||||
|
{[]string{"1"}, "2", false},
|
||||||
|
{[]string{"1", "2"}, "1", true},
|
||||||
|
{[]string{"1", "2"}, "3", false},
|
||||||
|
{nil, "3", false},
|
||||||
|
{nil, "", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, each := range cases {
|
||||||
|
t.Run(path.Join(each.slice...), func(t *testing.T) {
|
||||||
|
actual := Contains(each.slice, each.value)
|
||||||
|
assert.Equal(t, each.expect, actual)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestNotEmpty(t *testing.T) {
|
func TestNotEmpty(t *testing.T) {
|
||||||
cases := []struct {
|
cases := []struct {
|
||||||
args []string
|
args []string
|
||||||
@@ -41,28 +63,6 @@ func TestNotEmpty(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestContainsString(t *testing.T) {
|
|
||||||
cases := []struct {
|
|
||||||
slice []string
|
|
||||||
value string
|
|
||||||
expect bool
|
|
||||||
}{
|
|
||||||
{[]string{"1"}, "1", true},
|
|
||||||
{[]string{"1"}, "2", false},
|
|
||||||
{[]string{"1", "2"}, "1", true},
|
|
||||||
{[]string{"1", "2"}, "3", false},
|
|
||||||
{nil, "3", false},
|
|
||||||
{nil, "", false},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, each := range cases {
|
|
||||||
t.Run(path.Join(each.slice...), func(t *testing.T) {
|
|
||||||
actual := Contains(each.slice, each.value)
|
|
||||||
assert.Equal(t, each.expect, actual)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFilter(t *testing.T) {
|
func TestFilter(t *testing.T) {
|
||||||
cases := []struct {
|
cases := []struct {
|
||||||
input string
|
input string
|
||||||
|
|||||||
@@ -3,9 +3,7 @@ package syncx
|
|||||||
import "sync"
|
import "sync"
|
||||||
|
|
||||||
// Once returns a func that guarantees fn can only called once.
|
// Once returns a func that guarantees fn can only called once.
|
||||||
|
// Deprecated: use sync.OnceFunc instead.
|
||||||
func Once(fn func()) func() {
|
func Once(fn func()) func() {
|
||||||
once := new(sync.Once)
|
return sync.OnceFunc(fn)
|
||||||
return func() {
|
|
||||||
once.Do(fn)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
package utils
|
package utils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"cmp"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/core/mathx"
|
|
||||||
"github.com/zeromicro/go-zero/core/stringx"
|
"github.com/zeromicro/go-zero/core/stringx"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -39,7 +39,7 @@ func compare(v1, v2 string) int {
|
|||||||
fields1, fields2 := strings.Split(v1, "."), strings.Split(v2, ".")
|
fields1, fields2 := strings.Split(v1, "."), strings.Split(v2, ".")
|
||||||
ver1, ver2 := strsToInts(fields1), strsToInts(fields2)
|
ver1, ver2 := strsToInts(fields1), strsToInts(fields2)
|
||||||
ver1len, ver2len := len(ver1), len(ver2)
|
ver1len, ver2len := len(ver1), len(ver2)
|
||||||
shorter := mathx.MinInt(ver1len, ver2len)
|
shorter := min(ver1len, ver2len)
|
||||||
|
|
||||||
for i := 0; i < shorter; i++ {
|
for i := 0; i < shorter; i++ {
|
||||||
if ver1[i] == ver2[i] {
|
if ver1[i] == ver2[i] {
|
||||||
@@ -50,14 +50,7 @@ func compare(v1, v2 string) int {
|
|||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return cmp.Compare(ver1len, ver2len)
|
||||||
if ver1len < ver2len {
|
|
||||||
return -1
|
|
||||||
} else if ver1len == ver2len {
|
|
||||||
return 0
|
|
||||||
} else {
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func strsToInts(strs []string) []int64 {
|
func strsToInts(strs []string) []int64 {
|
||||||
|
|||||||
@@ -201,6 +201,13 @@ func TestHttpToHttp(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
|
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("method not allowed", func(t *testing.T) {
|
||||||
|
resp, err := httpc.Do(context.Background(), http.MethodPost,
|
||||||
|
"http://localhost:18882/api/ping", nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHttpToHttpBadUpstream(t *testing.T) {
|
func TestHttpToHttpBadUpstream(t *testing.T) {
|
||||||
|
|||||||
18
go.mod
18
go.mod
@@ -4,25 +4,25 @@ go 1.21
|
|||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/DATA-DOG/go-sqlmock v1.5.2
|
github.com/DATA-DOG/go-sqlmock v1.5.2
|
||||||
github.com/alicebob/miniredis/v2 v2.34.0
|
github.com/alicebob/miniredis/v2 v2.35.0
|
||||||
github.com/fatih/color v1.18.0
|
github.com/fatih/color v1.18.0
|
||||||
github.com/fullstorydev/grpcurl v1.9.2
|
github.com/fullstorydev/grpcurl v1.9.3
|
||||||
github.com/go-sql-driver/mysql v1.9.0
|
github.com/go-sql-driver/mysql v1.9.0
|
||||||
github.com/golang-jwt/jwt/v4 v4.5.1
|
github.com/golang-jwt/jwt/v4 v4.5.2
|
||||||
github.com/golang/mock v1.6.0
|
github.com/golang/mock v1.6.0
|
||||||
github.com/golang/protobuf v1.5.4
|
github.com/golang/protobuf v1.5.4
|
||||||
github.com/google/uuid v1.6.0
|
github.com/google/uuid v1.6.0
|
||||||
github.com/jackc/pgx/v5 v5.7.2
|
github.com/grafana/pyroscope-go v1.2.2
|
||||||
|
github.com/jackc/pgx/v5 v5.7.4
|
||||||
github.com/jhump/protoreflect v1.17.0
|
github.com/jhump/protoreflect v1.17.0
|
||||||
github.com/olekukonko/tablewriter v0.0.5
|
|
||||||
github.com/pelletier/go-toml/v2 v2.2.2
|
github.com/pelletier/go-toml/v2 v2.2.2
|
||||||
github.com/prometheus/client_golang v1.21.0
|
github.com/prometheus/client_golang v1.21.1
|
||||||
github.com/redis/go-redis/v9 v9.7.1
|
github.com/redis/go-redis/v9 v9.11.0
|
||||||
github.com/spaolacci/murmur3 v1.1.0
|
github.com/spaolacci/murmur3 v1.1.0
|
||||||
github.com/stretchr/testify v1.10.0
|
github.com/stretchr/testify v1.10.0
|
||||||
go.etcd.io/etcd/api/v3 v3.5.15
|
go.etcd.io/etcd/api/v3 v3.5.15
|
||||||
go.etcd.io/etcd/client/v3 v3.5.15
|
go.etcd.io/etcd/client/v3 v3.5.15
|
||||||
go.mongodb.org/mongo-driver v1.17.3
|
go.mongodb.org/mongo-driver v1.17.4
|
||||||
go.opentelemetry.io/otel v1.24.0
|
go.opentelemetry.io/otel v1.24.0
|
||||||
go.opentelemetry.io/otel/exporters/jaeger v1.17.0
|
go.opentelemetry.io/otel/exporters/jaeger v1.17.0
|
||||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.24.0
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.24.0
|
||||||
@@ -50,7 +50,6 @@ require (
|
|||||||
|
|
||||||
require (
|
require (
|
||||||
filippo.io/edwards25519 v1.1.0 // indirect
|
filippo.io/edwards25519 v1.1.0 // indirect
|
||||||
github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 // indirect
|
|
||||||
github.com/beorn7/perks v1.0.1 // indirect
|
github.com/beorn7/perks v1.0.1 // indirect
|
||||||
github.com/bufbuild/protocompile v0.14.1 // indirect
|
github.com/bufbuild/protocompile v0.14.1 // indirect
|
||||||
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
|
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
|
||||||
@@ -73,6 +72,7 @@ require (
|
|||||||
github.com/google/gnostic-models v0.6.8 // indirect
|
github.com/google/gnostic-models v0.6.8 // indirect
|
||||||
github.com/google/go-cmp v0.6.0 // indirect
|
github.com/google/go-cmp v0.6.0 // indirect
|
||||||
github.com/google/gofuzz v1.2.0 // indirect
|
github.com/google/gofuzz v1.2.0 // indirect
|
||||||
|
github.com/grafana/pyroscope-go/godeltaprof v0.1.8 // indirect
|
||||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 // indirect
|
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 // indirect
|
||||||
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 // indirect
|
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 // indirect
|
||||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||||
|
|||||||
37
go.sum
37
go.sum
@@ -2,10 +2,8 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
|
|||||||
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
||||||
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
|
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
|
||||||
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
|
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
|
||||||
github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 h1:uvdUDbHQHO85qeSydJtItA4T55Pw6BtAejd0APRJOCE=
|
github.com/alicebob/miniredis/v2 v2.35.0 h1:QwLphYqCEAo1eu1TqPRN2jgVMPBweeQcR21jeqDCONI=
|
||||||
github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc=
|
github.com/alicebob/miniredis/v2 v2.35.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM=
|
||||||
github.com/alicebob/miniredis/v2 v2.34.0 h1:mBFWMaJSNL9RwdGRyEDoAAv8OQc5UlEhLDQggTglU/0=
|
|
||||||
github.com/alicebob/miniredis/v2 v2.34.0/go.mod h1:kWShP4b58T1CW0Y5dViCd5ztzrDqRWqM3nksiyXk5s8=
|
|
||||||
github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8=
|
github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8=
|
||||||
github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
|
github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
|
||||||
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
||||||
@@ -40,8 +38,8 @@ github.com/envoyproxy/protoc-gen-validate v1.0.4 h1:gVPz/FMfvh57HdSJQyvBtF00j8JU
|
|||||||
github.com/envoyproxy/protoc-gen-validate v1.0.4/go.mod h1:qys6tmnRsYrQqIhm2bvKZH4Blx/1gTIZ2UKVY1M+Yew=
|
github.com/envoyproxy/protoc-gen-validate v1.0.4/go.mod h1:qys6tmnRsYrQqIhm2bvKZH4Blx/1gTIZ2UKVY1M+Yew=
|
||||||
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
|
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
|
||||||
github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU=
|
github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU=
|
||||||
github.com/fullstorydev/grpcurl v1.9.2 h1:ObqVQTZW7aFnhuqQoppUrvep2duMBanB0UYK2Mm8euo=
|
github.com/fullstorydev/grpcurl v1.9.3 h1:PC1Xi3w+JAvEE2Tg2Gf2RfVgPbf9+tbuQr1ZkyVU3jk=
|
||||||
github.com/fullstorydev/grpcurl v1.9.2/go.mod h1:jLfcF55HAz6TYIJY9xFFWgsl0D7o2HlxA5Z4lUG0Tdo=
|
github.com/fullstorydev/grpcurl v1.9.3/go.mod h1:/b4Wxe8bG6ndAjlfSUjwseQReUDUvBJiFEB7UllOlUE=
|
||||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||||
github.com/go-logr/logr v1.3.0/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
github.com/go-logr/logr v1.3.0/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||||
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
|
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
|
||||||
@@ -62,8 +60,8 @@ github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4
|
|||||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||||
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
|
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
|
||||||
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||||
github.com/golang-jwt/jwt/v4 v4.5.1 h1:JdqV9zKUdtaa9gdPlywC3aeoEsR681PlKC+4F5gQgeo=
|
github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI=
|
||||||
github.com/golang-jwt/jwt/v4 v4.5.1/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
|
github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
|
||||||
github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
|
github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
|
||||||
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
|
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
|
||||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||||
@@ -82,6 +80,10 @@ github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1 h1:K6RDEckDVWvDI9JAJY
|
|||||||
github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
|
github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
|
||||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
|
github.com/grafana/pyroscope-go v1.2.2 h1:uvKCyZMD724RkaCEMrSTC38Yn7AnFe8S2wiAIYdDPCE=
|
||||||
|
github.com/grafana/pyroscope-go v1.2.2/go.mod h1:zzT9QXQAp2Iz2ZdS216UiV8y9uXJYQiGE1q8v1FyhqU=
|
||||||
|
github.com/grafana/pyroscope-go/godeltaprof v0.1.8 h1:iwOtYXeeVSAeYefJNaxDytgjKtUuKQbJqgAIjlnicKg=
|
||||||
|
github.com/grafana/pyroscope-go/godeltaprof v0.1.8/go.mod h1:2+l7K7twW49Ct4wFluZD3tZ6e0SjanjcUUBPVD/UuGU=
|
||||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 h1:bkypFPDjIYGfCYD5mRBvpqxfYX1YCS1PXdKYWi8FsN0=
|
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 h1:bkypFPDjIYGfCYD5mRBvpqxfYX1YCS1PXdKYWi8FsN0=
|
||||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0/go.mod h1:P+Lt/0by1T8bfcF3z737NnSbmxQAppXMRziHUxPOC8k=
|
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0/go.mod h1:P+Lt/0by1T8bfcF3z737NnSbmxQAppXMRziHUxPOC8k=
|
||||||
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw=
|
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw=
|
||||||
@@ -90,8 +92,8 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI
|
|||||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||||
github.com/jackc/pgx/v5 v5.7.2 h1:mLoDLV6sonKlvjIEsV56SkWNCnuNv531l94GaIzO+XI=
|
github.com/jackc/pgx/v5 v5.7.4 h1:9wKznZrhWa2QiHL+NjTSPP6yjl3451BX3imWDnokYlg=
|
||||||
github.com/jackc/pgx/v5 v5.7.2/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ=
|
github.com/jackc/pgx/v5 v5.7.4/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ=
|
||||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||||
github.com/jhump/protoreflect v1.17.0 h1:qOEr613fac2lOuTgWN4tPAtLL7fUSbuJL5X5XumQh94=
|
github.com/jhump/protoreflect v1.17.0 h1:qOEr613fac2lOuTgWN4tPAtLL7fUSbuJL5X5XumQh94=
|
||||||
@@ -121,7 +123,6 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
|
|||||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
|
|
||||||
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
|
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
|
||||||
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
@@ -135,8 +136,6 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq
|
|||||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
||||||
github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32 h1:W6apQkHrMkS0Muv8G/TipAy/FJl/rCYT0+EuS8+Z0z4=
|
github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32 h1:W6apQkHrMkS0Muv8G/TipAy/FJl/rCYT0+EuS8+Z0z4=
|
||||||
github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32/go.mod h1:9wM+0iRr9ahx58uYLpLIr5fm8diHn0JbqRycJi6w0Ms=
|
github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32/go.mod h1:9wM+0iRr9ahx58uYLpLIr5fm8diHn0JbqRycJi6w0Ms=
|
||||||
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
|
|
||||||
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
|
|
||||||
github.com/onsi/ginkgo/v2 v2.13.0 h1:0jY9lJquiL8fcf3M4LAXN5aMlS/b2BV86HFFPCPMgE4=
|
github.com/onsi/ginkgo/v2 v2.13.0 h1:0jY9lJquiL8fcf3M4LAXN5aMlS/b2BV86HFFPCPMgE4=
|
||||||
github.com/onsi/ginkgo/v2 v2.13.0/go.mod h1:TE309ZR8s5FsKKpuB1YAQYBzCaAfUgatB/xlT/ETL/o=
|
github.com/onsi/ginkgo/v2 v2.13.0/go.mod h1:TE309ZR8s5FsKKpuB1YAQYBzCaAfUgatB/xlT/ETL/o=
|
||||||
github.com/onsi/gomega v1.29.0 h1:KIA/t2t5UBzoirT4H9tsML45GEbo3ouUnBHsCfD2tVg=
|
github.com/onsi/gomega v1.29.0 h1:KIA/t2t5UBzoirT4H9tsML45GEbo3ouUnBHsCfD2tVg=
|
||||||
@@ -151,16 +150,16 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
|
|||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g=
|
github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g=
|
||||||
github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P1IAHhP5U=
|
github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P1IAHhP5U=
|
||||||
github.com/prometheus/client_golang v1.21.0 h1:DIsaGmiaBkSangBgMtWdNfxbMNdku5IK6iNhrEqWvdA=
|
github.com/prometheus/client_golang v1.21.1 h1:DOvXXTqVzvkIewV/CDPFdejpMCGeMcbGCQ8YOmu+Ibk=
|
||||||
github.com/prometheus/client_golang v1.21.0/go.mod h1:U9NM32ykUErtVBxdvD3zfi+EuFkkaBvMb09mIfe0Zgg=
|
github.com/prometheus/client_golang v1.21.1/go.mod h1:U9NM32ykUErtVBxdvD3zfi+EuFkkaBvMb09mIfe0Zgg=
|
||||||
github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E=
|
github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E=
|
||||||
github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY=
|
github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY=
|
||||||
github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io=
|
github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io=
|
||||||
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
|
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
|
||||||
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
|
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
|
||||||
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
|
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
|
||||||
github.com/redis/go-redis/v9 v9.7.1 h1:4LhKRCIduqXqtvCUlaq9c8bdHOkICjDMrr1+Zb3osAc=
|
github.com/redis/go-redis/v9 v9.11.0 h1:E3S08Gl/nJNn5vkxd2i78wZxWAPNZgUNTp8WIJUAiIs=
|
||||||
github.com/redis/go-redis/v9 v9.7.1/go.mod h1:f6zhXITC7JUJIlPEiBOTXxJgPLdZcA93GewI7inzyWw=
|
github.com/redis/go-redis/v9 v9.11.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
|
||||||
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
||||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||||
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
||||||
@@ -203,8 +202,8 @@ go.etcd.io/etcd/client/pkg/v3 v3.5.15 h1:fo0HpWz/KlHGMCC+YejpiCmyWDEuIpnTDzpJLB5
|
|||||||
go.etcd.io/etcd/client/pkg/v3 v3.5.15/go.mod h1:mXDI4NAOwEiszrHCb0aqfAYNCrZP4e9hRca3d1YK8EU=
|
go.etcd.io/etcd/client/pkg/v3 v3.5.15/go.mod h1:mXDI4NAOwEiszrHCb0aqfAYNCrZP4e9hRca3d1YK8EU=
|
||||||
go.etcd.io/etcd/client/v3 v3.5.15 h1:23M0eY4Fd/inNv1ZfU3AxrbbOdW79r9V9Rl62Nm6ip4=
|
go.etcd.io/etcd/client/v3 v3.5.15 h1:23M0eY4Fd/inNv1ZfU3AxrbbOdW79r9V9Rl62Nm6ip4=
|
||||||
go.etcd.io/etcd/client/v3 v3.5.15/go.mod h1:CLSJxrYjvLtHsrPKsy7LmZEE+DK2ktfd2bN4RhBMwlU=
|
go.etcd.io/etcd/client/v3 v3.5.15/go.mod h1:CLSJxrYjvLtHsrPKsy7LmZEE+DK2ktfd2bN4RhBMwlU=
|
||||||
go.mongodb.org/mongo-driver v1.17.3 h1:TQyXhnsWfWtgAhMtOgtYHMTkZIfBTpMTsMnd9ZBeHxQ=
|
go.mongodb.org/mongo-driver v1.17.4 h1:jUorfmVzljjr0FLzYQsGP8cgN/qzzxlY9Vh0C9KFXVw=
|
||||||
go.mongodb.org/mongo-driver v1.17.3/go.mod h1:Hy04i7O2kC4RS06ZrhPRqj/u4DTYkFDAAccj+rVKqgQ=
|
go.mongodb.org/mongo-driver v1.17.4/go.mod h1:Hy04i7O2kC4RS06ZrhPRqj/u4DTYkFDAAccj+rVKqgQ=
|
||||||
go.opentelemetry.io/otel v1.24.0 h1:0LAOdjNmQeSTzGBzduGe/rU4tZhMwL5rWgtp9Ku5Jfo=
|
go.opentelemetry.io/otel v1.24.0 h1:0LAOdjNmQeSTzGBzduGe/rU4tZhMwL5rWgtp9Ku5Jfo=
|
||||||
go.opentelemetry.io/otel v1.24.0/go.mod h1:W7b9Ozg4nkF5tWI5zsXkaKKDjdVjpD4oAt9Qi/MArHo=
|
go.opentelemetry.io/otel v1.24.0/go.mod h1:W7b9Ozg4nkF5tWI5zsXkaKKDjdVjpD4oAt9Qi/MArHo=
|
||||||
go.opentelemetry.io/otel/exporters/jaeger v1.17.0 h1:D7UpUy2Xc2wsi1Ras6V40q806WM07rqoCWzXu7Sqy+4=
|
go.opentelemetry.io/otel/exporters/jaeger v1.17.0 h1:D7UpUy2Xc2wsi1Ras6V40q806WM07rqoCWzXu7Sqy+4=
|
||||||
|
|||||||
263
internal/profiling/profiling.go
Normal file
263
internal/profiling/profiling.go
Normal file
@@ -0,0 +1,263 @@
|
|||||||
|
package profiling
|
||||||
|
|
||||||
|
import (
|
||||||
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/grafana/pyroscope-go"
|
||||||
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
|
"github.com/zeromicro/go-zero/core/proc"
|
||||||
|
"github.com/zeromicro/go-zero/core/stat"
|
||||||
|
"github.com/zeromicro/go-zero/core/threading"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultCheckInterval = time.Second * 10
|
||||||
|
defaultProfilingDuration = time.Minute * 2
|
||||||
|
defaultUploadRate = time.Second * 15
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
Config struct {
|
||||||
|
// Name is the name of the application.
|
||||||
|
Name string `json:",optional,inherit"`
|
||||||
|
// ServerAddr is the address of the profiling server.
|
||||||
|
ServerAddr string
|
||||||
|
// AuthUser is the username for basic authentication.
|
||||||
|
AuthUser string `json:",optional"`
|
||||||
|
// AuthPassword is the password for basic authentication.
|
||||||
|
AuthPassword string `json:",optional"`
|
||||||
|
// UploadRate is the duration for which profiling data is uploaded.
|
||||||
|
UploadRate time.Duration `json:",default=15s"`
|
||||||
|
// CheckInterval is the interval to check if profiling should start.
|
||||||
|
CheckInterval time.Duration `json:",default=10s"`
|
||||||
|
// ProfilingDuration is the duration for which profiling data is collected.
|
||||||
|
ProfilingDuration time.Duration `json:",default=2m"`
|
||||||
|
// CpuThreshold the collection is allowed only when the current service cpu > CpuThreshold
|
||||||
|
CpuThreshold int64 `json:",default=700,range=[0:1000)"`
|
||||||
|
|
||||||
|
// ProfileType is the type of profiling to be performed.
|
||||||
|
ProfileType ProfileType
|
||||||
|
}
|
||||||
|
|
||||||
|
ProfileType struct {
|
||||||
|
// Logger is a flag to enable or disable logging.
|
||||||
|
Logger bool `json:",default=false"`
|
||||||
|
// CPU is a flag to disable CPU profiling.
|
||||||
|
CPU bool `json:",default=true"`
|
||||||
|
// Goroutines is a flag to disable goroutine profiling.
|
||||||
|
Goroutines bool `json:",default=true"`
|
||||||
|
// Memory is a flag to disable memory profiling.
|
||||||
|
Memory bool `json:",default=true"`
|
||||||
|
// Mutex is a flag to disable mutex profiling.
|
||||||
|
Mutex bool `json:",default=false"`
|
||||||
|
// Block is a flag to disable block profiling.
|
||||||
|
Block bool `json:",default=false"`
|
||||||
|
}
|
||||||
|
|
||||||
|
profiler interface {
|
||||||
|
Start() error
|
||||||
|
Stop() error
|
||||||
|
}
|
||||||
|
|
||||||
|
pyroscopeProfiler struct {
|
||||||
|
c Config
|
||||||
|
profiler *pyroscope.Profiler
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
once sync.Once
|
||||||
|
|
||||||
|
newProfiler = func(c Config) profiler {
|
||||||
|
return newPyroscopeProfiler(c)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
// Start initializes the pyroscope profiler with the given configuration.
|
||||||
|
func Start(c Config) {
|
||||||
|
// check if the profiling is enabled
|
||||||
|
if len(c.ServerAddr) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// set default values for the configuration
|
||||||
|
if c.ProfilingDuration <= 0 {
|
||||||
|
c.ProfilingDuration = defaultProfilingDuration
|
||||||
|
}
|
||||||
|
|
||||||
|
// set default values for the configuration
|
||||||
|
if c.CheckInterval <= 0 {
|
||||||
|
c.CheckInterval = defaultCheckInterval
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.UploadRate <= 0 {
|
||||||
|
c.UploadRate = defaultUploadRate
|
||||||
|
}
|
||||||
|
|
||||||
|
once.Do(func() {
|
||||||
|
logx.Info("continuous profiling started")
|
||||||
|
|
||||||
|
threading.GoSafe(func() {
|
||||||
|
startPyroscope(c, proc.Done())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// startPyroscope starts the pyroscope profiler with the given configuration.
|
||||||
|
func startPyroscope(c Config, done <-chan struct{}) {
|
||||||
|
var (
|
||||||
|
pr profiler
|
||||||
|
err error
|
||||||
|
latestProfilingTime time.Time
|
||||||
|
intervalTicker = time.NewTicker(c.CheckInterval)
|
||||||
|
profilingTicker = time.NewTicker(c.ProfilingDuration)
|
||||||
|
)
|
||||||
|
|
||||||
|
defer profilingTicker.Stop()
|
||||||
|
defer intervalTicker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-intervalTicker.C:
|
||||||
|
// Check if the machine is overloaded and if the profiler is not running
|
||||||
|
if pr == nil && isCpuOverloaded(c) {
|
||||||
|
pr = newProfiler(c)
|
||||||
|
if err := pr.Start(); err != nil {
|
||||||
|
logx.Errorf("failed to start profiler: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// record the latest profiling time
|
||||||
|
latestProfilingTime = time.Now()
|
||||||
|
logx.Infof("pyroscope profiler started.")
|
||||||
|
}
|
||||||
|
case <-profilingTicker.C:
|
||||||
|
// check if the profiling duration has passed
|
||||||
|
if !time.Now().After(latestProfilingTime.Add(c.ProfilingDuration)) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if the profiler is already running, if so, skip
|
||||||
|
if pr != nil {
|
||||||
|
if err = pr.Stop(); err != nil {
|
||||||
|
logx.Errorf("failed to stop profiler: %v", err)
|
||||||
|
}
|
||||||
|
logx.Infof("pyroscope profiler stopped.")
|
||||||
|
pr = nil
|
||||||
|
}
|
||||||
|
case <-done:
|
||||||
|
logx.Infof("continuous profiling stopped.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// genPyroscopeConf generates the pyroscope configuration based on the given config.
|
||||||
|
func genPyroscopeConf(c Config) pyroscope.Config {
|
||||||
|
pConf := pyroscope.Config{
|
||||||
|
UploadRate: c.UploadRate,
|
||||||
|
ApplicationName: c.Name,
|
||||||
|
BasicAuthUser: c.AuthUser, // http basic auth user
|
||||||
|
BasicAuthPassword: c.AuthPassword, // http basic auth password
|
||||||
|
ServerAddress: c.ServerAddr,
|
||||||
|
Logger: nil,
|
||||||
|
HTTPHeaders: map[string]string{},
|
||||||
|
// you can provide static tags via a map:
|
||||||
|
Tags: map[string]string{
|
||||||
|
"name": c.Name,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.ProfileType.Logger {
|
||||||
|
pConf.Logger = logx.WithCallerSkip(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.ProfileType.CPU {
|
||||||
|
pConf.ProfileTypes = append(pConf.ProfileTypes, pyroscope.ProfileCPU)
|
||||||
|
}
|
||||||
|
if c.ProfileType.Goroutines {
|
||||||
|
pConf.ProfileTypes = append(pConf.ProfileTypes, pyroscope.ProfileGoroutines)
|
||||||
|
}
|
||||||
|
if c.ProfileType.Memory {
|
||||||
|
pConf.ProfileTypes = append(pConf.ProfileTypes, pyroscope.ProfileAllocObjects, pyroscope.ProfileAllocSpace,
|
||||||
|
pyroscope.ProfileInuseObjects, pyroscope.ProfileInuseSpace)
|
||||||
|
}
|
||||||
|
if c.ProfileType.Mutex {
|
||||||
|
pConf.ProfileTypes = append(pConf.ProfileTypes, pyroscope.ProfileMutexCount, pyroscope.ProfileMutexDuration)
|
||||||
|
}
|
||||||
|
if c.ProfileType.Block {
|
||||||
|
pConf.ProfileTypes = append(pConf.ProfileTypes, pyroscope.ProfileBlockCount, pyroscope.ProfileBlockDuration)
|
||||||
|
}
|
||||||
|
|
||||||
|
logx.Infof("applicationName: %s", pConf.ApplicationName)
|
||||||
|
|
||||||
|
return pConf
|
||||||
|
}
|
||||||
|
|
||||||
|
// isCpuOverloaded checks the machine performance based on the given configuration.
|
||||||
|
func isCpuOverloaded(c Config) bool {
|
||||||
|
currentValue := stat.CpuUsage()
|
||||||
|
if currentValue >= c.CpuThreshold {
|
||||||
|
logx.Infof("continuous profiling cpu overload, cpu: %d", currentValue)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func newPyroscopeProfiler(c Config) profiler {
|
||||||
|
return &pyroscopeProfiler{
|
||||||
|
c: c,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *pyroscopeProfiler) Start() error {
|
||||||
|
pConf := genPyroscopeConf(p.c)
|
||||||
|
// set mutex and block profile rate
|
||||||
|
setFraction(p.c)
|
||||||
|
prof, err := pyroscope.Start(pConf)
|
||||||
|
if err != nil {
|
||||||
|
resetFraction(p.c)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
p.profiler = prof
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *pyroscopeProfiler) Stop() error {
|
||||||
|
if p.profiler == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := p.profiler.Stop(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
resetFraction(p.c)
|
||||||
|
p.profiler = nil
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setFraction(c Config) {
|
||||||
|
// These 2 lines are only required if you're using mutex or block profiling
|
||||||
|
if c.ProfileType.Mutex {
|
||||||
|
runtime.SetMutexProfileFraction(10) // 10/seconds
|
||||||
|
}
|
||||||
|
if c.ProfileType.Block {
|
||||||
|
runtime.SetBlockProfileRate(1000 * 1000) // 1/millisecond
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func resetFraction(c Config) {
|
||||||
|
// These 2 lines are only required if you're using mutex or block profiling
|
||||||
|
if c.ProfileType.Mutex {
|
||||||
|
runtime.SetMutexProfileFraction(0)
|
||||||
|
}
|
||||||
|
if c.ProfileType.Block {
|
||||||
|
runtime.SetBlockProfileRate(0)
|
||||||
|
}
|
||||||
|
}
|
||||||
177
internal/profiling/profiling_test.go
Normal file
177
internal/profiling/profiling_test.go
Normal file
@@ -0,0 +1,177 @@
|
|||||||
|
package profiling
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/grafana/pyroscope-go"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/zeromicro/go-zero/core/conf"
|
||||||
|
"github.com/zeromicro/go-zero/core/syncx"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestStart(t *testing.T) {
|
||||||
|
t.Run("profiling", func(t *testing.T) {
|
||||||
|
var c Config
|
||||||
|
assert.NoError(t, conf.FillDefault(&c))
|
||||||
|
c.Name = "test"
|
||||||
|
p := newProfiler(c)
|
||||||
|
assert.NotNil(t, p)
|
||||||
|
assert.NoError(t, p.Start())
|
||||||
|
assert.NoError(t, p.Stop())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid config", func(t *testing.T) {
|
||||||
|
mp := &mockProfiler{}
|
||||||
|
newProfiler = func(c Config) profiler {
|
||||||
|
return mp
|
||||||
|
}
|
||||||
|
|
||||||
|
Start(Config{})
|
||||||
|
|
||||||
|
Start(Config{
|
||||||
|
ServerAddr: "localhost:4040",
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("test start profiler", func(t *testing.T) {
|
||||||
|
mp := &mockProfiler{}
|
||||||
|
newProfiler = func(c Config) profiler {
|
||||||
|
return mp
|
||||||
|
}
|
||||||
|
|
||||||
|
c := Config{
|
||||||
|
Name: "test",
|
||||||
|
ServerAddr: "localhost:4040",
|
||||||
|
CheckInterval: time.Millisecond,
|
||||||
|
ProfilingDuration: time.Millisecond * 10,
|
||||||
|
CpuThreshold: 0,
|
||||||
|
}
|
||||||
|
var done = make(chan struct{})
|
||||||
|
go startPyroscope(c, done)
|
||||||
|
|
||||||
|
time.Sleep(time.Millisecond * 50)
|
||||||
|
close(done)
|
||||||
|
|
||||||
|
assert.True(t, mp.started.True())
|
||||||
|
assert.True(t, mp.stopped.True())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("test start profiler with cpu overloaded", func(t *testing.T) {
|
||||||
|
mp := &mockProfiler{}
|
||||||
|
newProfiler = func(c Config) profiler {
|
||||||
|
return mp
|
||||||
|
}
|
||||||
|
|
||||||
|
c := Config{
|
||||||
|
Name: "test",
|
||||||
|
ServerAddr: "localhost:4040",
|
||||||
|
CheckInterval: time.Millisecond,
|
||||||
|
ProfilingDuration: time.Millisecond * 10,
|
||||||
|
CpuThreshold: 900,
|
||||||
|
}
|
||||||
|
var done = make(chan struct{})
|
||||||
|
go startPyroscope(c, done)
|
||||||
|
|
||||||
|
time.Sleep(time.Millisecond * 50)
|
||||||
|
close(done)
|
||||||
|
|
||||||
|
assert.False(t, mp.started.True())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("start/stop err", func(t *testing.T) {
|
||||||
|
mp := &mockProfiler{
|
||||||
|
err: assert.AnError,
|
||||||
|
}
|
||||||
|
newProfiler = func(c Config) profiler {
|
||||||
|
return mp
|
||||||
|
}
|
||||||
|
|
||||||
|
c := Config{
|
||||||
|
Name: "test",
|
||||||
|
ServerAddr: "localhost:4040",
|
||||||
|
CheckInterval: time.Millisecond,
|
||||||
|
ProfilingDuration: time.Millisecond * 10,
|
||||||
|
CpuThreshold: 0,
|
||||||
|
}
|
||||||
|
var done = make(chan struct{})
|
||||||
|
go startPyroscope(c, done)
|
||||||
|
|
||||||
|
time.Sleep(time.Millisecond * 50)
|
||||||
|
close(done)
|
||||||
|
|
||||||
|
assert.False(t, mp.started.True())
|
||||||
|
assert.False(t, mp.stopped.True())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenPyroscopeConf(t *testing.T) {
|
||||||
|
c := Config{
|
||||||
|
Name: "",
|
||||||
|
ServerAddr: "localhost:4040",
|
||||||
|
AuthUser: "user",
|
||||||
|
AuthPassword: "password",
|
||||||
|
ProfileType: ProfileType{
|
||||||
|
Logger: true,
|
||||||
|
CPU: true,
|
||||||
|
Goroutines: true,
|
||||||
|
Memory: true,
|
||||||
|
Mutex: true,
|
||||||
|
Block: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pyroscopeConf := genPyroscopeConf(c)
|
||||||
|
assert.Equal(t, c.ServerAddr, pyroscopeConf.ServerAddress)
|
||||||
|
assert.Equal(t, c.AuthUser, pyroscopeConf.BasicAuthUser)
|
||||||
|
assert.Equal(t, c.AuthPassword, pyroscopeConf.BasicAuthPassword)
|
||||||
|
assert.Equal(t, c.Name, pyroscopeConf.ApplicationName)
|
||||||
|
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileCPU)
|
||||||
|
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileGoroutines)
|
||||||
|
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileAllocObjects)
|
||||||
|
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileAllocSpace)
|
||||||
|
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileInuseObjects)
|
||||||
|
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileInuseSpace)
|
||||||
|
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileMutexCount)
|
||||||
|
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileMutexDuration)
|
||||||
|
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileBlockCount)
|
||||||
|
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileBlockDuration)
|
||||||
|
|
||||||
|
setFraction(c)
|
||||||
|
resetFraction(c)
|
||||||
|
|
||||||
|
newPyroscopeProfiler(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewPyroscopeProfiler(t *testing.T) {
|
||||||
|
p := newPyroscopeProfiler(Config{})
|
||||||
|
|
||||||
|
assert.Error(t, p.Start())
|
||||||
|
assert.NoError(t, p.Stop())
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockProfiler struct {
|
||||||
|
mutex sync.Mutex
|
||||||
|
started syncx.AtomicBool
|
||||||
|
stopped syncx.AtomicBool
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockProfiler) Start() error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
if m.err == nil {
|
||||||
|
m.started.Set(true)
|
||||||
|
}
|
||||||
|
m.mutex.Unlock()
|
||||||
|
return m.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockProfiler) Stop() error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
if m.err == nil {
|
||||||
|
m.stopped.Set(true)
|
||||||
|
}
|
||||||
|
m.mutex.Unlock()
|
||||||
|
return m.err
|
||||||
|
}
|
||||||
43
mcp/config.go
Normal file
43
mcp/config.go
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/zeromicro/go-zero/rest"
|
||||||
|
)
|
||||||
|
|
||||||
|
// McpConf defines the configuration for an MCP server.
|
||||||
|
// It embeds rest.RestConf for HTTP server settings
|
||||||
|
// and adds MCP-specific configuration options.
|
||||||
|
type McpConf struct {
|
||||||
|
rest.RestConf
|
||||||
|
Mcp struct {
|
||||||
|
// Name is the server name reported in initialize responses
|
||||||
|
Name string `json:",optional"`
|
||||||
|
|
||||||
|
// Version is the server version reported in initialize responses
|
||||||
|
Version string `json:",default=1.0.0"`
|
||||||
|
|
||||||
|
// ProtocolVersion is the MCP protocol version implemented
|
||||||
|
ProtocolVersion string `json:",default=2024-11-05"`
|
||||||
|
|
||||||
|
// BaseUrl is the base URL for the server, used in SSE endpoint messages
|
||||||
|
// If not set, defaults to http://localhost:{Port}
|
||||||
|
BaseUrl string `json:",optional"`
|
||||||
|
|
||||||
|
// SseEndpoint is the path for Server-Sent Events connections
|
||||||
|
SseEndpoint string `json:",default=/sse"`
|
||||||
|
|
||||||
|
// MessageEndpoint is the path for JSON-RPC requests
|
||||||
|
MessageEndpoint string `json:",default=/message"`
|
||||||
|
|
||||||
|
// Cors contains allowed CORS origins
|
||||||
|
Cors []string `json:",optional"`
|
||||||
|
|
||||||
|
// SseTimeout is the maximum time allowed for SSE connections
|
||||||
|
SseTimeout time.Duration `json:",default=24h"`
|
||||||
|
|
||||||
|
// MessageTimeout is the maximum time allowed for request execution
|
||||||
|
MessageTimeout time.Duration `json:",default=30s"`
|
||||||
|
}
|
||||||
|
}
|
||||||
63
mcp/config_test.go
Normal file
63
mcp/config_test.go
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/zeromicro/go-zero/core/conf"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMcpConfDefaults(t *testing.T) {
|
||||||
|
// Test default values are set correctly when unmarshalled from JSON
|
||||||
|
jsonConfig := `name: test-service
|
||||||
|
port: 8080
|
||||||
|
mcp:
|
||||||
|
name: test-mcp-server
|
||||||
|
version: 1.0.0
|
||||||
|
`
|
||||||
|
|
||||||
|
var c McpConf
|
||||||
|
err := conf.LoadFromYamlBytes([]byte(jsonConfig), &c)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Check default values
|
||||||
|
assert.Equal(t, "test-mcp-server", c.Mcp.Name)
|
||||||
|
assert.Equal(t, "1.0.0", c.Mcp.Version, "Default version should be 1.0.0")
|
||||||
|
assert.Equal(t, "2024-11-05", c.Mcp.ProtocolVersion, "Default protocol version should be 2024-11-05")
|
||||||
|
assert.Equal(t, "/sse", c.Mcp.SseEndpoint, "Default SSE endpoint should be /sse")
|
||||||
|
assert.Equal(t, "/message", c.Mcp.MessageEndpoint, "Default message endpoint should be /message")
|
||||||
|
assert.Equal(t, 30*time.Second, c.Mcp.MessageTimeout, "Default message timeout should be 30s")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMcpConfCustomValues(t *testing.T) {
|
||||||
|
// Test custom values can be set
|
||||||
|
jsonConfig := `{
|
||||||
|
"Name": "test-service",
|
||||||
|
"Port": 8080,
|
||||||
|
"Mcp": {
|
||||||
|
"Name": "test-mcp-server",
|
||||||
|
"Version": "2.0.0",
|
||||||
|
"ProtocolVersion": "2025-01-01",
|
||||||
|
"BaseUrl": "http://example.com",
|
||||||
|
"SseEndpoint": "/custom-sse",
|
||||||
|
"MessageEndpoint": "/custom-message",
|
||||||
|
"Cors": ["http://localhost:3000", "http://example.com"],
|
||||||
|
"MessageTimeout": "60s"
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
var c McpConf
|
||||||
|
err := conf.LoadFromJsonBytes([]byte(jsonConfig), &c)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Check custom values
|
||||||
|
assert.Equal(t, "test-mcp-server", c.Mcp.Name, "Name should be inherited from RestConf")
|
||||||
|
assert.Equal(t, "2.0.0", c.Mcp.Version, "Version should be customizable")
|
||||||
|
assert.Equal(t, "2025-01-01", c.Mcp.ProtocolVersion, "Protocol version should be customizable")
|
||||||
|
assert.Equal(t, "http://example.com", c.Mcp.BaseUrl, "BaseUrl should be customizable")
|
||||||
|
assert.Equal(t, "/custom-sse", c.Mcp.SseEndpoint, "SSE endpoint should be customizable")
|
||||||
|
assert.Equal(t, "/custom-message", c.Mcp.MessageEndpoint, "Message endpoint should be customizable")
|
||||||
|
assert.Equal(t, []string{"http://localhost:3000", "http://example.com"}, c.Mcp.Cors, "CORS settings should be customizable")
|
||||||
|
assert.Equal(t, 60*time.Second, c.Mcp.MessageTimeout, "Tool timeout should be customizable")
|
||||||
|
}
|
||||||
443
mcp/integration_test.go
Normal file
443
mcp/integration_test.go
Normal file
@@ -0,0 +1,443 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// syncResponseRecorder is a thread-safe wrapper around httptest.ResponseRecorder
|
||||||
|
type syncResponseRecorder struct {
|
||||||
|
*httptest.ResponseRecorder
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new synchronized response recorder
|
||||||
|
func newSyncResponseRecorder() *syncResponseRecorder {
|
||||||
|
return &syncResponseRecorder{
|
||||||
|
ResponseRecorder: httptest.NewRecorder(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Override Write method to synchronize access
|
||||||
|
func (srr *syncResponseRecorder) Write(p []byte) (int, error) {
|
||||||
|
srr.mu.Lock()
|
||||||
|
defer srr.mu.Unlock()
|
||||||
|
return srr.ResponseRecorder.Write(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Override WriteHeader method to synchronize access
|
||||||
|
func (srr *syncResponseRecorder) WriteHeader(statusCode int) {
|
||||||
|
srr.mu.Lock()
|
||||||
|
defer srr.mu.Unlock()
|
||||||
|
srr.ResponseRecorder.WriteHeader(statusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Override Result method to synchronize access
|
||||||
|
func (srr *syncResponseRecorder) Result() *http.Response {
|
||||||
|
srr.mu.Lock()
|
||||||
|
defer srr.mu.Unlock()
|
||||||
|
return srr.ResponseRecorder.Result()
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHTTPHandlerIntegration tests the HTTP handlers with a real server instance
|
||||||
|
func TestHTTPHandlerIntegration(t *testing.T) {
|
||||||
|
// Skip in short test mode
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a test configuration
|
||||||
|
conf := McpConf{}
|
||||||
|
conf.Mcp.Name = "test-integration"
|
||||||
|
conf.Mcp.Version = "1.0.0-test"
|
||||||
|
conf.Mcp.MessageTimeout = 1 * time.Second
|
||||||
|
|
||||||
|
// Create a mock server directly
|
||||||
|
server := &sseMcpServer{
|
||||||
|
conf: conf,
|
||||||
|
clients: make(map[string]*mcpClient),
|
||||||
|
tools: make(map[string]Tool),
|
||||||
|
prompts: make(map[string]Prompt),
|
||||||
|
resources: make(map[string]Resource),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register a test tool
|
||||||
|
err := server.RegisterTool(Tool{
|
||||||
|
Name: "echo",
|
||||||
|
Description: "Echo tool for testing",
|
||||||
|
InputSchema: InputSchema{
|
||||||
|
Properties: map[string]any{
|
||||||
|
"message": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"description": "Message to echo",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
|
if msg, ok := params["message"].(string); ok {
|
||||||
|
return fmt.Sprintf("Echo: %s", msg), nil
|
||||||
|
}
|
||||||
|
return "Echo: no message provided", nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Create a test HTTP request to the SSE endpoint
|
||||||
|
req := httptest.NewRequest("GET", "/sse", nil)
|
||||||
|
w := newSyncResponseRecorder()
|
||||||
|
|
||||||
|
// Create a done channel to signal completion of test
|
||||||
|
done := make(chan bool)
|
||||||
|
|
||||||
|
// Start the SSE handler in a goroutine
|
||||||
|
go func() {
|
||||||
|
// lock.Lock()
|
||||||
|
server.handleSSE(w, req)
|
||||||
|
// lock.Unlock()
|
||||||
|
done <- true
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Allow time for the handler to process
|
||||||
|
select {
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
// Expected - handler would normally block indefinitely
|
||||||
|
case <-done:
|
||||||
|
// This shouldn't happen immediately - the handler should block
|
||||||
|
t.Error("SSE handler returned unexpectedly")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check the initial headers
|
||||||
|
resp := w.Result()
|
||||||
|
assert.Equal(t, "chunked", resp.Header.Get("Transfer-Encoding"))
|
||||||
|
resp.Body.Close()
|
||||||
|
|
||||||
|
// The handler creates a client and sends the endpoint message
|
||||||
|
var sessionId string
|
||||||
|
|
||||||
|
// Give the handler time to set up the client
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
// Check that a client was created
|
||||||
|
server.clientsLock.Lock()
|
||||||
|
assert.Equal(t, 1, len(server.clients))
|
||||||
|
for id := range server.clients {
|
||||||
|
sessionId = id
|
||||||
|
}
|
||||||
|
server.clientsLock.Unlock()
|
||||||
|
|
||||||
|
require.NotEmpty(t, sessionId, "Expected a session ID to be created")
|
||||||
|
|
||||||
|
// Now that we have a session ID, we can test the message endpoint
|
||||||
|
messageBody, _ := json.Marshal(Request{
|
||||||
|
JsonRpc: "2.0",
|
||||||
|
ID: 1,
|
||||||
|
Method: methodInitialize,
|
||||||
|
Params: json.RawMessage(`{}`),
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create a message request
|
||||||
|
reqURL := fmt.Sprintf("/message?%s=%s", sessionIdKey, sessionId)
|
||||||
|
msgReq := httptest.NewRequest("POST", reqURL, bytes.NewReader(messageBody))
|
||||||
|
msgW := newSyncResponseRecorder()
|
||||||
|
|
||||||
|
// Process the message
|
||||||
|
server.handleRequest(msgW, msgReq)
|
||||||
|
|
||||||
|
// Check the response
|
||||||
|
msgResp := msgW.Result()
|
||||||
|
assert.Equal(t, http.StatusAccepted, msgResp.StatusCode)
|
||||||
|
msgResp.Body.Close() // Ensure response body is closed
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandlerResponseFlow tests the flow of a full request/response cycle
|
||||||
|
func TestHandlerResponseFlow(t *testing.T) {
|
||||||
|
// Create a mock server for testing
|
||||||
|
server := &sseMcpServer{
|
||||||
|
conf: McpConf{},
|
||||||
|
clients: map[string]*mcpClient{
|
||||||
|
"test-session": {
|
||||||
|
id: "test-session",
|
||||||
|
channel: make(chan string, 10),
|
||||||
|
initialized: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
tools: make(map[string]Tool),
|
||||||
|
prompts: make(map[string]Prompt),
|
||||||
|
resources: make(map[string]Resource),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register test resources
|
||||||
|
server.RegisterTool(Tool{
|
||||||
|
Name: "test.tool",
|
||||||
|
Description: "Test tool",
|
||||||
|
InputSchema: InputSchema{Type: "object"},
|
||||||
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
|
return "tool result", nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
server.RegisterPrompt(Prompt{
|
||||||
|
Name: "test.prompt",
|
||||||
|
Description: "Test prompt",
|
||||||
|
})
|
||||||
|
|
||||||
|
server.RegisterResource(Resource{
|
||||||
|
Name: "test.resource",
|
||||||
|
URI: "http://example.com",
|
||||||
|
Description: "Test resource",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create a request with session ID parameter
|
||||||
|
reqURL := fmt.Sprintf("/message?%s=%s", sessionIdKey, "test-session")
|
||||||
|
|
||||||
|
// Test tools/list request
|
||||||
|
toolsListBody, _ := json.Marshal(Request{
|
||||||
|
JsonRpc: "2.0",
|
||||||
|
ID: 1,
|
||||||
|
Method: methodToolsList,
|
||||||
|
Params: json.RawMessage(`{}`),
|
||||||
|
})
|
||||||
|
|
||||||
|
toolsReq := httptest.NewRequest("POST", reqURL, bytes.NewReader(toolsListBody))
|
||||||
|
toolsW := newSyncResponseRecorder()
|
||||||
|
|
||||||
|
// Process the request
|
||||||
|
server.handleRequest(toolsW, toolsReq)
|
||||||
|
|
||||||
|
// Check the response code
|
||||||
|
toolsResp := toolsW.Result()
|
||||||
|
assert.Equal(t, http.StatusAccepted, toolsResp.StatusCode)
|
||||||
|
toolsResp.Body.Close()
|
||||||
|
|
||||||
|
// Check the channel message
|
||||||
|
client := server.clients["test-session"]
|
||||||
|
select {
|
||||||
|
case message := <-client.channel:
|
||||||
|
assert.Contains(t, message, `"tools":[{"name":"test.tool"`)
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Fatal("Timed out waiting for tools/list response")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test prompts/list request
|
||||||
|
promptsListBody, _ := json.Marshal(Request{
|
||||||
|
JsonRpc: "2.0",
|
||||||
|
ID: 2,
|
||||||
|
Method: methodPromptsList,
|
||||||
|
Params: json.RawMessage(`{}`),
|
||||||
|
})
|
||||||
|
|
||||||
|
promptsReq := httptest.NewRequest("POST", reqURL, bytes.NewReader(promptsListBody))
|
||||||
|
promptsW := newSyncResponseRecorder()
|
||||||
|
|
||||||
|
// Process the request
|
||||||
|
server.handleRequest(promptsW, promptsReq)
|
||||||
|
|
||||||
|
// Check the response code
|
||||||
|
promptsResp := promptsW.Result()
|
||||||
|
assert.Equal(t, http.StatusAccepted, promptsResp.StatusCode)
|
||||||
|
promptsResp.Body.Close()
|
||||||
|
|
||||||
|
// Check the channel message
|
||||||
|
select {
|
||||||
|
case message := <-client.channel:
|
||||||
|
assert.Contains(t, message, `"prompts":[{"name":"test.prompt"`)
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Fatal("Timed out waiting for prompts/list response")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test resources/list request
|
||||||
|
resourcesListBody, _ := json.Marshal(Request{
|
||||||
|
JsonRpc: "2.0",
|
||||||
|
ID: 3,
|
||||||
|
Method: methodResourcesList,
|
||||||
|
Params: json.RawMessage(`{}`),
|
||||||
|
})
|
||||||
|
|
||||||
|
resourcesReq := httptest.NewRequest("POST", reqURL, bytes.NewReader(resourcesListBody))
|
||||||
|
resourcesW := newSyncResponseRecorder()
|
||||||
|
|
||||||
|
// Process the request
|
||||||
|
server.handleRequest(resourcesW, resourcesReq)
|
||||||
|
|
||||||
|
// Check the response code
|
||||||
|
resourcesResp := resourcesW.Result()
|
||||||
|
assert.Equal(t, http.StatusAccepted, resourcesResp.StatusCode)
|
||||||
|
resourcesResp.Body.Close()
|
||||||
|
|
||||||
|
// Check the channel message
|
||||||
|
select {
|
||||||
|
case message := <-client.channel:
|
||||||
|
assert.Contains(t, message, `"name":"test.resource"`)
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Fatal("Timed out waiting for resources/list response")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProcessListMethods tests the list processing methods with pagination
|
||||||
|
func TestProcessListMethods(t *testing.T) {
|
||||||
|
server := &sseMcpServer{
|
||||||
|
tools: make(map[string]Tool),
|
||||||
|
prompts: make(map[string]Prompt),
|
||||||
|
resources: make(map[string]Resource),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add some test data
|
||||||
|
for i := 1; i <= 5; i++ {
|
||||||
|
tool := Tool{
|
||||||
|
Name: fmt.Sprintf("tool%d", i),
|
||||||
|
Description: fmt.Sprintf("Tool %d", i),
|
||||||
|
InputSchema: InputSchema{Type: "object"},
|
||||||
|
}
|
||||||
|
server.tools[tool.Name] = tool
|
||||||
|
|
||||||
|
prompt := Prompt{
|
||||||
|
Name: fmt.Sprintf("prompt%d", i),
|
||||||
|
Description: fmt.Sprintf("Prompt %d", i),
|
||||||
|
}
|
||||||
|
server.prompts[prompt.Name] = prompt
|
||||||
|
|
||||||
|
resource := Resource{
|
||||||
|
Name: fmt.Sprintf("resource%d", i),
|
||||||
|
URI: fmt.Sprintf("http://example.com/%d", i),
|
||||||
|
Description: fmt.Sprintf("Resource %d", i),
|
||||||
|
}
|
||||||
|
server.resources[resource.Name] = resource
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a test client
|
||||||
|
client := &mcpClient{
|
||||||
|
id: "test-client",
|
||||||
|
channel: make(chan string, 10),
|
||||||
|
initialized: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test processListTools
|
||||||
|
req := Request{
|
||||||
|
JsonRpc: "2.0",
|
||||||
|
ID: 1,
|
||||||
|
Method: methodToolsList,
|
||||||
|
Params: json.RawMessage(`{"cursor": "", "_meta": {"progressToken": "token1"}}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
server.processListTools(context.Background(), client, req)
|
||||||
|
|
||||||
|
// Read response
|
||||||
|
select {
|
||||||
|
case response := <-client.channel:
|
||||||
|
assert.Contains(t, response, `"tools":`)
|
||||||
|
assert.Contains(t, response, `"progressToken":"token1"`)
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Fatal("Timed out waiting for tools/list response")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test processListPrompts
|
||||||
|
req.ID = 2
|
||||||
|
req.Method = methodPromptsList
|
||||||
|
req.Params = json.RawMessage(`{"cursor": "next"}`)
|
||||||
|
server.processListPrompts(context.Background(), client, req)
|
||||||
|
|
||||||
|
// Read response
|
||||||
|
select {
|
||||||
|
case response := <-client.channel:
|
||||||
|
assert.Contains(t, response, `"prompts":`)
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Fatal("Timed out waiting for prompts/list response")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test processListResources
|
||||||
|
req.ID = 3
|
||||||
|
req.Method = methodResourcesList
|
||||||
|
req.Params = json.RawMessage(`{"cursor": "next"}`)
|
||||||
|
server.processListResources(context.Background(), client, req)
|
||||||
|
|
||||||
|
// Read response
|
||||||
|
select {
|
||||||
|
case response := <-client.channel:
|
||||||
|
assert.Contains(t, response, `"resources":`)
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Fatal("Timed out waiting for resources/list response")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestErrorResponseHandling tests error handling in the server
|
||||||
|
func TestErrorResponseHandling(t *testing.T) {
|
||||||
|
server := &sseMcpServer{
|
||||||
|
tools: make(map[string]Tool),
|
||||||
|
prompts: make(map[string]Prompt),
|
||||||
|
resources: make(map[string]Resource),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a test client
|
||||||
|
client := &mcpClient{
|
||||||
|
id: "test-client",
|
||||||
|
channel: make(chan string, 10),
|
||||||
|
initialized: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test invalid method
|
||||||
|
req := Request{
|
||||||
|
JsonRpc: "2.0",
|
||||||
|
ID: 1,
|
||||||
|
Method: "invalid_method",
|
||||||
|
Params: json.RawMessage(`{}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mock handleRequest by directly calling error handler
|
||||||
|
server.sendErrorResponse(context.Background(), client, req.ID, "Method not found", errCodeMethodNotFound)
|
||||||
|
|
||||||
|
// Check response
|
||||||
|
select {
|
||||||
|
case response := <-client.channel:
|
||||||
|
assert.Contains(t, response, `"error":{"code":-32601,"message":"Method not found"}`)
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Fatal("Timed out waiting for error response")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test invalid tool
|
||||||
|
toolReq := Request{
|
||||||
|
JsonRpc: "2.0",
|
||||||
|
ID: 2,
|
||||||
|
Method: methodToolsCall,
|
||||||
|
Params: json.RawMessage(`{"name":"non_existent_tool"}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call process method directly
|
||||||
|
server.processToolCall(context.Background(), client, toolReq)
|
||||||
|
|
||||||
|
// Check response
|
||||||
|
select {
|
||||||
|
case response := <-client.channel:
|
||||||
|
assert.Contains(t, response, `"error":{"code":-32602,"message":"Tool 'non_existent_tool' not found"}`)
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Fatal("Timed out waiting for error response")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test invalid prompt
|
||||||
|
promptReq := Request{
|
||||||
|
JsonRpc: "2.0",
|
||||||
|
ID: 3,
|
||||||
|
Method: methodPromptsGet,
|
||||||
|
Params: json.RawMessage(`{"name":"non_existent_prompt"}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call process method directly
|
||||||
|
server.processGetPrompt(context.Background(), client, promptReq)
|
||||||
|
|
||||||
|
// Check response
|
||||||
|
select {
|
||||||
|
case response := <-client.channel:
|
||||||
|
assert.Contains(t, response, `"error":{"code":-32602,"message":"Prompt 'non_existent_prompt' not found"}`)
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Fatal("Timed out waiting for error response")
|
||||||
|
}
|
||||||
|
}
|
||||||
23
mcp/parser.go
Normal file
23
mcp/parser.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/zeromicro/go-zero/core/mapping"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ParseArguments parses the arguments and populates the request object
|
||||||
|
func ParseArguments(args any, req any) error {
|
||||||
|
switch arguments := args.(type) {
|
||||||
|
case map[string]string:
|
||||||
|
m := make(map[string]any, len(arguments))
|
||||||
|
for k, v := range arguments {
|
||||||
|
m[k] = v
|
||||||
|
}
|
||||||
|
return mapping.UnmarshalJsonMap(m, req, mapping.WithStringValues())
|
||||||
|
case map[string]any:
|
||||||
|
return mapping.UnmarshalJsonMap(arguments, req)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unsupported argument type: %T", arguments)
|
||||||
|
}
|
||||||
|
}
|
||||||
139
mcp/parser_test.go
Normal file
139
mcp/parser_test.go
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestParseArguments_MapStringString tests parsing map[string]string arguments
|
||||||
|
func TestParseArguments_MapStringString(t *testing.T) {
|
||||||
|
// Sample request struct to populate
|
||||||
|
type requestStruct struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Count int `json:"count"`
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create test arguments
|
||||||
|
args := map[string]string{
|
||||||
|
"name": "test-name",
|
||||||
|
"message": "hello world",
|
||||||
|
"count": "42",
|
||||||
|
"enabled": "true",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a target object to populate
|
||||||
|
var req requestStruct
|
||||||
|
|
||||||
|
// Parse the arguments
|
||||||
|
err := ParseArguments(args, &req)
|
||||||
|
|
||||||
|
// Verify results
|
||||||
|
assert.NoError(t, err, "Should parse map[string]string without error")
|
||||||
|
assert.Equal(t, "test-name", req.Name, "Name should be correctly parsed")
|
||||||
|
assert.Equal(t, "hello world", req.Message, "Message should be correctly parsed")
|
||||||
|
assert.Equal(t, 42, req.Count, "Count should be correctly parsed to int")
|
||||||
|
assert.True(t, req.Enabled, "Enabled should be correctly parsed to bool")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestParseArguments_MapStringAny tests parsing map[string]any arguments
|
||||||
|
func TestParseArguments_MapStringAny(t *testing.T) {
|
||||||
|
// Sample request struct to populate
|
||||||
|
type requestStruct struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Count int `json:"count"`
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
Tags []string `json:"tags"`
|
||||||
|
Metadata map[string]string `json:"metadata"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create test arguments with mixed types
|
||||||
|
args := map[string]any{
|
||||||
|
"name": "test-name",
|
||||||
|
"message": "hello world",
|
||||||
|
"count": 42, // note: this is already an int
|
||||||
|
"enabled": true, // note: this is already a bool
|
||||||
|
"tags": []string{"tag1", "tag2"},
|
||||||
|
"metadata": map[string]string{
|
||||||
|
"key1": "value1",
|
||||||
|
"key2": "value2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a target object to populate
|
||||||
|
var req requestStruct
|
||||||
|
|
||||||
|
// Parse the arguments
|
||||||
|
err := ParseArguments(args, &req)
|
||||||
|
|
||||||
|
// Verify results
|
||||||
|
assert.NoError(t, err, "Should parse map[string]any without error")
|
||||||
|
assert.Equal(t, "test-name", req.Name, "Name should be correctly parsed")
|
||||||
|
assert.Equal(t, "hello world", req.Message, "Message should be correctly parsed")
|
||||||
|
assert.Equal(t, 42, req.Count, "Count should be correctly parsed")
|
||||||
|
assert.True(t, req.Enabled, "Enabled should be correctly parsed")
|
||||||
|
assert.Equal(t, []string{"tag1", "tag2"}, req.Tags, "Tags should be correctly parsed")
|
||||||
|
assert.Equal(t, map[string]string{
|
||||||
|
"key1": "value1",
|
||||||
|
"key2": "value2",
|
||||||
|
}, req.Metadata, "Metadata should be correctly parsed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestParseArguments_UnsupportedType tests parsing with an unsupported type
|
||||||
|
func TestParseArguments_UnsupportedType(t *testing.T) {
|
||||||
|
// Sample request struct to populate
|
||||||
|
type requestStruct struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use an unsupported argument type (slice)
|
||||||
|
args := []string{"not", "a", "map"}
|
||||||
|
|
||||||
|
// Create a target object to populate
|
||||||
|
var req requestStruct
|
||||||
|
|
||||||
|
// Parse the arguments
|
||||||
|
err := ParseArguments(args, &req)
|
||||||
|
|
||||||
|
// Verify error is returned with correct message
|
||||||
|
assert.Error(t, err, "Should return error for unsupported type")
|
||||||
|
assert.Contains(t, err.Error(), "unsupported argument type", "Error should mention unsupported type")
|
||||||
|
assert.Contains(t, err.Error(), "[]string", "Error should include the actual type")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestParseArguments_EmptyMap tests parsing with empty maps
|
||||||
|
func TestParseArguments_EmptyMap(t *testing.T) {
|
||||||
|
// Sample request struct to populate
|
||||||
|
type requestStruct struct {
|
||||||
|
Name string `json:"name,optional"`
|
||||||
|
Message string `json:"message,optional"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test empty map[string]string
|
||||||
|
t.Run("EmptyMapStringString", func(t *testing.T) {
|
||||||
|
args := map[string]string{}
|
||||||
|
var req requestStruct
|
||||||
|
|
||||||
|
err := ParseArguments(args, &req)
|
||||||
|
|
||||||
|
assert.NoError(t, err, "Should parse empty map[string]string without error")
|
||||||
|
assert.Empty(t, req.Name, "Name should be empty string")
|
||||||
|
assert.Empty(t, req.Message, "Message should be empty string")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test empty map[string]any
|
||||||
|
t.Run("EmptyMapStringAny", func(t *testing.T) {
|
||||||
|
args := map[string]any{}
|
||||||
|
var req requestStruct
|
||||||
|
|
||||||
|
err := ParseArguments(args, &req)
|
||||||
|
|
||||||
|
assert.NoError(t, err, "Should parse empty map[string]any without error")
|
||||||
|
assert.Empty(t, req.Name, "Name should be empty string")
|
||||||
|
assert.Empty(t, req.Message, "Message should be empty string")
|
||||||
|
})
|
||||||
|
}
|
||||||
870
mcp/readme.md
Normal file
870
mcp/readme.md
Normal file
@@ -0,0 +1,870 @@
|
|||||||
|
# Model Context Protocol (MCP) Implementation
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
This package implements the Model Context Protocol (MCP) server specification in Go, providing a framework for real-time communication between AI models and clients using Server-Sent Events (SSE). The implementation follows the standardized protocol for building AI-assisted applications with bidirectional communication capabilities.
|
||||||
|
|
||||||
|
## Core Components
|
||||||
|
|
||||||
|
### Server-Sent Events (SSE) Communication
|
||||||
|
- **Real-time Communication**: Robust SSE-based communication system that maintains persistent connections with clients
|
||||||
|
- **Connection Management**: Client registration, message broadcasting, and client cleanup mechanisms
|
||||||
|
- **Event Handling**: Event types for tools, prompts, and resources changes
|
||||||
|
|
||||||
|
### JSON-RPC Implementation
|
||||||
|
- **Request Processing**: Complete JSON-RPC request processor for handling MCP protocol methods
|
||||||
|
- **Response Formatting**: Proper response formatting according to JSON-RPC specifications
|
||||||
|
- **Error Handling**: Comprehensive error handling with appropriate error codes
|
||||||
|
|
||||||
|
### Tool Management
|
||||||
|
- **Tool Registration**: System to register custom tools with handlers
|
||||||
|
- **Tool Execution**: Mechanism to execute tool functions with proper timeout handling
|
||||||
|
- **Result Handling**: Flexible result handling supporting various return types (string, JSON, images)
|
||||||
|
|
||||||
|
### Prompt System
|
||||||
|
- **Prompt Registration**: System for registering both static and dynamic prompts
|
||||||
|
- **Argument Validation**: Validation for required arguments and default values for optional ones
|
||||||
|
- **Message Generation**: Handlers that generate properly formatted conversation messages
|
||||||
|
|
||||||
|
### Resource Management
|
||||||
|
- **Resource Registration**: System for managing and accessing external resources
|
||||||
|
- **Content Delivery**: Handlers for delivering resource content to clients on demand
|
||||||
|
- **Resource Subscription**: Mechanisms for clients to subscribe to resource updates
|
||||||
|
|
||||||
|
### Protocol Features
|
||||||
|
- **Initialization Sequence**: Proper handshaking with capability negotiation
|
||||||
|
- **Notification Handling**: Support for both standard and client-specific notifications
|
||||||
|
- **Message Routing**: Intelligent routing of requests to appropriate handlers
|
||||||
|
|
||||||
|
## Technical Highlights
|
||||||
|
|
||||||
|
### Configuration System
|
||||||
|
- **Flexible Configuration**: Configuration system with sensible defaults and customization options
|
||||||
|
- **CORS Support**: Configurable CORS settings for cross-origin requests
|
||||||
|
- **Server Information**: Proper server identification and versioning
|
||||||
|
|
||||||
|
### Client Session Management
|
||||||
|
- **Session Tracking**: Client session tracking with unique identifiers
|
||||||
|
- **Connection Health**: Ping/pong mechanism to maintain connection health
|
||||||
|
- **Initialization State**: Client initialization state tracking
|
||||||
|
|
||||||
|
### Content Handling
|
||||||
|
- **Multi-format Content**: Support for text, code, and binary content
|
||||||
|
- **MIME Type Support**: Proper MIME type identification for various content types
|
||||||
|
- **Audience Annotations**: Content audience annotations for user/assistant targeting
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
### Setting Up an MCP Server
|
||||||
|
|
||||||
|
To create and start an MCP server:
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/zeromicro/go-zero/core/conf"
|
||||||
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
|
"github.com/zeromicro/go-zero/mcp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// Load configuration from YAML file
|
||||||
|
var c mcp.McpConf
|
||||||
|
conf.MustLoad("config.yaml", &c)
|
||||||
|
|
||||||
|
// Optional: Disable stats logging
|
||||||
|
logx.DisableStat()
|
||||||
|
|
||||||
|
// Create MCP server
|
||||||
|
server := mcp.NewMcpServer(c)
|
||||||
|
|
||||||
|
// Register tools, prompts, and resources (examples below)
|
||||||
|
|
||||||
|
// Start the server and ensure it's stopped on exit
|
||||||
|
defer server.Stop()
|
||||||
|
server.Start()
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Sample configuration file (config.yaml):
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
name: mcp-server
|
||||||
|
host: localhost
|
||||||
|
port: 8080
|
||||||
|
mcp:
|
||||||
|
name: my-mcp-server
|
||||||
|
messageTimeout: 30s # Timeout for tool calls
|
||||||
|
cors:
|
||||||
|
- http://localhost:3000 # Optional CORS configuration
|
||||||
|
```
|
||||||
|
|
||||||
|
### Registering Tools
|
||||||
|
|
||||||
|
Tools allow AI models to execute custom code through the MCP protocol.
|
||||||
|
|
||||||
|
#### Basic Tool Example:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Register a simple echo tool
|
||||||
|
echoTool := mcp.Tool{
|
||||||
|
Name: "echo",
|
||||||
|
Description: "Echoes back the message provided by the user",
|
||||||
|
InputSchema: mcp.InputSchema{
|
||||||
|
Properties: map[string]any{
|
||||||
|
"message": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"description": "The message to echo back",
|
||||||
|
},
|
||||||
|
"prefix": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"description": "Optional prefix to add to the echoed message",
|
||||||
|
"default": "Echo: ",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Required: []string{"message"},
|
||||||
|
},
|
||||||
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
|
var req struct {
|
||||||
|
Message string `json:"message"`
|
||||||
|
Prefix string `json:"prefix,optional"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mcp.ParseArguments(params, &req); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse params: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
prefix := "Echo: "
|
||||||
|
if len(req.Prefix) > 0 {
|
||||||
|
prefix = req.Prefix
|
||||||
|
}
|
||||||
|
|
||||||
|
return prefix + req.Message, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
server.RegisterTool(echoTool)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Tool with Different Response Types:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Tool returning JSON data
|
||||||
|
dataTool := mcp.Tool{
|
||||||
|
Name: "data.generate",
|
||||||
|
Description: "Generates sample data in various formats",
|
||||||
|
InputSchema: mcp.InputSchema{
|
||||||
|
Properties: map[string]any{
|
||||||
|
"format": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"description": "Format of data (json, text)",
|
||||||
|
"enum": []string{"json", "text"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
|
var req struct {
|
||||||
|
Format string `json:"format"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mcp.ParseArguments(params, &req); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse params: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Format == "json" {
|
||||||
|
// Return structured data
|
||||||
|
return map[string]any{
|
||||||
|
"items": []map[string]any{
|
||||||
|
{"id": 1, "name": "Item 1"},
|
||||||
|
{"id": 2, "name": "Item 2"},
|
||||||
|
},
|
||||||
|
"count": 2,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default to text
|
||||||
|
return "Sample text data", nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
server.RegisterTool(dataTool)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Image Generation Tool Example:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Tool returning image content
|
||||||
|
imageTool := mcp.Tool{
|
||||||
|
Name: "image.generate",
|
||||||
|
Description: "Generates a simple image",
|
||||||
|
InputSchema: mcp.InputSchema{
|
||||||
|
Properties: map[string]any{
|
||||||
|
"type": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"description": "Type of image to generate",
|
||||||
|
"default": "placeholder",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
|
// Return image content directly
|
||||||
|
return mcp.ImageContent{
|
||||||
|
Data: "base64EncodedImageData...", // Base64 encoded image data
|
||||||
|
MimeType: "image/png",
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
server.RegisterTool(imageTool)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Using ToolResult for Custom Outputs:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Tool that returns a custom ToolResult type
|
||||||
|
customResultTool := mcp.Tool{
|
||||||
|
Name: "custom.result",
|
||||||
|
Description: "Returns a custom formatted result",
|
||||||
|
InputSchema: mcp.InputSchema{
|
||||||
|
Properties: map[string]any{
|
||||||
|
"resultType": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"enum": []string{"text", "image"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
|
var req struct {
|
||||||
|
ResultType string `json:"resultType"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mcp.ParseArguments(params, &req); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse params: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.ResultType == "image" {
|
||||||
|
return mcp.ToolResult{
|
||||||
|
Type: mcp.ContentTypeImage,
|
||||||
|
Content: map[string]any{
|
||||||
|
"data": "base64EncodedImageData...",
|
||||||
|
"mimeType": "image/jpeg",
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default to text
|
||||||
|
return mcp.ToolResult{
|
||||||
|
Type: mcp.ContentTypeText,
|
||||||
|
Content: "This is a text result from ToolResult",
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
server.RegisterTool(customResultTool)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Registering Prompts
|
||||||
|
|
||||||
|
Prompts are reusable conversation templates for AI models.
|
||||||
|
|
||||||
|
#### Static Prompt Example:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Register a simple static prompt with placeholders
|
||||||
|
server.RegisterPrompt(mcp.Prompt{
|
||||||
|
Name: "hello",
|
||||||
|
Description: "A simple hello prompt",
|
||||||
|
Arguments: []mcp.PromptArgument{
|
||||||
|
{
|
||||||
|
Name: "name",
|
||||||
|
Description: "The name to greet",
|
||||||
|
Required: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Content: "Say hello to {{name}} and introduce yourself as an AI assistant.",
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Dynamic Prompt with Handler Function:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Register a prompt with a dynamic handler function
|
||||||
|
server.RegisterPrompt(mcp.Prompt{
|
||||||
|
Name: "dynamic-prompt",
|
||||||
|
Description: "A prompt that uses a handler to generate dynamic content",
|
||||||
|
Arguments: []mcp.PromptArgument{
|
||||||
|
{
|
||||||
|
Name: "username",
|
||||||
|
Description: "User's name for personalized greeting",
|
||||||
|
Required: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "topic",
|
||||||
|
Description: "Topic of expertise",
|
||||||
|
Required: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) {
|
||||||
|
var req struct {
|
||||||
|
Username string `json:"username"`
|
||||||
|
Topic string `json:"topic"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mcp.ParseArguments(args, &req); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse args: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a user message
|
||||||
|
userMessage := mcp.PromptMessage{
|
||||||
|
Role: mcp.RoleUser,
|
||||||
|
Content: mcp.TextContent{
|
||||||
|
Text: fmt.Sprintf("Hello, I'm %s and I'd like to learn about %s.", req.Username, req.Topic),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create an assistant response with current time
|
||||||
|
currentTime := time.Now().Format(time.RFC1123)
|
||||||
|
assistantMessage := mcp.PromptMessage{
|
||||||
|
Role: mcp.RoleAssistant,
|
||||||
|
Content: mcp.TextContent{
|
||||||
|
Text: fmt.Sprintf("Hello %s! I'm an AI assistant and I'll help you learn about %s. The current time is %s.",
|
||||||
|
req.Username, req.Topic, currentTime),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return both messages as a conversation
|
||||||
|
return []mcp.PromptMessage{userMessage, assistantMessage}, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Multi-Message Prompt with Code Examples:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Register a prompt that provides code examples in different programming languages
|
||||||
|
server.RegisterPrompt(mcp.Prompt{
|
||||||
|
Name: "code-example",
|
||||||
|
Description: "Provides code examples in different programming languages",
|
||||||
|
Arguments: []mcp.PromptArgument{
|
||||||
|
{
|
||||||
|
Name: "language",
|
||||||
|
Description: "Programming language for the example",
|
||||||
|
Required: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "complexity",
|
||||||
|
Description: "Complexity level (simple, medium, advanced)",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) {
|
||||||
|
var req struct {
|
||||||
|
Language string `json:"language"`
|
||||||
|
Complexity string `json:"complexity,optional"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mcp.ParseArguments(args, &req); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse args: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate language
|
||||||
|
supportedLanguages := map[string]bool{"go": true, "python": true, "javascript": true, "rust": true}
|
||||||
|
if !supportedLanguages[req.Language] {
|
||||||
|
return nil, fmt.Errorf("unsupported language: %s", req.Language)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate code example based on language and complexity
|
||||||
|
var codeExample string
|
||||||
|
|
||||||
|
switch req.Language {
|
||||||
|
case "go":
|
||||||
|
if req.Complexity == "simple" {
|
||||||
|
codeExample = `
|
||||||
|
package main
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
fmt.Println("Hello, World!")
|
||||||
|
}`
|
||||||
|
} else {
|
||||||
|
codeExample = `
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
now := time.Now()
|
||||||
|
fmt.Printf("Hello, World! Current time is %s\n", now.Format(time.RFC3339))
|
||||||
|
}`
|
||||||
|
}
|
||||||
|
case "python":
|
||||||
|
// Python example code
|
||||||
|
if req.Complexity == "simple" {
|
||||||
|
codeExample = `
|
||||||
|
def greet(name):
|
||||||
|
return f"Hello, {name}!"
|
||||||
|
|
||||||
|
print(greet("World"))`
|
||||||
|
} else {
|
||||||
|
codeExample = `
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
def greet(name, include_time=False):
|
||||||
|
message = f"Hello, {name}!"
|
||||||
|
if include_time:
|
||||||
|
message += f" Current time is {datetime.datetime.now().isoformat()}"
|
||||||
|
return message
|
||||||
|
|
||||||
|
print(greet("World", include_time=True))`
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create messages array according to MCP spec
|
||||||
|
messages := []mcp.PromptMessage{
|
||||||
|
{
|
||||||
|
Role: mcp.RoleAssistant,
|
||||||
|
Content: mcp.TextContent{
|
||||||
|
Text: fmt.Sprintf("You are a helpful coding assistant specialized in %s programming.", req.Language),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: mcp.RoleUser,
|
||||||
|
Content: mcp.TextContent{
|
||||||
|
Text: fmt.Sprintf("Show me a %s example of a Hello World program in %s.", req.Complexity, req.Language),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: mcp.RoleAssistant,
|
||||||
|
Content: mcp.TextContent{
|
||||||
|
Text: fmt.Sprintf("Here's a %s example in %s:\n\n```%s%s\n```\n\nHow can I help you implement this?",
|
||||||
|
req.Complexity, req.Language, req.Language, codeExample),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return messages, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Registering Resources
|
||||||
|
|
||||||
|
Resources provide access to external content such as files or generated data.
|
||||||
|
|
||||||
|
#### Basic Resource Example:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Register a static resource
|
||||||
|
server.RegisterResource(mcp.Resource{
|
||||||
|
Name: "example-document",
|
||||||
|
URI: "file:///example/document.txt",
|
||||||
|
Description: "An example document",
|
||||||
|
MimeType: "text/plain",
|
||||||
|
Handler: func(ctx context.Context) (mcp.ResourceContent, error) {
|
||||||
|
return mcp.ResourceContent{
|
||||||
|
URI: "file:///example/document.txt",
|
||||||
|
MimeType: "text/plain",
|
||||||
|
Text: "This is an example document content.",
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Dynamic Resource with Code Example:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Register a Go code resource with dynamic handler
|
||||||
|
server.RegisterResource(mcp.Resource{
|
||||||
|
Name: "go-example",
|
||||||
|
URI: "file:///project/src/main.go",
|
||||||
|
Description: "A simple Go example with multiple files",
|
||||||
|
MimeType: "text/x-go",
|
||||||
|
Handler: func(ctx context.Context) (mcp.ResourceContent, error) {
|
||||||
|
// Return ResourceContent with all required fields
|
||||||
|
return mcp.ResourceContent{
|
||||||
|
URI: "file:///project/src/main.go",
|
||||||
|
MimeType: "text/x-go",
|
||||||
|
Text: "package main\n\nimport (\n\t\"fmt\"\n\t\"./greeting\"\n)\n\nfunc main() {\n\tfmt.Println(greeting.Hello(\"world\"))\n}",
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
// Register a companion file for the above example
|
||||||
|
server.RegisterResource(mcp.Resource{
|
||||||
|
Name: "go-greeting",
|
||||||
|
URI: "file:///project/src/greeting/greeting.go",
|
||||||
|
Description: "A greeting package for the Go example",
|
||||||
|
MimeType: "text/x-go",
|
||||||
|
Handler: func(ctx context.Context) (mcp.ResourceContent, error) {
|
||||||
|
return mcp.ResourceContent{
|
||||||
|
URI: "file:///project/src/greeting/greeting.go",
|
||||||
|
MimeType: "text/x-go",
|
||||||
|
Text: "package greeting\n\nfunc Hello(name string) string {\n\treturn \"Hello, \" + name + \"!\"\n}",
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Binary Resource Example:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Register a binary resource (like an image)
|
||||||
|
server.RegisterResource(mcp.Resource{
|
||||||
|
Name: "example-image",
|
||||||
|
URI: "file:///example/image.png",
|
||||||
|
Description: "An example image",
|
||||||
|
MimeType: "image/png",
|
||||||
|
Handler: func(ctx context.Context) (mcp.ResourceContent, error) {
|
||||||
|
// Read image from file or generate it
|
||||||
|
imageData := "base64EncodedImageData..." // Base64 encoded image data
|
||||||
|
|
||||||
|
return mcp.ResourceContent{
|
||||||
|
URI: "file:///example/image.png",
|
||||||
|
MimeType: "image/png",
|
||||||
|
Blob: imageData, // For binary data
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Using Resources in Prompts
|
||||||
|
|
||||||
|
You can embed resources in prompt responses to create rich interactions with proper MCP-compliant structure:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Register a prompt that embeds a resource
|
||||||
|
server.RegisterPrompt(mcp.Prompt{
|
||||||
|
Name: "resource-example",
|
||||||
|
Description: "A prompt that embeds a resource",
|
||||||
|
Arguments: []mcp.PromptArgument{
|
||||||
|
{
|
||||||
|
Name: "file_type",
|
||||||
|
Description: "Type of file to show (rust or go)",
|
||||||
|
Required: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) {
|
||||||
|
var req struct {
|
||||||
|
FileType string `json:"file_type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mcp.ParseArguments(args, &req); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse args: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resourceURI, mimeType, fileContent string
|
||||||
|
if req.FileType == "rust" {
|
||||||
|
resourceURI = "file:///project/src/main.rs"
|
||||||
|
mimeType = "text/x-rust"
|
||||||
|
fileContent = "fn main() {\n println!(\"Hello world!\");\n}"
|
||||||
|
} else {
|
||||||
|
resourceURI = "file:///project/src/main.go"
|
||||||
|
mimeType = "text/x-go"
|
||||||
|
fileContent = "package main\n\nimport \"fmt\"\n\nfunc main() {\n\tfmt.Println(\"Hello, world!\")\n}"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create message with embedded resource using proper MCP format
|
||||||
|
return []mcp.PromptMessage{
|
||||||
|
{
|
||||||
|
Role: mcp.RoleUser,
|
||||||
|
Content: mcp.TextContent{
|
||||||
|
Text: fmt.Sprintf("Can you explain this %s code?", req.FileType),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: mcp.RoleAssistant,
|
||||||
|
Content: mcp.EmbeddedResource{
|
||||||
|
Type: mcp.ContentTypeResource,
|
||||||
|
Resource: struct {
|
||||||
|
URI string `json:"uri"`
|
||||||
|
MimeType string `json:"mimeType"`
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
Blob string `json:"blob,omitempty"`
|
||||||
|
}{
|
||||||
|
URI: resourceURI,
|
||||||
|
MimeType: mimeType,
|
||||||
|
Text: fileContent,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: mcp.RoleAssistant,
|
||||||
|
Content: mcp.TextContent{
|
||||||
|
Text: fmt.Sprintf("Above is a simple Hello World example in %s. Let me explain how it works.", req.FileType),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Multiple File Resources Example
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Register a prompt that demonstrates embedding multiple resource files
|
||||||
|
server.RegisterPrompt(mcp.Prompt{
|
||||||
|
Name: "go-code-example",
|
||||||
|
Description: "A prompt that correctly embeds multiple resource files",
|
||||||
|
Arguments: []mcp.PromptArgument{
|
||||||
|
{
|
||||||
|
Name: "format",
|
||||||
|
Description: "How to format the code display",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) {
|
||||||
|
var req struct {
|
||||||
|
Format string `json:"format,optional"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mcp.ParseArguments(args, &req); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse args: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the Go code for multiple files
|
||||||
|
var mainGoText string = "package main\n\nimport (\n\t\"fmt\"\n\t\"./greeting\"\n)\n\nfunc main() {\n\tfmt.Println(greeting.Hello(\"world\"))\n}"
|
||||||
|
var greetingGoText string = "package greeting\n\nfunc Hello(name string) string {\n\treturn \"Hello, \" + name + \"!\"\n}"
|
||||||
|
|
||||||
|
// Create message with properly formatted embedded resource per MCP spec
|
||||||
|
messages := []mcp.PromptMessage{
|
||||||
|
{
|
||||||
|
Role: mcp.RoleUser,
|
||||||
|
Content: mcp.TextContent{
|
||||||
|
Text: "Show me a simple Go example with proper imports.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: mcp.RoleAssistant,
|
||||||
|
Content: mcp.TextContent{
|
||||||
|
Text: "Here's a simple Go example project:",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: mcp.RoleAssistant,
|
||||||
|
Content: mcp.EmbeddedResource{
|
||||||
|
Type: mcp.ContentTypeResource,
|
||||||
|
Resource: struct {
|
||||||
|
URI string `json:"uri"`
|
||||||
|
MimeType string `json:"mimeType"`
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
Blob string `json:"blob,omitempty"`
|
||||||
|
}{
|
||||||
|
URI: "file:///project/src/main.go",
|
||||||
|
MimeType: "text/x-go",
|
||||||
|
Text: mainGoText,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add explanation and additional file if requested
|
||||||
|
if req.Format == "with_explanation" {
|
||||||
|
messages = append(messages, mcp.PromptMessage{
|
||||||
|
Role: mcp.RoleAssistant,
|
||||||
|
Content: mcp.TextContent{
|
||||||
|
Text: "This example demonstrates a simple Go application with modular structure. The main.go file imports from a local 'greeting' package that provides the Hello function.",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
// Also show the greeting.go file with correct resource format
|
||||||
|
messages = append(messages, mcp.PromptMessage{
|
||||||
|
Role: mcp.RoleAssistant,
|
||||||
|
Content: mcp.EmbeddedResource{
|
||||||
|
Type: mcp.ContentTypeResource,
|
||||||
|
Resource: struct {
|
||||||
|
URI string `json:"uri"`
|
||||||
|
MimeType string `json:"mimeType"`
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
Blob string `json:"blob,omitempty"`
|
||||||
|
}{
|
||||||
|
URI: "file:///project/src/greeting/greeting.go",
|
||||||
|
MimeType: "text/x-go",
|
||||||
|
Text: greetingGoText,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return messages, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Complete Application Example
|
||||||
|
|
||||||
|
Here's a complete example demonstrating all the components:
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/zeromicro/go-zero/core/conf"
|
||||||
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
|
"github.com/zeromicro/go-zero/mcp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// Load configuration
|
||||||
|
var c mcp.McpConf
|
||||||
|
if err := conf.Load("config.yaml", &c); err != nil {
|
||||||
|
log.Fatalf("Failed to load config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up logging
|
||||||
|
logx.DisableStat()
|
||||||
|
|
||||||
|
// Create MCP server
|
||||||
|
server := mcp.NewMcpServer(c)
|
||||||
|
defer server.Stop()
|
||||||
|
|
||||||
|
// Register a simple echo tool
|
||||||
|
echoTool := mcp.Tool{
|
||||||
|
Name: "echo",
|
||||||
|
Description: "Echoes back the message provided by the user",
|
||||||
|
InputSchema: mcp.InputSchema{
|
||||||
|
Properties: map[string]any{
|
||||||
|
"message": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"description": "The message to echo back",
|
||||||
|
},
|
||||||
|
"prefix": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"description": "Optional prefix to add to the echoed message",
|
||||||
|
"default": "Echo: ",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Required: []string{"message"},
|
||||||
|
},
|
||||||
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
|
var req struct {
|
||||||
|
Message string `json:"message"`
|
||||||
|
Prefix string `json:"prefix,optional"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mcp.ParseArguments(params, &req); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse args: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
prefix := "Echo: "
|
||||||
|
if len(req.Prefix) > 0 {
|
||||||
|
prefix = req.Prefix
|
||||||
|
}
|
||||||
|
|
||||||
|
return prefix + req.Message, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
server.RegisterTool(echoTool)
|
||||||
|
|
||||||
|
// Register a static prompt
|
||||||
|
server.RegisterPrompt(mcp.Prompt{
|
||||||
|
Name: "greeting",
|
||||||
|
Description: "A simple greeting prompt",
|
||||||
|
Arguments: []mcp.PromptArgument{
|
||||||
|
{
|
||||||
|
Name: "name",
|
||||||
|
Description: "The name to greet",
|
||||||
|
Required: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Content: "Hello {{name}}! How can I assist you today?",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Register a dynamic prompt
|
||||||
|
server.RegisterPrompt(mcp.Prompt{
|
||||||
|
Name: "dynamic-prompt",
|
||||||
|
Description: "A prompt that uses a handler to generate dynamic content",
|
||||||
|
Arguments: []mcp.PromptArgument{
|
||||||
|
{
|
||||||
|
Name: "username",
|
||||||
|
Description: "User's name for personalized greeting",
|
||||||
|
Required: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "topic",
|
||||||
|
Description: "Topic of expertise",
|
||||||
|
Required: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) {
|
||||||
|
var req struct {
|
||||||
|
Username string `json:"username"`
|
||||||
|
Topic string `json:"topic"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mcp.ParseArguments(args, &req); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse args: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create messages with current time
|
||||||
|
currentTime := time.Now().Format(time.RFC1123)
|
||||||
|
return []mcp.PromptMessage{
|
||||||
|
{
|
||||||
|
Role: mcp.RoleUser,
|
||||||
|
Content: mcp.TextContent{
|
||||||
|
Text: fmt.Sprintf("Hello, I'm %s and I'd like to learn about %s.", req.Username, req.Topic),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: mcp.RoleAssistant,
|
||||||
|
Content: mcp.TextContent{
|
||||||
|
Text: fmt.Sprintf("Hello %s! I'm an AI assistant and I'll help you learn about %s. The current time is %s.",
|
||||||
|
req.Username, req.Topic, currentTime),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
// Register a resource
|
||||||
|
server.RegisterResource(mcp.Resource{
|
||||||
|
Name: "example-doc",
|
||||||
|
URI: "file:///example/doc.txt",
|
||||||
|
Description: "An example document",
|
||||||
|
MimeType: "text/plain",
|
||||||
|
Handler: func(ctx context.Context) (mcp.ResourceContent, error) {
|
||||||
|
return mcp.ResourceContent{
|
||||||
|
URI: "file:///example/doc.txt",
|
||||||
|
MimeType: "text/plain",
|
||||||
|
Text: "This is the content of the example document.",
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
// Start the server
|
||||||
|
fmt.Printf("Starting MCP server on %s:%d\n", c.Host, c.Port)
|
||||||
|
server.Start()
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Error Handling
|
||||||
|
|
||||||
|
The MCP implementation provides comprehensive error handling:
|
||||||
|
|
||||||
|
- Tool execution errors are properly reported back to clients
|
||||||
|
- Missing or invalid parameters are detected and reported with appropriate error codes
|
||||||
|
- Resource and prompt lookup failures are handled gracefully
|
||||||
|
- Timeout handling for long-running tool executions using context
|
||||||
|
- Panic recovery to prevent server crashes
|
||||||
|
|
||||||
|
## Advanced Features
|
||||||
|
|
||||||
|
- **Annotations**: Add audience and priority metadata to content
|
||||||
|
- **Content Types**: Support for text, images, audio, and other content formats
|
||||||
|
- **Embedded Resources**: Include file resources directly in prompt responses
|
||||||
|
- **Context Awareness**: All handlers receive context.Context for timeout and cancellation support
|
||||||
|
- **Progress Tokens**: Support for tracking progress of long-running operations
|
||||||
|
- **Customizable Timeouts**: Configure execution timeouts for tools and operations
|
||||||
|
|
||||||
|
## Performance Considerations
|
||||||
|
|
||||||
|
- Tool execution runs with configurable timeouts to prevent blocking
|
||||||
|
- Efficient client tracking and cleanup to prevent resource leaks
|
||||||
|
- Proper concurrency handling with mutex protection for shared resources
|
||||||
|
- Buffered message channels to prevent blocking on client message delivery
|
||||||
940
mcp/server.go
Normal file
940
mcp/server.go
Normal file
@@ -0,0 +1,940 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
|
"github.com/zeromicro/go-zero/rest"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewMcpServer(c McpConf) McpServer {
|
||||||
|
var server *rest.Server
|
||||||
|
if len(c.Mcp.Cors) == 0 {
|
||||||
|
server = rest.MustNewServer(c.RestConf)
|
||||||
|
} else {
|
||||||
|
server = rest.MustNewServer(c.RestConf, rest.WithCors(c.Mcp.Cors...))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(c.Mcp.Name) == 0 {
|
||||||
|
c.Mcp.Name = c.Name
|
||||||
|
}
|
||||||
|
if len(c.Mcp.BaseUrl) == 0 {
|
||||||
|
c.Mcp.BaseUrl = fmt.Sprintf("http://localhost:%d", c.Port)
|
||||||
|
}
|
||||||
|
|
||||||
|
s := &sseMcpServer{
|
||||||
|
conf: c,
|
||||||
|
server: server,
|
||||||
|
clients: make(map[string]*mcpClient),
|
||||||
|
tools: make(map[string]Tool),
|
||||||
|
prompts: make(map[string]Prompt),
|
||||||
|
resources: make(map[string]Resource),
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSE endpoint for real-time updates
|
||||||
|
s.server.AddRoute(rest.Route{
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Path: s.conf.Mcp.SseEndpoint,
|
||||||
|
Handler: s.handleSSE,
|
||||||
|
}, rest.WithSSE(), rest.WithTimeout(c.Mcp.SseTimeout))
|
||||||
|
|
||||||
|
// JSON-RPC message endpoint for regular requests
|
||||||
|
s.server.AddRoute(rest.Route{
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Path: s.conf.Mcp.MessageEndpoint,
|
||||||
|
Handler: s.handleRequest,
|
||||||
|
}, rest.WithTimeout(c.Mcp.MessageTimeout))
|
||||||
|
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterPrompt registers a new prompt with the server
|
||||||
|
func (s *sseMcpServer) RegisterPrompt(prompt Prompt) {
|
||||||
|
s.promptsLock.Lock()
|
||||||
|
s.prompts[prompt.Name] = prompt
|
||||||
|
s.promptsLock.Unlock()
|
||||||
|
// Notify clients about the new prompt
|
||||||
|
s.broadcast(eventPromptsListChanged, map[string][]Prompt{keyPrompts: {prompt}})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterResource registers a new resource with the server
|
||||||
|
func (s *sseMcpServer) RegisterResource(resource Resource) {
|
||||||
|
s.resourcesLock.Lock()
|
||||||
|
s.resources[resource.URI] = resource
|
||||||
|
s.resourcesLock.Unlock()
|
||||||
|
// Notify clients about the new resource
|
||||||
|
s.broadcast(eventResourcesListChanged, map[string][]Resource{keyResources: {resource}})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterTool registers a new tool with the server
|
||||||
|
func (s *sseMcpServer) RegisterTool(tool Tool) error {
|
||||||
|
if tool.Handler == nil {
|
||||||
|
return fmt.Errorf("tool '%s' has no handler function", tool.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.toolsLock.Lock()
|
||||||
|
s.tools[tool.Name] = tool
|
||||||
|
s.toolsLock.Unlock()
|
||||||
|
// Notify clients about the new tool
|
||||||
|
s.broadcast(eventToolsListChanged, map[string][]Tool{keyTools: {tool}})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start implements McpServer.
|
||||||
|
func (s *sseMcpServer) Start() {
|
||||||
|
s.server.Start()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *sseMcpServer) Stop() {
|
||||||
|
s.server.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
// broadcast sends a message to all connected clients
|
||||||
|
// It uses Server-Sent Events (SSE) format for real-time communication
|
||||||
|
func (s *sseMcpServer) broadcast(event string, data any) {
|
||||||
|
jsonData, err := json.Marshal(data)
|
||||||
|
if err != nil {
|
||||||
|
logx.Errorf("Failed to marshal broadcast data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lock only while reading the clients map
|
||||||
|
s.clientsLock.Lock()
|
||||||
|
clients := make([]*mcpClient, 0, len(s.clients))
|
||||||
|
for _, client := range s.clients {
|
||||||
|
clients = append(clients, client)
|
||||||
|
}
|
||||||
|
s.clientsLock.Unlock()
|
||||||
|
|
||||||
|
clientCount := len(clients)
|
||||||
|
if clientCount == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logx.Infof("Broadcasting event '%s' to %d clients", event, clientCount)
|
||||||
|
|
||||||
|
// Use CRLF line endings as per SSE specification
|
||||||
|
message := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", event, string(jsonData))
|
||||||
|
|
||||||
|
// Send messages without holding the lock
|
||||||
|
for _, client := range clients {
|
||||||
|
select {
|
||||||
|
case client.channel <- message:
|
||||||
|
// Message sent successfully
|
||||||
|
default:
|
||||||
|
// Channel buffer is full, log warning and continue
|
||||||
|
logx.Errorf("Client channel buffer full, dropping message for client %s", client.id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanupClient removes a client from the active clients map
|
||||||
|
func (s *sseMcpServer) cleanupClient(sessionId string) {
|
||||||
|
s.clientsLock.Lock()
|
||||||
|
defer s.clientsLock.Unlock()
|
||||||
|
|
||||||
|
if client, exists := s.clients[sessionId]; exists {
|
||||||
|
// Close the channel to signal any goroutines waiting on it
|
||||||
|
close(client.channel)
|
||||||
|
// Remove from active clients
|
||||||
|
delete(s.clients, sessionId)
|
||||||
|
logx.Infof("Cleaned up client %s (remaining clients: %d)", sessionId, len(s.clients))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleRequest handles MCP JSON-RPC requests
|
||||||
|
func (s *sseMcpServer) handleRequest(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Extract sessionId from query parameters
|
||||||
|
sessionId := r.URL.Query().Get(sessionIdKey)
|
||||||
|
if len(sessionId) == 0 {
|
||||||
|
http.Error(w, fmt.Sprintf("Missing %s", sessionIdKey), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the client with this sessionId exists
|
||||||
|
s.clientsLock.Lock()
|
||||||
|
client, exists := s.clients[sessionId]
|
||||||
|
s.clientsLock.Unlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
http.Error(w, fmt.Sprintf("Invalid or expired %s", sessionIdKey), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req Request
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
http.Error(w, "Invalid request", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// For notification methods (no ID), we don't send a response
|
||||||
|
isNotification, err := req.isNotification()
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Invalid request.ID", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusAccepted)
|
||||||
|
|
||||||
|
// Special handling for initialization sequence
|
||||||
|
// Always allow initialize and notifications/initialized regardless of client state
|
||||||
|
if req.Method == methodInitialize {
|
||||||
|
logx.Infof("Processing initialize request with ID: %v", req.ID)
|
||||||
|
s.processInitialize(r.Context(), client, req)
|
||||||
|
logx.Infof("Sent initialize response for ID: %v, waiting for notifications/initialized", req.ID)
|
||||||
|
return
|
||||||
|
} else if req.Method == methodNotificationsInitialized {
|
||||||
|
// Handle initialized notification
|
||||||
|
logx.Info("Received notifications/initialized notification")
|
||||||
|
if !isNotification {
|
||||||
|
s.sendErrorResponse(r.Context(), client, req.ID,
|
||||||
|
"Method should be used as a notification", errCodeInvalidRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.processNotificationInitialized(client)
|
||||||
|
return
|
||||||
|
} else if !client.initialized && req.Method != methodNotificationsCancelled {
|
||||||
|
// Block most requests until client is initialized (except for cancellations)
|
||||||
|
s.sendErrorResponse(r.Context(), client, req.ID,
|
||||||
|
"Client not fully initialized, waiting for notifications/initialized",
|
||||||
|
errCodeClientNotInitialized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process normal requests only after initialization
|
||||||
|
switch req.Method {
|
||||||
|
case methodToolsCall:
|
||||||
|
logx.Infof("Received tools call request with ID: %v", req.ID)
|
||||||
|
s.processToolCall(r.Context(), client, req)
|
||||||
|
logx.Infof("Sent tools call response for ID: %v", req.ID)
|
||||||
|
case methodToolsList:
|
||||||
|
logx.Infof("Processing tools/list request with ID: %v", req.ID)
|
||||||
|
s.processListTools(r.Context(), client, req)
|
||||||
|
logx.Infof("Sent tools/list response for ID: %v", req.ID)
|
||||||
|
case methodPromptsList:
|
||||||
|
logx.Infof("Processing prompts/list request with ID: %v", req.ID)
|
||||||
|
s.processListPrompts(r.Context(), client, req)
|
||||||
|
logx.Infof("Sent prompts/list response for ID: %v", req.ID)
|
||||||
|
case methodPromptsGet:
|
||||||
|
logx.Infof("Processing prompts/get request with ID: %v", req.ID)
|
||||||
|
s.processGetPrompt(r.Context(), client, req)
|
||||||
|
logx.Infof("Sent prompts/get response for ID: %v", req.ID)
|
||||||
|
case methodResourcesList:
|
||||||
|
logx.Infof("Processing resources/list request with ID: %v", req.ID)
|
||||||
|
s.processListResources(r.Context(), client, req)
|
||||||
|
logx.Infof("Sent resources/list response for ID: %v", req.ID)
|
||||||
|
case methodResourcesRead:
|
||||||
|
logx.Infof("Processing resources/read request with ID: %v", req.ID)
|
||||||
|
s.processResourcesRead(r.Context(), client, req)
|
||||||
|
logx.Infof("Sent resources/read response for ID: %v", req.ID)
|
||||||
|
case methodResourcesSubscribe:
|
||||||
|
logx.Infof("Processing resources/subscribe request with ID: %v", req.ID)
|
||||||
|
s.processResourceSubscribe(r.Context(), client, req)
|
||||||
|
logx.Infof("Sent resources/subscribe response for ID: %v", req.ID)
|
||||||
|
case methodPing:
|
||||||
|
logx.Infof("Processing ping request with ID: %v", req.ID)
|
||||||
|
s.processPing(r.Context(), client, req)
|
||||||
|
case methodNotificationsCancelled:
|
||||||
|
logx.Infof("Received notifications/cancelled notification: %v", req.ID)
|
||||||
|
s.processNotificationCancelled(r.Context(), client, req)
|
||||||
|
default:
|
||||||
|
logx.Infof("Unknown method: %s from client: %v", req.Method, req.ID)
|
||||||
|
s.sendErrorResponse(r.Context(), client, req.ID, "Method not found", errCodeMethodNotFound)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleSSE handles Server-Sent Events connections
|
||||||
|
func (s *sseMcpServer) handleSSE(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Generate a unique session ID for this client
|
||||||
|
sessionId := uuid.New().String()
|
||||||
|
|
||||||
|
// Create new client with buffered channel to prevent blocking
|
||||||
|
client := &mcpClient{
|
||||||
|
id: sessionId,
|
||||||
|
channel: make(chan string, eventChanSize),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add client to active clients map
|
||||||
|
s.clientsLock.Lock()
|
||||||
|
s.clients[sessionId] = client
|
||||||
|
activeClients := len(s.clients)
|
||||||
|
s.clientsLock.Unlock()
|
||||||
|
|
||||||
|
logx.Infof("New SSE connection established for client %s (active clients: %d)",
|
||||||
|
sessionId, activeClients)
|
||||||
|
|
||||||
|
// Set proper SSE headers
|
||||||
|
w.Header().Set("Transfer-Encoding", "chunked")
|
||||||
|
|
||||||
|
// Enable streaming
|
||||||
|
flusher, ok := w.(http.Flusher)
|
||||||
|
if !ok {
|
||||||
|
logx.Error("Streaming not supported by the underlying http.ResponseWriter")
|
||||||
|
http.Error(w, "Streaming not supported", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send the message endpoint URL to the client
|
||||||
|
endpoint := fmt.Sprintf("%s%s?%s=%s",
|
||||||
|
s.conf.Mcp.BaseUrl, s.conf.Mcp.MessageEndpoint, sessionIdKey, sessionId)
|
||||||
|
|
||||||
|
// Format and send the endpoint message
|
||||||
|
endpointMsg := formatSSEMessage(eventEndpoint, []byte(endpoint))
|
||||||
|
if _, err := fmt.Fprint(w, endpointMsg); err != nil {
|
||||||
|
logx.Errorf("Failed to send endpoint message to client %s: %v", sessionId, err)
|
||||||
|
s.cleanupClient(sessionId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
flusher.Flush()
|
||||||
|
|
||||||
|
// Set up keep-alive ping and client cleanup
|
||||||
|
ticker := time.NewTicker(pingInterval.Load())
|
||||||
|
defer func() {
|
||||||
|
ticker.Stop()
|
||||||
|
s.cleanupClient(sessionId)
|
||||||
|
logx.Infof("SSE connection closed for client %s", sessionId)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Message processing loop
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case message, ok := <-client.channel:
|
||||||
|
if !ok {
|
||||||
|
// Channel was closed, end connection
|
||||||
|
logx.Infof("Client channel was closed for %s", sessionId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write message to the response
|
||||||
|
if _, err := fmt.Fprint(w, message); err != nil {
|
||||||
|
logx.Infof("Failed to write message to client %s: %v", sessionId, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
flusher.Flush()
|
||||||
|
case <-ticker.C:
|
||||||
|
// Send keep-alive ping to maintain connection
|
||||||
|
ping := fmt.Sprintf(`{"type":"ping","timestamp":"%s"}`, time.Now().String())
|
||||||
|
pingMsg := formatSSEMessage("ping", []byte(ping))
|
||||||
|
if _, err := fmt.Fprint(w, pingMsg); err != nil {
|
||||||
|
logx.Errorf("Failed to send ping to client %s, closing connection: %v", sessionId, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
flusher.Flush()
|
||||||
|
case <-r.Context().Done():
|
||||||
|
// Client disconnected or request was canceled or timed out
|
||||||
|
logx.Infof("Client %s disconnected: context done", sessionId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// processInitialize processes the initialize request
|
||||||
|
func (s *sseMcpServer) processInitialize(ctx context.Context, client *mcpClient, req Request) {
|
||||||
|
// Create a proper JSON-RPC response that preserves the client's request ID
|
||||||
|
result := initializationResponse{
|
||||||
|
ProtocolVersion: s.conf.Mcp.ProtocolVersion,
|
||||||
|
Capabilities: capabilities{
|
||||||
|
Prompts: struct {
|
||||||
|
ListChanged bool `json:"listChanged"`
|
||||||
|
}{
|
||||||
|
ListChanged: true,
|
||||||
|
},
|
||||||
|
Resources: struct {
|
||||||
|
Subscribe bool `json:"subscribe"`
|
||||||
|
ListChanged bool `json:"listChanged"`
|
||||||
|
}{
|
||||||
|
Subscribe: true,
|
||||||
|
ListChanged: true,
|
||||||
|
},
|
||||||
|
Tools: struct {
|
||||||
|
ListChanged bool `json:"listChanged"`
|
||||||
|
}{
|
||||||
|
ListChanged: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ServerInfo: serverInfo{
|
||||||
|
Name: s.conf.Mcp.Name,
|
||||||
|
Version: s.conf.Mcp.Version,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mark client as initialized
|
||||||
|
client.initialized = true
|
||||||
|
|
||||||
|
// Send response with client's original request ID
|
||||||
|
s.sendResponse(ctx, client, req.ID, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// processListTools processes the tools/list request
|
||||||
|
func (s *sseMcpServer) processListTools(ctx context.Context, client *mcpClient, req Request) {
|
||||||
|
// Extract pagination params if any
|
||||||
|
var nextCursor string
|
||||||
|
var progressToken any
|
||||||
|
|
||||||
|
// Extract meta data including progress token
|
||||||
|
if req.Params != nil {
|
||||||
|
var metaParams struct {
|
||||||
|
Cursor string `json:"cursor"`
|
||||||
|
Meta struct {
|
||||||
|
ProgressToken any `json:"progressToken"`
|
||||||
|
} `json:"_meta"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(req.Params, &metaParams); err == nil {
|
||||||
|
if len(metaParams.Cursor) > 0 {
|
||||||
|
nextCursor = metaParams.Cursor
|
||||||
|
}
|
||||||
|
progressToken = metaParams.Meta.ProgressToken
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var toolsList []Tool
|
||||||
|
s.toolsLock.Lock()
|
||||||
|
for _, tool := range s.tools {
|
||||||
|
if len(tool.InputSchema.Type) == 0 {
|
||||||
|
tool.InputSchema.Type = ContentTypeObject
|
||||||
|
}
|
||||||
|
toolsList = append(toolsList, tool)
|
||||||
|
}
|
||||||
|
s.toolsLock.Unlock()
|
||||||
|
|
||||||
|
result := ListToolsResult{
|
||||||
|
PaginatedResult: PaginatedResult{
|
||||||
|
Result: Result{},
|
||||||
|
NextCursor: Cursor(nextCursor),
|
||||||
|
},
|
||||||
|
Tools: toolsList,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add meta information if progress token was provided
|
||||||
|
if progressToken != nil {
|
||||||
|
result.Result.Meta = map[string]any{
|
||||||
|
progressTokenKey: progressToken,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.sendResponse(ctx, client, req.ID, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// processListPrompts processes the prompts/list request
|
||||||
|
func (s *sseMcpServer) processListPrompts(ctx context.Context, client *mcpClient, req Request) {
|
||||||
|
// Extract pagination params if any
|
||||||
|
var nextCursor string
|
||||||
|
if req.Params != nil {
|
||||||
|
var cursorParams struct {
|
||||||
|
Cursor string `json:"cursor"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(req.Params, &cursorParams); err == nil && cursorParams.Cursor != "" {
|
||||||
|
// If we have a valid cursor, we could use it for pagination
|
||||||
|
// For now, we're not actually implementing pagination, so this is just
|
||||||
|
// to show how it would be extracted from the request
|
||||||
|
_ = cursorParams.Cursor
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare prompt list
|
||||||
|
var promptsList []Prompt
|
||||||
|
s.promptsLock.Lock()
|
||||||
|
for _, prompt := range s.prompts {
|
||||||
|
promptsList = append(promptsList, prompt)
|
||||||
|
}
|
||||||
|
s.promptsLock.Unlock()
|
||||||
|
|
||||||
|
// In a real implementation, you'd handle pagination here
|
||||||
|
// For now, we'll return all prompts at once
|
||||||
|
result := struct {
|
||||||
|
Prompts []Prompt `json:"prompts"`
|
||||||
|
NextCursor string `json:"nextCursor,omitempty"`
|
||||||
|
Meta *struct{} `json:"_meta,omitempty"`
|
||||||
|
}{
|
||||||
|
Prompts: promptsList,
|
||||||
|
NextCursor: nextCursor,
|
||||||
|
}
|
||||||
|
|
||||||
|
s.sendResponse(ctx, client, req.ID, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// processListResources processes the resources/list request
|
||||||
|
func (s *sseMcpServer) processListResources(ctx context.Context, client *mcpClient, req Request) {
|
||||||
|
// Extract pagination params if any
|
||||||
|
var nextCursor string
|
||||||
|
var progressToken any
|
||||||
|
|
||||||
|
// Extract meta information including progress token if available
|
||||||
|
if req.Params != nil {
|
||||||
|
var metaParams PaginatedParams
|
||||||
|
if err := json.Unmarshal(req.Params, &metaParams); err == nil {
|
||||||
|
if len(metaParams.Cursor) > 0 {
|
||||||
|
nextCursor = metaParams.Cursor
|
||||||
|
}
|
||||||
|
progressToken = metaParams.Meta.ProgressToken
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var resourcesList []Resource
|
||||||
|
s.resourcesLock.Lock()
|
||||||
|
for _, resource := range s.resources {
|
||||||
|
// Create a copy without the handler function which shouldn't be sent to clients
|
||||||
|
resourceCopy := Resource{
|
||||||
|
URI: resource.URI,
|
||||||
|
Name: resource.Name,
|
||||||
|
Description: resource.Description,
|
||||||
|
MimeType: resource.MimeType,
|
||||||
|
}
|
||||||
|
resourcesList = append(resourcesList, resourceCopy)
|
||||||
|
}
|
||||||
|
s.resourcesLock.Unlock()
|
||||||
|
|
||||||
|
// Create proper ResourcesListResult according to MCP specification
|
||||||
|
result := ResourcesListResult{
|
||||||
|
PaginatedResult: PaginatedResult{
|
||||||
|
Result: Result{},
|
||||||
|
NextCursor: Cursor(nextCursor),
|
||||||
|
},
|
||||||
|
Resources: resourcesList,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add meta information if progress token was provided
|
||||||
|
if progressToken != nil {
|
||||||
|
result.Result.Meta = map[string]any{
|
||||||
|
progressTokenKey: progressToken,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.sendResponse(ctx, client, req.ID, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// processGetPrompt processes the prompts/get request
|
||||||
|
func (s *sseMcpServer) processGetPrompt(ctx context.Context, client *mcpClient, req Request) {
|
||||||
|
type GetPromptParams struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Arguments map[string]string `json:"arguments,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var params GetPromptParams
|
||||||
|
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
||||||
|
s.sendErrorResponse(ctx, client, req.ID, "Invalid parameters", errCodeInvalidParams)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if prompt exists
|
||||||
|
s.promptsLock.Lock()
|
||||||
|
prompt, exists := s.prompts[params.Name]
|
||||||
|
s.promptsLock.Unlock()
|
||||||
|
if !exists {
|
||||||
|
message := fmt.Sprintf("Prompt '%s' not found", params.Name)
|
||||||
|
s.sendErrorResponse(ctx, client, req.ID, message, errCodeInvalidParams)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logx.Infof("Processing prompt request: %s with %d arguments", prompt.Name, len(params.Arguments))
|
||||||
|
|
||||||
|
// Validate required arguments
|
||||||
|
missingArgs := validatePromptArguments(prompt, params.Arguments)
|
||||||
|
if len(missingArgs) > 0 {
|
||||||
|
message := fmt.Sprintf("Missing required arguments: %s", strings.Join(missingArgs, ", "))
|
||||||
|
s.sendErrorResponse(ctx, client, req.ID, message, errCodeInvalidParams)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure arguments are initialized to an empty map if nil
|
||||||
|
if params.Arguments == nil {
|
||||||
|
params.Arguments = make(map[string]string)
|
||||||
|
}
|
||||||
|
args := params.Arguments
|
||||||
|
|
||||||
|
// Generate messages using handler or static content
|
||||||
|
var messages []PromptMessage
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if prompt.Handler != nil {
|
||||||
|
// Use dynamic handler to generate messages
|
||||||
|
messages, err = prompt.Handler(ctx, args)
|
||||||
|
if err != nil {
|
||||||
|
logx.Errorf("Error from prompt handler: %v", err)
|
||||||
|
s.sendErrorResponse(ctx, client, req.ID,
|
||||||
|
fmt.Sprintf("Error generating prompt content: %v", err), errCodeInternalError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// No handler, generate messages from static content
|
||||||
|
var messageText string
|
||||||
|
if len(prompt.Content) > 0 {
|
||||||
|
messageText = prompt.Content
|
||||||
|
|
||||||
|
// Apply argument substitutions to static content
|
||||||
|
for key, value := range args {
|
||||||
|
placeholder := fmt.Sprintf("{{%s}}", key)
|
||||||
|
messageText = strings.Replace(messageText, placeholder, value, -1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a single user message with the content
|
||||||
|
messages = []PromptMessage{
|
||||||
|
{
|
||||||
|
Role: RoleUser,
|
||||||
|
Content: TextContent{
|
||||||
|
Text: messageText,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Construct the response according to MCP spec
|
||||||
|
result := struct {
|
||||||
|
Description string `json:"description,omitempty"`
|
||||||
|
Messages []PromptMessage `json:"messages"`
|
||||||
|
}{
|
||||||
|
Description: prompt.Description,
|
||||||
|
Messages: toTypedPromptMessages(messages),
|
||||||
|
}
|
||||||
|
|
||||||
|
s.sendResponse(ctx, client, req.ID, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// processToolCall processes the tools/call request
|
||||||
|
func (s *sseMcpServer) processToolCall(ctx context.Context, client *mcpClient, req Request) {
|
||||||
|
var toolCallParams struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Arguments map[string]any `json:"arguments,omitempty"`
|
||||||
|
Meta struct {
|
||||||
|
ProgressToken any `json:"progressToken"`
|
||||||
|
} `json:"_meta,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle different types of req.Params
|
||||||
|
// If it's a RawMessage (JSON), unmarshal it
|
||||||
|
if err := json.Unmarshal(req.Params, &toolCallParams); err != nil {
|
||||||
|
logx.Errorf("Failed to unmarshal tool call params: %v", err)
|
||||||
|
s.sendErrorResponse(ctx, client, req.ID, "Invalid tool call parameters", errCodeInvalidParams)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract progress token if available
|
||||||
|
progressToken := toolCallParams.Meta.ProgressToken
|
||||||
|
|
||||||
|
// Find the requested tool
|
||||||
|
s.toolsLock.Lock()
|
||||||
|
tool, exists := s.tools[toolCallParams.Name]
|
||||||
|
s.toolsLock.Unlock()
|
||||||
|
if !exists {
|
||||||
|
s.sendErrorResponse(ctx, client, req.ID, fmt.Sprintf("Tool '%s' not found",
|
||||||
|
toolCallParams.Name), errCodeInvalidParams)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log parameters before execution
|
||||||
|
logx.Infof("Executing tool '%s' with arguments: %#v", toolCallParams.Name, toolCallParams.Arguments)
|
||||||
|
|
||||||
|
// Execute the tool handler with timeout handling
|
||||||
|
var result any
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// Create a channel to receive the result
|
||||||
|
// make sure to have 1 size buffer to avoid channel leak if timeout
|
||||||
|
resultCh := make(chan struct {
|
||||||
|
result any
|
||||||
|
err error
|
||||||
|
}, 1)
|
||||||
|
|
||||||
|
// Execute the tool handler in a goroutine
|
||||||
|
go func() {
|
||||||
|
toolResult, toolErr := tool.Handler(ctx, toolCallParams.Arguments)
|
||||||
|
resultCh <- struct {
|
||||||
|
result any
|
||||||
|
err error
|
||||||
|
}{
|
||||||
|
result: toolResult,
|
||||||
|
err: toolErr,
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for either the result or a timeout
|
||||||
|
select {
|
||||||
|
case res := <-resultCh:
|
||||||
|
result = res.result
|
||||||
|
err = res.err
|
||||||
|
case <-ctx.Done():
|
||||||
|
// Handle request timeout
|
||||||
|
logx.Errorf("Tool execution timed out after %v: %s", s.conf.Mcp.MessageTimeout, toolCallParams.Name)
|
||||||
|
s.sendErrorResponse(ctx, client, req.ID, "Tool execution timed out", errCodeTimeout)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the base result structure with metadata
|
||||||
|
callToolResult := CallToolResult{
|
||||||
|
Result: Result{},
|
||||||
|
Content: []any{},
|
||||||
|
IsError: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add meta information if progress token was provided
|
||||||
|
if progressToken != nil {
|
||||||
|
callToolResult.Result.Meta = map[string]any{
|
||||||
|
progressTokenKey: progressToken,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if there was an error during tool execution
|
||||||
|
if err != nil {
|
||||||
|
// According to the spec, for tool-level errors (as opposed to protocol-level errors),
|
||||||
|
// we should report them inside the result with isError=true
|
||||||
|
logx.Errorf("Tool execution reported error: %v", err)
|
||||||
|
|
||||||
|
callToolResult.Content = []any{
|
||||||
|
TextContent{
|
||||||
|
Text: fmt.Sprintf("Error: %v", err),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
callToolResult.IsError = true
|
||||||
|
s.sendResponse(ctx, client, req.ID, callToolResult)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Format the response according to the CallToolResult schema
|
||||||
|
switch v := result.(type) {
|
||||||
|
case string:
|
||||||
|
// Simple string becomes text content
|
||||||
|
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||||
|
Text: v,
|
||||||
|
Annotations: &Annotations{
|
||||||
|
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
case map[string]any:
|
||||||
|
// JSON-like object becomes formatted JSON text
|
||||||
|
jsonStr, err := json.Marshal(v)
|
||||||
|
if err != nil {
|
||||||
|
jsonStr = []byte(err.Error())
|
||||||
|
}
|
||||||
|
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||||
|
Text: string(jsonStr),
|
||||||
|
Annotations: &Annotations{
|
||||||
|
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
case TextContent:
|
||||||
|
callToolResult.Content = append(callToolResult.Content, v)
|
||||||
|
case ImageContent:
|
||||||
|
callToolResult.Content = append(callToolResult.Content, v)
|
||||||
|
case []any:
|
||||||
|
callToolResult.Content = v
|
||||||
|
case ToolResult:
|
||||||
|
// Handle legacy ToolResult type
|
||||||
|
switch v.Type {
|
||||||
|
case ContentTypeText:
|
||||||
|
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||||
|
Text: fmt.Sprintf("%v", v.Content),
|
||||||
|
Annotations: &Annotations{
|
||||||
|
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
case ContentTypeImage:
|
||||||
|
if imgData, ok := v.Content.(map[string]any); ok {
|
||||||
|
callToolResult.Content = append(callToolResult.Content, ImageContent{
|
||||||
|
Data: fmt.Sprintf("%v", imgData["data"]),
|
||||||
|
MimeType: fmt.Sprintf("%v", imgData["mimeType"]),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||||
|
Text: fmt.Sprintf("%v", v.Content),
|
||||||
|
Annotations: &Annotations{
|
||||||
|
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// For any other type, convert to string
|
||||||
|
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||||
|
Text: fmt.Sprintf("%v", v),
|
||||||
|
Annotations: &Annotations{
|
||||||
|
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
callToolResult.Content = toTypedContents(callToolResult.Content)
|
||||||
|
logx.Infof("Tool call result: %#v", callToolResult)
|
||||||
|
|
||||||
|
s.sendResponse(ctx, client, req.ID, callToolResult)
|
||||||
|
}
|
||||||
|
|
||||||
|
// processResourcesRead processes the resources/read request
|
||||||
|
func (s *sseMcpServer) processResourcesRead(ctx context.Context, client *mcpClient, req Request) {
|
||||||
|
var params ResourceReadParams
|
||||||
|
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
||||||
|
s.sendErrorResponse(ctx, client, req.ID, "Invalid parameters", errCodeInvalidParams)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find resource that matches the URI
|
||||||
|
s.resourcesLock.Lock()
|
||||||
|
resource, exists := s.resources[params.URI]
|
||||||
|
s.resourcesLock.Unlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
s.sendErrorResponse(ctx, client, req.ID, fmt.Sprintf("Resource with URI '%s' not found",
|
||||||
|
params.URI), errCodeResourceNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no handler is provided, return an empty content array
|
||||||
|
if resource.Handler == nil {
|
||||||
|
result := ResourceReadResult{
|
||||||
|
Contents: []ResourceContent{
|
||||||
|
{
|
||||||
|
URI: params.URI,
|
||||||
|
MimeType: resource.MimeType,
|
||||||
|
Text: "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
s.sendResponse(ctx, client, req.ID, result)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute the resource handler
|
||||||
|
content, err := resource.Handler(ctx)
|
||||||
|
if err != nil {
|
||||||
|
s.sendErrorResponse(ctx, client, req.ID, fmt.Sprintf("Error reading resource: %v", err),
|
||||||
|
errCodeInternalError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure the URI is set if not already provided by the handler
|
||||||
|
if len(content.URI) == 0 {
|
||||||
|
content.URI = params.URI
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure MimeType is set if available from the resource definition
|
||||||
|
if len(content.MimeType) == 0 && len(resource.MimeType) > 0 {
|
||||||
|
content.MimeType = resource.MimeType
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create response with contents from the handler
|
||||||
|
// The MCP specification requires a contents array
|
||||||
|
result := ResourceReadResult{
|
||||||
|
Contents: []ResourceContent{content},
|
||||||
|
}
|
||||||
|
|
||||||
|
s.sendResponse(ctx, client, req.ID, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// processResourceSubscribe processes the resources/subscribe request
|
||||||
|
func (s *sseMcpServer) processResourceSubscribe(ctx context.Context, client *mcpClient, req Request) {
|
||||||
|
var params ResourceSubscribeParams
|
||||||
|
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
||||||
|
s.sendErrorResponse(ctx, client, req.ID, "Invalid parameters", errCodeInvalidParams)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the resource exists
|
||||||
|
s.resourcesLock.Lock()
|
||||||
|
_, exists := s.resources[params.URI]
|
||||||
|
s.resourcesLock.Unlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
s.sendErrorResponse(ctx, client, req.ID, fmt.Sprintf("Resource with URI '%s' not found",
|
||||||
|
params.URI), errCodeResourceNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send success response for the subscription
|
||||||
|
s.sendResponse(ctx, client, req.ID, struct{}{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// processNotificationCancelled processes the notifications/cancelled notification
|
||||||
|
func (s *sseMcpServer) processNotificationCancelled(ctx context.Context, client *mcpClient, req Request) {
|
||||||
|
// Extract the requestId that was canceled
|
||||||
|
type CancelParams struct {
|
||||||
|
RequestId int64 `json:"requestId"`
|
||||||
|
Reason string `json:"reason"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var params CancelParams
|
||||||
|
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
||||||
|
logx.Errorf("Failed to parse cancellation params: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logx.Infof("Request %d was cancelled by client. Reason: %s", params.RequestId, params.Reason)
|
||||||
|
}
|
||||||
|
|
||||||
|
// processNotificationInitialized processes the notifications/initialized notification
|
||||||
|
func (s *sseMcpServer) processNotificationInitialized(client *mcpClient) {
|
||||||
|
// Mark the client as properly initialized
|
||||||
|
client.initialized = true
|
||||||
|
logx.Infof("Client %s is now fully initialized and ready for normal operations", client.id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// processPing processes the ping request and responds immediately
|
||||||
|
func (s *sseMcpServer) processPing(ctx context.Context, client *mcpClient, req Request) {
|
||||||
|
// A ping request should simply respond with an empty result to confirm the server is alive
|
||||||
|
logx.Infof("Received ping request with ID: %d", req.ID)
|
||||||
|
|
||||||
|
// Send an empty response with client's original request ID
|
||||||
|
s.sendResponse(ctx, client, req.ID, struct{}{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendErrorResponse sends an error response via the SSE channel
|
||||||
|
func (s *sseMcpServer) sendErrorResponse(ctx context.Context, client *mcpClient,
|
||||||
|
id any, message string, code int) {
|
||||||
|
errorResponse := struct {
|
||||||
|
JsonRpc string `json:"jsonrpc"`
|
||||||
|
ID any `json:"id"`
|
||||||
|
Error errorMessage `json:"error"`
|
||||||
|
}{
|
||||||
|
JsonRpc: jsonRpcVersion,
|
||||||
|
ID: id,
|
||||||
|
Error: errorMessage{
|
||||||
|
Code: code,
|
||||||
|
Message: message,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// all fields are primitive types, impossible to fail
|
||||||
|
jsonData, _ := json.Marshal(errorResponse)
|
||||||
|
// Use CRLF line endings as requested
|
||||||
|
sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData))
|
||||||
|
logx.Infof("Sending error for ID %v: %s", id, sseMessage)
|
||||||
|
|
||||||
|
// cannot receive from ctx.Done() because we're sending to the channel for SSE messages
|
||||||
|
select {
|
||||||
|
case client.channel <- sseMessage:
|
||||||
|
default:
|
||||||
|
// Channel buffer is full, log warning and continue
|
||||||
|
logx.Infof("Client %s channel is full while sending error response with ID %d", client.id, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendResponse sends a success response via the SSE channel
|
||||||
|
func (s *sseMcpServer) sendResponse(ctx context.Context, client *mcpClient, id any, result any) {
|
||||||
|
response := Response{
|
||||||
|
JsonRpc: jsonRpcVersion,
|
||||||
|
ID: id,
|
||||||
|
Result: result,
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(response)
|
||||||
|
if err != nil {
|
||||||
|
s.sendErrorResponse(ctx, client, id, "Failed to marshal response", errCodeInternalError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use CRLF line endings as requested
|
||||||
|
sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData))
|
||||||
|
logx.Infof("Sending response for ID %v: %s", id, sseMessage)
|
||||||
|
|
||||||
|
// cannot receive from ctx.Done() because we're sending to the channel for SSE messages
|
||||||
|
select {
|
||||||
|
case client.channel <- sseMessage:
|
||||||
|
default:
|
||||||
|
// Channel buffer is full, log warning and continue
|
||||||
|
logx.Infof("Client %s channel is full while sending response with ID %v", client.id, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
3451
mcp/server_test.go
Normal file
3451
mcp/server_test.go
Normal file
File diff suppressed because it is too large
Load Diff
317
mcp/types.go
Normal file
317
mcp/types.go
Normal file
@@ -0,0 +1,317 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/zeromicro/go-zero/rest"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Cursor is an opaque token used for pagination
|
||||||
|
type Cursor string
|
||||||
|
|
||||||
|
// Request represents a generic MCP request following JSON-RPC 2.0 specification
|
||||||
|
type Request struct {
|
||||||
|
SessionId string `form:"session_id"` // Session identifier for client tracking
|
||||||
|
JsonRpc string `json:"jsonrpc"` // Must be "2.0" per JSON-RPC spec
|
||||||
|
ID any `json:"id"` // Request identifier for matching responses
|
||||||
|
Method string `json:"method"` // Method name to invoke
|
||||||
|
Params json.RawMessage `json:"params"` // Parameters for the method
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r Request) isNotification() (bool, error) {
|
||||||
|
switch val := r.ID.(type) {
|
||||||
|
case int:
|
||||||
|
return val == 0, nil
|
||||||
|
case int64:
|
||||||
|
return val == 0, nil
|
||||||
|
case float64:
|
||||||
|
return val == 0.0, nil
|
||||||
|
case string:
|
||||||
|
return len(val) == 0, nil
|
||||||
|
case nil:
|
||||||
|
return true, nil
|
||||||
|
default:
|
||||||
|
return false, fmt.Errorf("invalid type %T", val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type PaginatedParams struct {
|
||||||
|
Cursor string `json:"cursor"`
|
||||||
|
Meta struct {
|
||||||
|
ProgressToken any `json:"progressToken"`
|
||||||
|
} `json:"_meta"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Result is the base interface for all results
|
||||||
|
type Result struct {
|
||||||
|
Meta map[string]any `json:"_meta,omitempty"` // Optional metadata
|
||||||
|
}
|
||||||
|
|
||||||
|
// PaginatedResult is a base for results that support pagination
|
||||||
|
type PaginatedResult struct {
|
||||||
|
Result
|
||||||
|
NextCursor Cursor `json:"nextCursor,omitempty"` // Opaque token for fetching next page
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListToolsResult represents the response to a tools/list request
|
||||||
|
type ListToolsResult struct {
|
||||||
|
PaginatedResult
|
||||||
|
Tools []Tool `json:"tools"` // List of available tools
|
||||||
|
}
|
||||||
|
|
||||||
|
// Message Content Types
|
||||||
|
|
||||||
|
// RoleType represents the sender or recipient of messages in a conversation
|
||||||
|
type RoleType string
|
||||||
|
|
||||||
|
// PromptArgument defines a single argument that can be passed to a prompt
|
||||||
|
type PromptArgument struct {
|
||||||
|
Name string `json:"name"` // Argument name
|
||||||
|
Description string `json:"description,omitempty"` // Human-readable description
|
||||||
|
Required bool `json:"required,omitempty"` // Whether this argument is required
|
||||||
|
}
|
||||||
|
|
||||||
|
// PromptHandler is a function that dynamically generates prompt content
|
||||||
|
type PromptHandler func(ctx context.Context, args map[string]string) ([]PromptMessage, error)
|
||||||
|
|
||||||
|
// Prompt represents an MCP Prompt definition
|
||||||
|
type Prompt struct {
|
||||||
|
Name string `json:"name"` // Unique identifier for the prompt
|
||||||
|
Description string `json:"description,omitempty"` // Human-readable description
|
||||||
|
Arguments []PromptArgument `json:"arguments,omitempty"` // Arguments for customization
|
||||||
|
Content string `json:"-"` // Static content (internal use only)
|
||||||
|
Handler PromptHandler `json:"-"` // Handler for dynamic content generation
|
||||||
|
}
|
||||||
|
|
||||||
|
// PromptMessage represents a message in a conversation
|
||||||
|
type PromptMessage struct {
|
||||||
|
Role RoleType `json:"role"` // Message sender role
|
||||||
|
Content any `json:"content"` // Message content (TextContent, ImageContent, etc.)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TextContent represents text content in a message
|
||||||
|
type TextContent struct {
|
||||||
|
Text string `json:"text"` // The text content
|
||||||
|
Annotations *Annotations `json:"annotations,omitempty"` // Optional annotations
|
||||||
|
}
|
||||||
|
|
||||||
|
type typedTextContent struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
TextContent
|
||||||
|
}
|
||||||
|
|
||||||
|
// ImageContent represents image data in a message
|
||||||
|
type ImageContent struct {
|
||||||
|
Data string `json:"data"` // Base64-encoded image data
|
||||||
|
MimeType string `json:"mimeType"` // MIME type (e.g., "image/png")
|
||||||
|
}
|
||||||
|
|
||||||
|
type typedImageContent struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
ImageContent
|
||||||
|
}
|
||||||
|
|
||||||
|
// AudioContent represents audio data in a message
|
||||||
|
type AudioContent struct {
|
||||||
|
Data string `json:"data"` // Base64-encoded audio data
|
||||||
|
MimeType string `json:"mimeType"` // MIME type (e.g., "audio/mp3")
|
||||||
|
}
|
||||||
|
|
||||||
|
type typedAudioContent struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
AudioContent
|
||||||
|
}
|
||||||
|
|
||||||
|
// FileContent represents file content
|
||||||
|
type FileContent struct {
|
||||||
|
URI string `json:"uri"` // URI identifying the file
|
||||||
|
MimeType string `json:"mimeType"` // MIME type of the file
|
||||||
|
Text string `json:"text"` // File content as text
|
||||||
|
}
|
||||||
|
|
||||||
|
// EmbeddedResource represents a resource embedded in a message
|
||||||
|
type EmbeddedResource struct {
|
||||||
|
Type string `json:"type"` // Always "resource"
|
||||||
|
Resource ResourceContent `json:"resource"` // The resource data
|
||||||
|
}
|
||||||
|
|
||||||
|
// Annotations provides additional metadata for content
|
||||||
|
type Annotations struct {
|
||||||
|
Audience []RoleType `json:"audience,omitempty"` // Who should see this content
|
||||||
|
Priority *float64 `json:"priority,omitempty"` // Optional priority (0-1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tool-related Types
|
||||||
|
|
||||||
|
// ToolHandler is a function that handles tool calls
|
||||||
|
type ToolHandler func(ctx context.Context, params map[string]any) (any, error)
|
||||||
|
|
||||||
|
// Tool represents a Model Context Protocol Tool definition
|
||||||
|
type Tool struct {
|
||||||
|
Name string `json:"name"` // Unique identifier for the tool
|
||||||
|
Description string `json:"description"` // Human-readable description
|
||||||
|
InputSchema InputSchema `json:"inputSchema"` // JSON Schema for parameters
|
||||||
|
Handler ToolHandler `json:"-"` // Not sent to clients
|
||||||
|
}
|
||||||
|
|
||||||
|
// InputSchema represents tool's input schema in JSON Schema format
|
||||||
|
type InputSchema struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Properties map[string]any `json:"properties"` // Property definitions
|
||||||
|
Required []string `json:"required,omitempty"` // List of required properties
|
||||||
|
}
|
||||||
|
|
||||||
|
// CallToolResult represents a tool call result that conforms to the MCP schema
|
||||||
|
type CallToolResult struct {
|
||||||
|
Result
|
||||||
|
Content []any `json:"content"` // Content items (text, images, etc.)
|
||||||
|
IsError bool `json:"isError,omitempty"` // True if tool execution failed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resource represents a Model Context Protocol Resource definition
|
||||||
|
type Resource struct {
|
||||||
|
URI string `json:"uri"` // Unique resource identifier (RFC3986)
|
||||||
|
Name string `json:"name"` // Human-readable name
|
||||||
|
Description string `json:"description,omitempty"` // Optional description
|
||||||
|
MimeType string `json:"mimeType,omitempty"` // Optional MIME type
|
||||||
|
Handler ResourceHandler `json:"-"` // Internal handler not sent to clients
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResourceHandler is a function that handles resource read requests
|
||||||
|
type ResourceHandler func(ctx context.Context) (ResourceContent, error)
|
||||||
|
|
||||||
|
// ResourceContent represents the content of a resource
|
||||||
|
type ResourceContent struct {
|
||||||
|
URI string `json:"uri"` // Resource URI (required)
|
||||||
|
MimeType string `json:"mimeType,omitempty"` // MIME type of the resource
|
||||||
|
Text string `json:"text,omitempty"` // Text content (if available)
|
||||||
|
Blob string `json:"blob,omitempty"` // Base64 encoded blob data (if available)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResourcesListResult represents the response to a resources/list request
|
||||||
|
type ResourcesListResult struct {
|
||||||
|
PaginatedResult
|
||||||
|
Resources []Resource `json:"resources"` // List of available resources
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResourceReadParams contains parameters for a resources/read request
|
||||||
|
type ResourceReadParams struct {
|
||||||
|
URI string `json:"uri"` // URI of the resource to read
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResourceReadResult contains the result of a resources/read request
|
||||||
|
type ResourceReadResult struct {
|
||||||
|
Result
|
||||||
|
Contents []ResourceContent `json:"contents"` // Array of resource content
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResourceSubscribeParams contains parameters for a resources/subscribe request
|
||||||
|
type ResourceSubscribeParams struct {
|
||||||
|
URI string `json:"uri"` // URI of the resource to subscribe to
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResourceUpdateNotification represents a notification about a resource update
|
||||||
|
type ResourceUpdateNotification struct {
|
||||||
|
URI string `json:"uri"` // URI of the updated resource
|
||||||
|
Content ResourceContent `json:"content"` // New resource content
|
||||||
|
}
|
||||||
|
|
||||||
|
// Client and Server Types
|
||||||
|
|
||||||
|
// mcpClient represents an SSE client connection
|
||||||
|
type mcpClient struct {
|
||||||
|
id string // Unique client identifier
|
||||||
|
channel chan string // Channel for sending SSE messages
|
||||||
|
initialized bool // Tracks if client has sent notifications/initialized
|
||||||
|
}
|
||||||
|
|
||||||
|
// McpServer defines the interface for Model Context Protocol servers
|
||||||
|
type McpServer interface {
|
||||||
|
Start()
|
||||||
|
Stop()
|
||||||
|
RegisterTool(tool Tool) error
|
||||||
|
RegisterPrompt(prompt Prompt)
|
||||||
|
RegisterResource(resource Resource)
|
||||||
|
}
|
||||||
|
|
||||||
|
// sseMcpServer implements the McpServer interface using SSE
|
||||||
|
type sseMcpServer struct {
|
||||||
|
conf McpConf
|
||||||
|
server *rest.Server
|
||||||
|
clients map[string]*mcpClient
|
||||||
|
clientsLock sync.Mutex
|
||||||
|
tools map[string]Tool
|
||||||
|
toolsLock sync.Mutex
|
||||||
|
prompts map[string]Prompt
|
||||||
|
promptsLock sync.Mutex
|
||||||
|
resources map[string]Resource
|
||||||
|
resourcesLock sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// Response Types
|
||||||
|
|
||||||
|
// errorObj represents a JSON-RPC error object
|
||||||
|
type errorObj struct {
|
||||||
|
Code int `json:"code"` // Error code
|
||||||
|
Message string `json:"message"` // Error message
|
||||||
|
}
|
||||||
|
|
||||||
|
// Response represents a JSON-RPC response
|
||||||
|
type Response struct {
|
||||||
|
JsonRpc string `json:"jsonrpc"` // Always "2.0"
|
||||||
|
ID any `json:"id"` // Same as request ID
|
||||||
|
Result any `json:"result"` // Result object (null if error)
|
||||||
|
Error *errorObj `json:"error,omitempty"` // Error object (null if success)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Server Information Types
|
||||||
|
|
||||||
|
// serverInfo provides information about the server
|
||||||
|
type serverInfo struct {
|
||||||
|
Name string `json:"name"` // Server name
|
||||||
|
Version string `json:"version"` // Server version
|
||||||
|
}
|
||||||
|
|
||||||
|
// capabilities describes the server's capabilities
|
||||||
|
type capabilities struct {
|
||||||
|
Logging struct{} `json:"logging"`
|
||||||
|
Prompts struct {
|
||||||
|
ListChanged bool `json:"listChanged"` // Server will notify on prompt changes
|
||||||
|
} `json:"prompts"`
|
||||||
|
Resources struct {
|
||||||
|
Subscribe bool `json:"subscribe"` // Server supports resource subscriptions
|
||||||
|
ListChanged bool `json:"listChanged"` // Server will notify on resource changes
|
||||||
|
} `json:"resources"`
|
||||||
|
Tools struct {
|
||||||
|
ListChanged bool `json:"listChanged"` // Server will notify on tool changes
|
||||||
|
} `json:"tools"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// initializationResponse is sent in response to an initialize request
|
||||||
|
type initializationResponse struct {
|
||||||
|
ProtocolVersion string `json:"protocolVersion"` // Protocol version
|
||||||
|
Capabilities capabilities `json:"capabilities"` // Server capabilities
|
||||||
|
ServerInfo serverInfo `json:"serverInfo"` // Server information
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToolCallParams contains the parameters for a tool call
|
||||||
|
type ToolCallParams struct {
|
||||||
|
Name string `json:"name"` // Tool name
|
||||||
|
Parameters map[string]any `json:"parameters"` // Tool parameters
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToolResult contains the result of a tool execution
|
||||||
|
type ToolResult struct {
|
||||||
|
Type string `json:"type"` // Content type (text, image, etc.)
|
||||||
|
Content any `json:"content"` // Result content
|
||||||
|
}
|
||||||
|
|
||||||
|
// errorMessage represents a detailed error message
|
||||||
|
type errorMessage struct {
|
||||||
|
Code int `json:"code"` // Error code
|
||||||
|
Message string `json:"message"` // Error message
|
||||||
|
Data any `json:",omitempty"` // Additional error data
|
||||||
|
}
|
||||||
271
mcp/types_test.go
Normal file
271
mcp/types_test.go
Normal file
@@ -0,0 +1,271 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestResponseMarshaling(t *testing.T) {
|
||||||
|
// Test that the Response struct marshals correctly
|
||||||
|
resp := Response{
|
||||||
|
JsonRpc: "2.0",
|
||||||
|
ID: 123,
|
||||||
|
Result: map[string]string{
|
||||||
|
"key": "value",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(resp)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Contains(t, string(data), `"jsonrpc":"2.0"`)
|
||||||
|
assert.Contains(t, string(data), `"id":123`)
|
||||||
|
assert.Contains(t, string(data), `"result":{"key":"value"}`)
|
||||||
|
|
||||||
|
// Test response with error
|
||||||
|
respWithError := Response{
|
||||||
|
JsonRpc: "2.0",
|
||||||
|
ID: 456,
|
||||||
|
Error: &errorObj{
|
||||||
|
Code: errCodeInvalidRequest,
|
||||||
|
Message: "Invalid Request",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err = json.Marshal(respWithError)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Contains(t, string(data), `"jsonrpc":"2.0"`)
|
||||||
|
assert.Contains(t, string(data), `"id":456`)
|
||||||
|
assert.Contains(t, string(data), `"error":{"code":-32600,"message":"Invalid Request"}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestUnmarshaling(t *testing.T) {
|
||||||
|
// Test that the Request struct unmarshals correctly
|
||||||
|
jsonStr := `{
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": 789,
|
||||||
|
"method": "test_method",
|
||||||
|
"params": {"key": "value"}
|
||||||
|
}`
|
||||||
|
|
||||||
|
var req Request
|
||||||
|
err := json.Unmarshal([]byte(jsonStr), &req)
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "2.0", req.JsonRpc)
|
||||||
|
assert.Equal(t, float64(789), req.ID)
|
||||||
|
assert.Equal(t, "test_method", req.Method)
|
||||||
|
|
||||||
|
// Check params unmarshaled correctly
|
||||||
|
var params map[string]string
|
||||||
|
err = json.Unmarshal(req.Params, ¶ms)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "value", params["key"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToolStructs(t *testing.T) {
|
||||||
|
// Test Tool struct
|
||||||
|
tool := Tool{
|
||||||
|
Name: "test.tool",
|
||||||
|
Description: "A test tool",
|
||||||
|
InputSchema: InputSchema{
|
||||||
|
Type: "object",
|
||||||
|
Properties: map[string]any{
|
||||||
|
"input": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"description": "Input parameter",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Required: []string{"input"},
|
||||||
|
},
|
||||||
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
|
return "result", nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify fields are correct
|
||||||
|
assert.Equal(t, "test.tool", tool.Name)
|
||||||
|
assert.Equal(t, "A test tool", tool.Description)
|
||||||
|
assert.Equal(t, "object", tool.InputSchema.Type)
|
||||||
|
assert.Contains(t, tool.InputSchema.Properties, "input")
|
||||||
|
propMap, ok := tool.InputSchema.Properties["input"].(map[string]any)
|
||||||
|
assert.True(t, ok, "Property should be a map")
|
||||||
|
assert.Equal(t, "string", propMap["type"])
|
||||||
|
assert.NotNil(t, tool.Handler)
|
||||||
|
|
||||||
|
// Verify JSON marshalling (which should exclude Handler function)
|
||||||
|
data, err := json.Marshal(tool)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Contains(t, string(data), `"name":"test.tool"`)
|
||||||
|
assert.Contains(t, string(data), `"description":"A test tool"`)
|
||||||
|
assert.Contains(t, string(data), `"inputSchema":`)
|
||||||
|
assert.NotContains(t, string(data), `"Handler":`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPromptStructs(t *testing.T) {
|
||||||
|
// Test Prompt struct
|
||||||
|
prompt := Prompt{
|
||||||
|
Name: "test.prompt",
|
||||||
|
Description: "A test prompt description",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify fields are correct
|
||||||
|
assert.Equal(t, "test.prompt", prompt.Name)
|
||||||
|
assert.Equal(t, "A test prompt description", prompt.Description)
|
||||||
|
|
||||||
|
// Verify JSON marshalling
|
||||||
|
data, err := json.Marshal(prompt)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Contains(t, string(data), `"name":"test.prompt"`)
|
||||||
|
assert.Contains(t, string(data), `"description":"A test prompt description"`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResourceStructs(t *testing.T) {
|
||||||
|
// Test Resource struct
|
||||||
|
resource := Resource{
|
||||||
|
Name: "test.resource",
|
||||||
|
URI: "http://example.com/resource",
|
||||||
|
Description: "A test resource",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify fields are correct
|
||||||
|
assert.Equal(t, "test.resource", resource.Name)
|
||||||
|
assert.Equal(t, "http://example.com/resource", resource.URI)
|
||||||
|
assert.Equal(t, "A test resource", resource.Description)
|
||||||
|
|
||||||
|
// Verify JSON marshalling
|
||||||
|
data, err := json.Marshal(resource)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Contains(t, string(data), `"name":"test.resource"`)
|
||||||
|
assert.Contains(t, string(data), `"uri":"http://example.com/resource"`)
|
||||||
|
assert.Contains(t, string(data), `"description":"A test resource"`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContentTypes(t *testing.T) {
|
||||||
|
// Test TextContent
|
||||||
|
textContent := TextContent{
|
||||||
|
Text: "Sample text",
|
||||||
|
Annotations: &Annotations{
|
||||||
|
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||||
|
Priority: ptr(1.0),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(textContent)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Contains(t, string(data), `"text":"Sample text"`)
|
||||||
|
assert.Contains(t, string(data), `"audience":["user","assistant"]`)
|
||||||
|
assert.Contains(t, string(data), `"priority":1`)
|
||||||
|
|
||||||
|
// Test ImageContent
|
||||||
|
imageContent := ImageContent{
|
||||||
|
Data: "base64data",
|
||||||
|
MimeType: "image/png",
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err = json.Marshal(imageContent)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Contains(t, string(data), `"data":"base64data"`)
|
||||||
|
assert.Contains(t, string(data), `"mimeType":"image/png"`)
|
||||||
|
|
||||||
|
// Test AudioContent
|
||||||
|
audioContent := AudioContent{
|
||||||
|
Data: "base64audio",
|
||||||
|
MimeType: "audio/mp3",
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err = json.Marshal(audioContent)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Contains(t, string(data), `"data":"base64audio"`)
|
||||||
|
assert.Contains(t, string(data), `"mimeType":"audio/mp3"`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCallToolResult(t *testing.T) {
|
||||||
|
// Test CallToolResult
|
||||||
|
result := CallToolResult{
|
||||||
|
Result: Result{
|
||||||
|
Meta: map[string]any{
|
||||||
|
"progressToken": "token123",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Content: []interface{}{
|
||||||
|
TextContent{
|
||||||
|
Text: "Sample result",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
IsError: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(result)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Contains(t, string(data), `"_meta":{"progressToken":"token123"}`)
|
||||||
|
assert.Contains(t, string(data), `"content":[{"text":"Sample result"}]`)
|
||||||
|
assert.NotContains(t, string(data), `"isError":`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequest_isNotification(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
id any
|
||||||
|
want bool
|
||||||
|
wantErr error
|
||||||
|
}{
|
||||||
|
// integer test cases
|
||||||
|
{name: "int zero", id: 0, want: true, wantErr: nil},
|
||||||
|
{name: "int non-zero", id: 1, want: false, wantErr: nil},
|
||||||
|
{name: "int64 zero", id: int64(0), want: true, wantErr: nil},
|
||||||
|
{name: "int64 max", id: int64(9223372036854775807), want: false, wantErr: nil},
|
||||||
|
|
||||||
|
// floating point number test cases
|
||||||
|
{name: "float64 zero", id: float64(0.0), want: true, wantErr: nil},
|
||||||
|
{name: "float64 positive", id: float64(0.000001), want: false, wantErr: nil},
|
||||||
|
{name: "float64 negative", id: float64(-0.000001), want: false, wantErr: nil},
|
||||||
|
{name: "float64 epsilon", id: float64(1e-300), want: false, wantErr: nil},
|
||||||
|
|
||||||
|
// string test cases
|
||||||
|
{name: "empty string", id: "", want: true, wantErr: nil},
|
||||||
|
{name: "non-empty string", id: "abc", want: false, wantErr: nil},
|
||||||
|
{name: "space string", id: " ", want: false, wantErr: nil},
|
||||||
|
{name: "unicode string", id: "こんにちは", want: false, wantErr: nil},
|
||||||
|
|
||||||
|
// special cases
|
||||||
|
{name: "nil", id: nil, want: true, wantErr: nil},
|
||||||
|
|
||||||
|
// logical type test cases
|
||||||
|
{name: "bool true", id: true, want: false, wantErr: errors.New("invalid type bool")},
|
||||||
|
{name: "bool false", id: false, want: false, wantErr: errors.New("invalid type bool")},
|
||||||
|
{name: "struct type", id: struct{}{}, want: false, wantErr: errors.New("invalid type struct {}")},
|
||||||
|
{name: "slice type", id: []int{1, 2, 3}, want: false, wantErr: errors.New("invalid type []int")},
|
||||||
|
{name: "map type", id: map[string]int{"a": 1}, want: false, wantErr: errors.New("invalid type map[string]int")},
|
||||||
|
{name: "pointer type", id: new(int), want: false, wantErr: errors.New("invalid type *int")},
|
||||||
|
{name: "func type", id: func() {}, want: false, wantErr: errors.New("invalid type func()")},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
req := Request{
|
||||||
|
SessionId: "test-session",
|
||||||
|
JsonRpc: "2.0",
|
||||||
|
ID: tt.id,
|
||||||
|
Method: "testMethod",
|
||||||
|
Params: json.RawMessage(`{}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := req.isNotification()
|
||||||
|
|
||||||
|
if (err != nil) != (tt.wantErr != nil) {
|
||||||
|
t.Fatalf("error presence mismatch: got error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
if err != nil && tt.wantErr != nil && err.Error() != tt.wantErr.Error() {
|
||||||
|
t.Fatalf("error message mismatch:\ngot %q\nwant %q", err.Error(), tt.wantErr.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("isNotification() = %v, want %v for ID %v (%T)", got, tt.want, tt.id, tt.id)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
107
mcp/util.go
Normal file
107
mcp/util.go
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
// formatSSEMessage formats a Server-Sent Event message with proper CRLF line endings
|
||||||
|
func formatSSEMessage(event string, data []byte) string {
|
||||||
|
return fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", event, string(data))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ptr is a helper function to get a pointer to a value
|
||||||
|
func ptr[T any](v T) *T {
|
||||||
|
return &v
|
||||||
|
}
|
||||||
|
|
||||||
|
func toTypedContents(contents []any) []any {
|
||||||
|
typedContents := make([]any, len(contents))
|
||||||
|
|
||||||
|
for i, content := range contents {
|
||||||
|
switch v := content.(type) {
|
||||||
|
case TextContent:
|
||||||
|
typedContents[i] = typedTextContent{
|
||||||
|
Type: ContentTypeText,
|
||||||
|
TextContent: v,
|
||||||
|
}
|
||||||
|
case ImageContent:
|
||||||
|
typedContents[i] = typedImageContent{
|
||||||
|
Type: ContentTypeImage,
|
||||||
|
ImageContent: v,
|
||||||
|
}
|
||||||
|
case AudioContent:
|
||||||
|
typedContents[i] = typedAudioContent{
|
||||||
|
Type: ContentTypeAudio,
|
||||||
|
AudioContent: v,
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
typedContents[i] = typedTextContent{
|
||||||
|
Type: ContentTypeText,
|
||||||
|
TextContent: TextContent{
|
||||||
|
Text: fmt.Sprintf("Unknown content type: %T", v),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return typedContents
|
||||||
|
}
|
||||||
|
|
||||||
|
func toTypedPromptMessages(messages []PromptMessage) []PromptMessage {
|
||||||
|
typedMessages := make([]PromptMessage, len(messages))
|
||||||
|
|
||||||
|
for i, msg := range messages {
|
||||||
|
switch v := msg.Content.(type) {
|
||||||
|
case TextContent:
|
||||||
|
typedMessages[i] = PromptMessage{
|
||||||
|
Role: msg.Role,
|
||||||
|
Content: typedTextContent{
|
||||||
|
Type: ContentTypeText,
|
||||||
|
TextContent: v,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
case ImageContent:
|
||||||
|
typedMessages[i] = PromptMessage{
|
||||||
|
Role: msg.Role,
|
||||||
|
Content: typedImageContent{
|
||||||
|
Type: ContentTypeImage,
|
||||||
|
ImageContent: v,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
case AudioContent:
|
||||||
|
typedMessages[i] = PromptMessage{
|
||||||
|
Role: msg.Role,
|
||||||
|
Content: typedAudioContent{
|
||||||
|
Type: ContentTypeAudio,
|
||||||
|
AudioContent: v,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
typedMessages[i] = PromptMessage{
|
||||||
|
Role: msg.Role,
|
||||||
|
Content: typedTextContent{
|
||||||
|
Type: ContentTypeText,
|
||||||
|
TextContent: TextContent{
|
||||||
|
Text: fmt.Sprintf("Unknown content type: %T", v),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return typedMessages
|
||||||
|
}
|
||||||
|
|
||||||
|
// validatePromptArguments checks if all required arguments are provided
|
||||||
|
// Returns a list of missing required arguments
|
||||||
|
func validatePromptArguments(prompt Prompt, providedArgs map[string]string) []string {
|
||||||
|
var missingArgs []string
|
||||||
|
|
||||||
|
for _, arg := range prompt.Arguments {
|
||||||
|
if arg.Required {
|
||||||
|
if value, exists := providedArgs[arg.Name]; !exists || len(value) == 0 {
|
||||||
|
missingArgs = append(missingArgs, arg.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return missingArgs
|
||||||
|
}
|
||||||
274
mcp/util_test.go
Normal file
274
mcp/util_test.go
Normal file
@@ -0,0 +1,274 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Event struct {
|
||||||
|
Type string
|
||||||
|
Data map[string]any
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseEvent(input string) (*Event, error) {
|
||||||
|
var evt Event
|
||||||
|
var dataStr string
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(strings.NewReader(input))
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Text()
|
||||||
|
if strings.HasPrefix(line, "event:") {
|
||||||
|
evt.Type = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
|
||||||
|
} else if strings.HasPrefix(line, "data:") {
|
||||||
|
dataStr = strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(dataStr) > 0 {
|
||||||
|
if err := json.Unmarshal([]byte(dataStr), &evt.Data); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse data: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &evt, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestToTypedPromptMessages tests the toTypedPromptMessages function
|
||||||
|
func TestToTypedPromptMessages(t *testing.T) {
|
||||||
|
// Test with multiple message types in one test
|
||||||
|
t.Run("MixedContentTypes", func(t *testing.T) {
|
||||||
|
// Create test data with different content types
|
||||||
|
messages := []PromptMessage{
|
||||||
|
{
|
||||||
|
Role: RoleUser,
|
||||||
|
Content: TextContent{
|
||||||
|
Text: "Hello, this is a text message",
|
||||||
|
Annotations: &Annotations{
|
||||||
|
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||||
|
Priority: ptr(0.8),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: RoleAssistant,
|
||||||
|
Content: ImageContent{
|
||||||
|
Data: "base64ImageData",
|
||||||
|
MimeType: "image/jpeg",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: RoleUser,
|
||||||
|
Content: AudioContent{
|
||||||
|
Data: "base64AudioData",
|
||||||
|
MimeType: "audio/mp3",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "system",
|
||||||
|
Content: "This is a simple string that should be handled as unknown type",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call the function
|
||||||
|
result := toTypedPromptMessages(messages)
|
||||||
|
|
||||||
|
// Validate results
|
||||||
|
require.Len(t, result, 4, "Should return the same number of messages")
|
||||||
|
|
||||||
|
// Validate first message (TextContent)
|
||||||
|
msg := result[0]
|
||||||
|
assert.Equal(t, RoleUser, msg.Role, "Role should be preserved")
|
||||||
|
|
||||||
|
// Type assertion using reflection since Content is an interface
|
||||||
|
typed, ok := msg.Content.(typedTextContent)
|
||||||
|
require.True(t, ok, "Should be typedTextContent")
|
||||||
|
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
|
||||||
|
assert.Equal(t, "Hello, this is a text message", typed.Text, "Text content should be preserved")
|
||||||
|
require.NotNil(t, typed.Annotations, "Annotations should be preserved")
|
||||||
|
assert.Equal(t, []RoleType{RoleUser, RoleAssistant}, typed.Annotations.Audience, "Audience should be preserved")
|
||||||
|
require.NotNil(t, typed.Annotations.Priority, "Priority should be preserved")
|
||||||
|
assert.Equal(t, 0.8, *typed.Annotations.Priority, "Priority value should be preserved")
|
||||||
|
|
||||||
|
// Validate second message (ImageContent)
|
||||||
|
msg = result[1]
|
||||||
|
assert.Equal(t, RoleAssistant, msg.Role, "Role should be preserved")
|
||||||
|
|
||||||
|
// Type assertion for image content
|
||||||
|
typedImg, ok := msg.Content.(typedImageContent)
|
||||||
|
require.True(t, ok, "Should be typedImageContent")
|
||||||
|
assert.Equal(t, ContentTypeImage, typedImg.Type, "Type should be image")
|
||||||
|
assert.Equal(t, "base64ImageData", typedImg.Data, "Image data should be preserved")
|
||||||
|
assert.Equal(t, "image/jpeg", typedImg.MimeType, "MimeType should be preserved")
|
||||||
|
|
||||||
|
// Validate third message (AudioContent)
|
||||||
|
msg = result[2]
|
||||||
|
assert.Equal(t, RoleUser, msg.Role, "Role should be preserved")
|
||||||
|
|
||||||
|
// Type assertion for audio content
|
||||||
|
typedAudio, ok := msg.Content.(typedAudioContent)
|
||||||
|
require.True(t, ok, "Should be typedAudioContent")
|
||||||
|
assert.Equal(t, ContentTypeAudio, typedAudio.Type, "Type should be audio")
|
||||||
|
assert.Equal(t, "base64AudioData", typedAudio.Data, "Audio data should be preserved")
|
||||||
|
assert.Equal(t, "audio/mp3", typedAudio.MimeType, "MimeType should be preserved")
|
||||||
|
|
||||||
|
// Validate fourth message (unknown type converted to TextContent)
|
||||||
|
msg = result[3]
|
||||||
|
assert.Equal(t, RoleType("system"), msg.Role, "Role should be preserved")
|
||||||
|
|
||||||
|
// Should be converted to a typedTextContent with error message
|
||||||
|
typedUnknown, ok := msg.Content.(typedTextContent)
|
||||||
|
require.True(t, ok, "Unknown content should be converted to typedTextContent")
|
||||||
|
assert.Equal(t, ContentTypeText, typedUnknown.Type, "Type should be text")
|
||||||
|
assert.Contains(t, typedUnknown.Text, "Unknown content type:", "Should contain error about unknown type")
|
||||||
|
assert.Contains(t, typedUnknown.Text, "string", "Should mention the actual type")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test empty input
|
||||||
|
t.Run("EmptyInput", func(t *testing.T) {
|
||||||
|
messages := []PromptMessage{}
|
||||||
|
result := toTypedPromptMessages(messages)
|
||||||
|
assert.Empty(t, result, "Should return empty slice for empty input")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test with nil annotations
|
||||||
|
t.Run("NilAnnotations", func(t *testing.T) {
|
||||||
|
messages := []PromptMessage{
|
||||||
|
{
|
||||||
|
Role: RoleUser,
|
||||||
|
Content: TextContent{
|
||||||
|
Text: "Text with nil annotations",
|
||||||
|
Annotations: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := toTypedPromptMessages(messages)
|
||||||
|
require.Len(t, result, 1, "Should return one message")
|
||||||
|
|
||||||
|
typed, ok := result[0].Content.(typedTextContent)
|
||||||
|
require.True(t, ok, "Should be typedTextContent")
|
||||||
|
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
|
||||||
|
assert.Equal(t, "Text with nil annotations", typed.Text, "Text content should be preserved")
|
||||||
|
assert.Nil(t, typed.Annotations, "Nil annotations should be preserved as nil")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestToTypedContents tests the toTypedContents function
|
||||||
|
func TestToTypedContents(t *testing.T) {
|
||||||
|
// Test with multiple content types in one test
|
||||||
|
t.Run("MixedContentTypes", func(t *testing.T) {
|
||||||
|
// Create test data with different content types
|
||||||
|
contents := []any{
|
||||||
|
TextContent{
|
||||||
|
Text: "Hello, this is a text content",
|
||||||
|
Annotations: &Annotations{
|
||||||
|
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||||
|
Priority: ptr(0.7),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ImageContent{
|
||||||
|
Data: "base64ImageData",
|
||||||
|
MimeType: "image/png",
|
||||||
|
},
|
||||||
|
AudioContent{
|
||||||
|
Data: "base64AudioData",
|
||||||
|
MimeType: "audio/wav",
|
||||||
|
},
|
||||||
|
"This is a simple string that should be handled as unknown type",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call the function
|
||||||
|
result := toTypedContents(contents)
|
||||||
|
|
||||||
|
// Validate results
|
||||||
|
require.Len(t, result, 4, "Should return the same number of contents")
|
||||||
|
|
||||||
|
// Validate first content (TextContent)
|
||||||
|
typed, ok := result[0].(typedTextContent)
|
||||||
|
require.True(t, ok, "Should be typedTextContent")
|
||||||
|
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
|
||||||
|
assert.Equal(t, "Hello, this is a text content", typed.Text, "Text content should be preserved")
|
||||||
|
require.NotNil(t, typed.Annotations, "Annotations should be preserved")
|
||||||
|
assert.Equal(t, []RoleType{RoleUser, RoleAssistant}, typed.Annotations.Audience, "Audience should be preserved")
|
||||||
|
require.NotNil(t, typed.Annotations.Priority, "Priority should be preserved")
|
||||||
|
assert.Equal(t, 0.7, *typed.Annotations.Priority, "Priority value should be preserved")
|
||||||
|
|
||||||
|
// Validate second content (ImageContent)
|
||||||
|
typedImg, ok := result[1].(typedImageContent)
|
||||||
|
require.True(t, ok, "Should be typedImageContent")
|
||||||
|
assert.Equal(t, ContentTypeImage, typedImg.Type, "Type should be image")
|
||||||
|
assert.Equal(t, "base64ImageData", typedImg.Data, "Image data should be preserved")
|
||||||
|
assert.Equal(t, "image/png", typedImg.MimeType, "MimeType should be preserved")
|
||||||
|
|
||||||
|
// Validate third content (AudioContent)
|
||||||
|
typedAudio, ok := result[2].(typedAudioContent)
|
||||||
|
require.True(t, ok, "Should be typedAudioContent")
|
||||||
|
assert.Equal(t, ContentTypeAudio, typedAudio.Type, "Type should be audio")
|
||||||
|
assert.Equal(t, "base64AudioData", typedAudio.Data, "Audio data should be preserved")
|
||||||
|
assert.Equal(t, "audio/wav", typedAudio.MimeType, "MimeType should be preserved")
|
||||||
|
|
||||||
|
// Validate fourth content (unknown type converted to TextContent)
|
||||||
|
typedUnknown, ok := result[3].(typedTextContent)
|
||||||
|
require.True(t, ok, "Unknown content should be converted to typedTextContent")
|
||||||
|
assert.Equal(t, ContentTypeText, typedUnknown.Type, "Type should be text")
|
||||||
|
assert.Contains(t, typedUnknown.Text, "Unknown content type:", "Should contain error about unknown type")
|
||||||
|
assert.Contains(t, typedUnknown.Text, "string", "Should mention the actual type")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test empty input
|
||||||
|
t.Run("EmptyInput", func(t *testing.T) {
|
||||||
|
contents := []any{}
|
||||||
|
result := toTypedContents(contents)
|
||||||
|
assert.Empty(t, result, "Should return empty slice for empty input")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test with nil annotations
|
||||||
|
t.Run("NilAnnotations", func(t *testing.T) {
|
||||||
|
contents := []any{
|
||||||
|
TextContent{
|
||||||
|
Text: "Text with nil annotations",
|
||||||
|
Annotations: nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := toTypedContents(contents)
|
||||||
|
require.Len(t, result, 1, "Should return one content")
|
||||||
|
|
||||||
|
typed, ok := result[0].(typedTextContent)
|
||||||
|
require.True(t, ok, "Should be typedTextContent")
|
||||||
|
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
|
||||||
|
assert.Equal(t, "Text with nil annotations", typed.Text, "Text content should be preserved")
|
||||||
|
assert.Nil(t, typed.Annotations, "Nil annotations should be preserved as nil")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test with custom struct (should be handled as unknown type)
|
||||||
|
t.Run("CustomStruct", func(t *testing.T) {
|
||||||
|
type CustomContent struct {
|
||||||
|
Data string
|
||||||
|
}
|
||||||
|
|
||||||
|
contents := []any{
|
||||||
|
CustomContent{
|
||||||
|
Data: "custom data",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := toTypedContents(contents)
|
||||||
|
require.Len(t, result, 1, "Should return one content")
|
||||||
|
|
||||||
|
typed, ok := result[0].(typedTextContent)
|
||||||
|
require.True(t, ok, "Custom struct should be converted to typedTextContent")
|
||||||
|
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
|
||||||
|
assert.Contains(t, typed.Text, "Unknown content type:", "Should contain error about unknown type")
|
||||||
|
assert.Contains(t, typed.Text, "CustomContent", "Should mention the actual type")
|
||||||
|
})
|
||||||
|
}
|
||||||
149
mcp/vars.go
Normal file
149
mcp/vars.go
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/zeromicro/go-zero/core/syncx"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Protocol constants
|
||||||
|
const (
|
||||||
|
// JSON-RPC version as defined in the specification
|
||||||
|
jsonRpcVersion = "2.0"
|
||||||
|
|
||||||
|
// Session identifier key used in request URLs
|
||||||
|
sessionIdKey = "session_id"
|
||||||
|
|
||||||
|
// progressTokenKey is used to track progress of long-running tasks
|
||||||
|
progressTokenKey = "progressToken"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Server-Sent Events (SSE) event types
|
||||||
|
const (
|
||||||
|
// Standard message event for JSON-RPC responses
|
||||||
|
eventMessage = "message"
|
||||||
|
|
||||||
|
// Endpoint event for sending endpoint URL to clients
|
||||||
|
eventEndpoint = "endpoint"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Content type identifiers
|
||||||
|
const (
|
||||||
|
// ContentTypeObject is object content type
|
||||||
|
ContentTypeObject = "object"
|
||||||
|
|
||||||
|
// ContentTypeText is text content type
|
||||||
|
ContentTypeText = "text"
|
||||||
|
|
||||||
|
// ContentTypeImage is image content type
|
||||||
|
ContentTypeImage = "image"
|
||||||
|
|
||||||
|
// ContentTypeAudio is audio content type
|
||||||
|
ContentTypeAudio = "audio"
|
||||||
|
|
||||||
|
// ContentTypeResource is resource content type
|
||||||
|
ContentTypeResource = "resource"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Collection keys for broadcast events
|
||||||
|
const (
|
||||||
|
// Key for prompts collection
|
||||||
|
keyPrompts = "prompts"
|
||||||
|
|
||||||
|
// Key for resources collection
|
||||||
|
keyResources = "resources"
|
||||||
|
|
||||||
|
// Key for tools collection
|
||||||
|
keyTools = "tools"
|
||||||
|
)
|
||||||
|
|
||||||
|
// JSON-RPC error codes
|
||||||
|
// Standard error codes from JSON-RPC 2.0 spec
|
||||||
|
const (
|
||||||
|
// Invalid JSON was received by the server
|
||||||
|
errCodeInvalidRequest = -32600
|
||||||
|
|
||||||
|
// The method does not exist / is not available
|
||||||
|
errCodeMethodNotFound = -32601
|
||||||
|
|
||||||
|
// Invalid method parameter(s)
|
||||||
|
errCodeInvalidParams = -32602
|
||||||
|
|
||||||
|
// Internal JSON-RPC error
|
||||||
|
errCodeInternalError = -32603
|
||||||
|
|
||||||
|
// Tool execution timed out
|
||||||
|
errCodeTimeout = -32001
|
||||||
|
|
||||||
|
// Resource not found error
|
||||||
|
errCodeResourceNotFound = -32002
|
||||||
|
|
||||||
|
// Client hasn't completed initialization
|
||||||
|
errCodeClientNotInitialized = -32800
|
||||||
|
)
|
||||||
|
|
||||||
|
// User and assistant role definitions
|
||||||
|
const (
|
||||||
|
// RoleUser is the "user" role - the entity asking questions
|
||||||
|
RoleUser RoleType = "user"
|
||||||
|
|
||||||
|
// RoleAssistant is the "assistant" role - the entity providing responses
|
||||||
|
RoleAssistant RoleType = "assistant"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Method names as defined in the MCP specification
|
||||||
|
const (
|
||||||
|
// Initialize the connection between client and server
|
||||||
|
methodInitialize = "initialize"
|
||||||
|
|
||||||
|
// List available tools
|
||||||
|
methodToolsList = "tools/list"
|
||||||
|
|
||||||
|
// Call a specific tool
|
||||||
|
methodToolsCall = "tools/call"
|
||||||
|
|
||||||
|
// List available prompts
|
||||||
|
methodPromptsList = "prompts/list"
|
||||||
|
|
||||||
|
// Get a specific prompt
|
||||||
|
methodPromptsGet = "prompts/get"
|
||||||
|
|
||||||
|
// List available resources
|
||||||
|
methodResourcesList = "resources/list"
|
||||||
|
|
||||||
|
// Read a specific resource
|
||||||
|
methodResourcesRead = "resources/read"
|
||||||
|
|
||||||
|
// Subscribe to resource updates
|
||||||
|
methodResourcesSubscribe = "resources/subscribe"
|
||||||
|
|
||||||
|
// Simple ping to check server availability
|
||||||
|
methodPing = "ping"
|
||||||
|
|
||||||
|
// Notification that client is fully initialized
|
||||||
|
methodNotificationsInitialized = "notifications/initialized"
|
||||||
|
|
||||||
|
// Notification that a request was canceled
|
||||||
|
methodNotificationsCancelled = "notifications/cancelled"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Event names for Server-Sent Events (SSE)
|
||||||
|
const (
|
||||||
|
// Notification of tool list changes
|
||||||
|
eventToolsListChanged = "tools/list_changed"
|
||||||
|
|
||||||
|
// Notification of prompt list changes
|
||||||
|
eventPromptsListChanged = "prompts/list_changed"
|
||||||
|
|
||||||
|
// Notification of resource list changes
|
||||||
|
eventResourcesListChanged = "resources/list_changed"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// Default channel size for events
|
||||||
|
eventChanSize = 10
|
||||||
|
|
||||||
|
// Default ping interval for checking connection availability
|
||||||
|
// use syncx.ForAtomicDuration to ensure atomicity in test race
|
||||||
|
pingInterval = syncx.ForAtomicDuration(30 * time.Second)
|
||||||
|
)
|
||||||
210
mcp/vars_test.go
Normal file
210
mcp/vars_test.go
Normal file
@@ -0,0 +1,210 @@
|
|||||||
|
// filepath: /Users/kevin/Develop/go/opensource/go-zero/mcp/vars_test.go
|
||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestErrorCodes ensures error codes are applied correctly in error responses
|
||||||
|
func TestErrorCodes(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
code int
|
||||||
|
message string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "invalid request error",
|
||||||
|
code: errCodeInvalidRequest,
|
||||||
|
message: "Invalid request",
|
||||||
|
expected: `"code":-32600`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "method not found error",
|
||||||
|
code: errCodeMethodNotFound,
|
||||||
|
message: "Method not found",
|
||||||
|
expected: `"code":-32601`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid params error",
|
||||||
|
code: errCodeInvalidParams,
|
||||||
|
message: "Invalid parameters",
|
||||||
|
expected: `"code":-32602`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "internal error",
|
||||||
|
code: errCodeInternalError,
|
||||||
|
message: "Internal server error",
|
||||||
|
expected: `"code":-32603`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "timeout error",
|
||||||
|
code: errCodeTimeout,
|
||||||
|
message: "Operation timed out",
|
||||||
|
expected: `"code":-32001`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "resource not found error",
|
||||||
|
code: errCodeResourceNotFound,
|
||||||
|
message: "Resource not found",
|
||||||
|
expected: `"code":-32002`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "client not initialized error",
|
||||||
|
code: errCodeClientNotInitialized,
|
||||||
|
message: "Client not initialized",
|
||||||
|
expected: `"code":-32800`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
resp := Response{
|
||||||
|
JsonRpc: jsonRpcVersion,
|
||||||
|
ID: int64(1),
|
||||||
|
Error: &errorObj{
|
||||||
|
Code: tc.code,
|
||||||
|
Message: tc.message,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(resp)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Contains(t, string(data), tc.expected, "Error code should match expected value")
|
||||||
|
assert.Contains(t, string(data), tc.message, "Error message should be included")
|
||||||
|
assert.Contains(t, string(data), jsonRpcVersion, "JSON-RPC version should be included")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestJsonRpcVersion ensures the correct JSON-RPC version is used
|
||||||
|
func TestJsonRpcVersion(t *testing.T) {
|
||||||
|
assert.Equal(t, "2.0", jsonRpcVersion, "JSON-RPC version should be 2.0")
|
||||||
|
|
||||||
|
// Test that it's used in responses
|
||||||
|
resp := Response{
|
||||||
|
JsonRpc: jsonRpcVersion,
|
||||||
|
ID: int64(1),
|
||||||
|
Result: "test",
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(resp)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Contains(t, string(data), `"jsonrpc":"2.0"`, "Response should use correct JSON-RPC version")
|
||||||
|
|
||||||
|
// Test that it's expected in requests
|
||||||
|
reqStr := `{"jsonrpc":"2.0","id":1,"method":"test"}`
|
||||||
|
var req Request
|
||||||
|
err = json.Unmarshal([]byte(reqStr), &req)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, jsonRpcVersion, req.JsonRpc, "Request should parse correct JSON-RPC version")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSessionIdKey ensures session ID extraction works correctly
|
||||||
|
func TestSessionIdKey(t *testing.T) {
|
||||||
|
// Create a mock server implementation
|
||||||
|
mock := newMockMcpServer(t)
|
||||||
|
defer mock.shutdown()
|
||||||
|
|
||||||
|
// Verify the key constant
|
||||||
|
assert.Equal(t, "session_id", sessionIdKey, "Session ID key should be 'session_id'")
|
||||||
|
|
||||||
|
// Test that session ID is extracted correctly
|
||||||
|
mockR := httptest.NewRequest("GET", "/?"+sessionIdKey+"=test-session", nil)
|
||||||
|
|
||||||
|
// Since the mock server is using the same session key logic,
|
||||||
|
// we can test this by accessing the request query parameters directly
|
||||||
|
sessionID := mockR.URL.Query().Get(sessionIdKey)
|
||||||
|
assert.Equal(t, "test-session", sessionID, "Session ID should be extracted correctly")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestEventTypes ensures event types are set correctly in SSE responses
|
||||||
|
func TestEventTypes(t *testing.T) {
|
||||||
|
// Test message event
|
||||||
|
assert.Equal(t, "message", eventMessage, "Message event should be 'message'")
|
||||||
|
|
||||||
|
// Test endpoint event
|
||||||
|
assert.Equal(t, "endpoint", eventEndpoint, "Endpoint event should be 'endpoint'")
|
||||||
|
|
||||||
|
// Verify them in an actual SSE format string
|
||||||
|
messageEvent := "event: " + eventMessage + "\ndata: test\n\n"
|
||||||
|
assert.Contains(t, messageEvent, "event: message", "Message event should format correctly")
|
||||||
|
|
||||||
|
endpointEvent := "event: " + eventEndpoint + "\ndata: test\n\n"
|
||||||
|
assert.Contains(t, endpointEvent, "event: endpoint", "Endpoint event should format correctly")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCollectionKeys checks that collection keys are used correctly
|
||||||
|
func TestCollectionKeys(t *testing.T) {
|
||||||
|
// Verify collection key constants
|
||||||
|
assert.Equal(t, "prompts", keyPrompts, "Prompts key should be 'prompts'")
|
||||||
|
assert.Equal(t, "resources", keyResources, "Resources key should be 'resources'")
|
||||||
|
assert.Equal(t, "tools", keyTools, "Tools key should be 'tools'")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRoleTypes checks that role types are used correctly
|
||||||
|
func TestRoleTypes(t *testing.T) {
|
||||||
|
// Test in annotations
|
||||||
|
annotations := Annotations{
|
||||||
|
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(annotations)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Contains(t, string(data), `"audience":["user","assistant"]`, "Role types should marshal correctly")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMethodNames checks that method names are used correctly
|
||||||
|
func TestMethodNames(t *testing.T) {
|
||||||
|
// Verify method name constants
|
||||||
|
methods := map[string]string{
|
||||||
|
"initialize": methodInitialize,
|
||||||
|
"tools/list": methodToolsList,
|
||||||
|
"tools/call": methodToolsCall,
|
||||||
|
"prompts/list": methodPromptsList,
|
||||||
|
"prompts/get": methodPromptsGet,
|
||||||
|
"resources/list": methodResourcesList,
|
||||||
|
"resources/read": methodResourcesRead,
|
||||||
|
"resources/subscribe": methodResourcesSubscribe,
|
||||||
|
"ping": methodPing,
|
||||||
|
"notifications/initialized": methodNotificationsInitialized,
|
||||||
|
"notifications/cancelled": methodNotificationsCancelled,
|
||||||
|
}
|
||||||
|
|
||||||
|
for expected, actual := range methods {
|
||||||
|
assert.Equal(t, expected, actual, "Method name should be "+expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test in a request
|
||||||
|
for methodName := range methods {
|
||||||
|
req := Request{
|
||||||
|
JsonRpc: jsonRpcVersion,
|
||||||
|
ID: int64(1),
|
||||||
|
Method: methodName,
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(req)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Contains(t, string(data), `"method":"`+methodName+`"`, "Method name should be used in requests")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestEventNames checks that event names are used correctly
|
||||||
|
func TestEventNames(t *testing.T) {
|
||||||
|
// Verify event name constants
|
||||||
|
events := map[string]string{
|
||||||
|
"tools/list_changed": eventToolsListChanged,
|
||||||
|
"prompts/list_changed": eventPromptsListChanged,
|
||||||
|
"resources/list_changed": eventResourcesListChanged,
|
||||||
|
}
|
||||||
|
|
||||||
|
for expected, actual := range events {
|
||||||
|
assert.Equal(t, expected, actual, "Event name should be "+expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test event names in SSE format
|
||||||
|
for _, eventName := range events {
|
||||||
|
sseEvent := "event: " + eventName + "\ndata: test\n\n"
|
||||||
|
assert.Contains(t, sseEvent, "event: "+eventName, "Event name should format correctly in SSE")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -6,7 +6,6 @@
|
|||||||
|
|
||||||
[English](readme.md) | 简体中文
|
[English](readme.md) | 简体中文
|
||||||
|
|
||||||
[](https://github.com/zeromicro/go-zero/actions)
|
|
||||||
[](https://goreportcard.com/report/github.com/zeromicro/go-zero)
|
[](https://goreportcard.com/report/github.com/zeromicro/go-zero)
|
||||||
[](https://goproxy.cn/stats/github.com/zeromicro/go-zero/badges/download-count.svg)
|
[](https://goproxy.cn/stats/github.com/zeromicro/go-zero/badges/download-count.svg)
|
||||||
[](https://codecov.io/gh/zeromicro/go-zero)
|
[](https://codecov.io/gh/zeromicro/go-zero)
|
||||||
@@ -302,6 +301,9 @@ go-zero 已被许多公司用于生产部署,接入场景如在线教育、电
|
|||||||
>103. 爱芯元智半导体股份有限公司
|
>103. 爱芯元智半导体股份有限公司
|
||||||
>104. 杭州升恒科技有限公司
|
>104. 杭州升恒科技有限公司
|
||||||
>105. 昆仑万维科技股份有限公司
|
>105. 昆仑万维科技股份有限公司
|
||||||
|
>106. 无锡盛算信息技术有限公司
|
||||||
|
>107. 深圳市聚货通信息科技有限公司
|
||||||
|
>108. 浙江银盾云科技有限公司
|
||||||
|
|
||||||
如果贵公司也已使用 go-zero,欢迎在 [登记地址](https://github.com/zeromicro/go-zero/issues/602) 登记,仅仅为了推广,不做其它用途。
|
如果贵公司也已使用 go-zero,欢迎在 [登记地址](https://github.com/zeromicro/go-zero/issues/602) 登记,仅仅为了推广,不做其它用途。
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ go-zero is a web and rpc framework with lots of builtin engineering practices. I
|
|||||||
|
|
||||||
<div align=center>
|
<div align=center>
|
||||||
|
|
||||||
[](https://github.com/zeromicro/go-zero/actions)
|
|
||||||
[](https://codecov.io/gh/zeromicro/go-zero)
|
[](https://codecov.io/gh/zeromicro/go-zero)
|
||||||
[](https://goreportcard.com/report/github.com/zeromicro/go-zero)
|
[](https://goreportcard.com/report/github.com/zeromicro/go-zero)
|
||||||
[](https://github.com/zeromicro/go-zero)
|
[](https://github.com/zeromicro/go-zero)
|
||||||
@@ -251,7 +250,3 @@ go-zero enlisted in the [CNCF Cloud Native Landscape](https://landscape.cncf.io/
|
|||||||
## Give a Star! ⭐
|
## Give a Star! ⭐
|
||||||
|
|
||||||
If you like this project or are using it to learn or start your own solution, give it a star to get updates on new releases. Your support matters!
|
If you like this project or are using it to learn or start your own solution, give it a star to get updates on new releases. Your support matters!
|
||||||
|
|
||||||
## Buy me a coffee
|
|
||||||
|
|
||||||
<a href="https://www.buymeacoffee.com/kevwan" target="_blank"><img src="https://cdn.buymeacoffee.com/buttons/v2/default-yellow.png" alt="Buy Me A Coffee" style="height: 60px !important;width: 217px !important;" ></a>
|
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"github.com/zeromicro/go-zero/rest/handler"
|
"github.com/zeromicro/go-zero/rest/handler"
|
||||||
"github.com/zeromicro/go-zero/rest/httpx"
|
"github.com/zeromicro/go-zero/rest/httpx"
|
||||||
"github.com/zeromicro/go-zero/rest/internal"
|
"github.com/zeromicro/go-zero/rest/internal"
|
||||||
|
"github.com/zeromicro/go-zero/rest/internal/header"
|
||||||
"github.com/zeromicro/go-zero/rest/internal/response"
|
"github.com/zeromicro/go-zero/rest/internal/response"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -27,7 +28,10 @@ var ErrSignatureConfig = errors.New("bad config for Signature")
|
|||||||
type engine struct {
|
type engine struct {
|
||||||
conf RestConf
|
conf RestConf
|
||||||
routes []featuredRoutes
|
routes []featuredRoutes
|
||||||
// timeout is the max timeout of all routes
|
// timeout is the max timeout of all routes,
|
||||||
|
// and is used to set http.Server.ReadTimeout and http.Server.WriteTimeout.
|
||||||
|
// this network timeout is used to avoid DoS attacks by sending data slowly
|
||||||
|
// or receiving data slowly with many connections to exhaust server resources.
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
unauthorizedCallback handler.UnauthorizedCallback
|
unauthorizedCallback handler.UnauthorizedCallback
|
||||||
unsignedCallback handler.UnsignedCallback
|
unsignedCallback handler.UnsignedCallback
|
||||||
@@ -54,13 +58,26 @@ func newEngine(c RestConf) *engine {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (ng *engine) addRoutes(r featuredRoutes) {
|
func (ng *engine) addRoutes(r featuredRoutes) {
|
||||||
|
if r.sse {
|
||||||
|
r.routes = buildSSERoutes(r.routes)
|
||||||
|
}
|
||||||
ng.routes = append(ng.routes, r)
|
ng.routes = append(ng.routes, r)
|
||||||
|
|
||||||
// need to guarantee the timeout is the max of all routes
|
ng.mightUpdateTimeout(r)
|
||||||
// otherwise impossible to set http.Server.ReadTimeout & WriteTimeout
|
}
|
||||||
if r.timeout > ng.timeout {
|
|
||||||
ng.timeout = r.timeout
|
func buildSSERoutes(routes []Route) []Route {
|
||||||
|
for i, route := range routes {
|
||||||
|
h := route.Handler
|
||||||
|
routes[i].Handler = func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set(header.ContentType, header.ContentTypeEventStream)
|
||||||
|
w.Header().Set(header.CacheControl, header.CacheControlNoCache)
|
||||||
|
w.Header().Set(header.Connection, header.ConnectionKeepAlive)
|
||||||
|
h(w, r)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return routes
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ng *engine) appendAuthHandler(fr featuredRoutes, chn chain.Chain,
|
func (ng *engine) appendAuthHandler(fr featuredRoutes, chn chain.Chain,
|
||||||
@@ -174,11 +191,12 @@ func (ng *engine) checkedMaxBytes(bytes int64) int64 {
|
|||||||
return ng.conf.MaxBytes
|
return ng.conf.MaxBytes
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ng *engine) checkedTimeout(timeout time.Duration) time.Duration {
|
func (ng *engine) checkedTimeout(timeout *time.Duration) time.Duration {
|
||||||
if timeout > 0 {
|
if timeout != nil {
|
||||||
return timeout
|
return *timeout
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// if timeout not set in featured routes, use global timeout
|
||||||
return time.Duration(ng.conf.Timeout) * time.Millisecond
|
return time.Duration(ng.conf.Timeout) * time.Millisecond
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -210,6 +228,32 @@ func (ng *engine) getShedder(priority bool) load.Shedder {
|
|||||||
return ng.shedder
|
return ng.shedder
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ng *engine) hasTimeout() bool {
|
||||||
|
return ng.conf.Middlewares.Timeout && ng.timeout > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// mightUpdateTimeout checks if the route timeout is greater than the current,
|
||||||
|
// and updates the engine's timeout accordingly.
|
||||||
|
func (ng *engine) mightUpdateTimeout(r featuredRoutes) {
|
||||||
|
// if global timeout is set to 0, it means no need to set read/write timeout
|
||||||
|
// if route timeout is nil, no need to update ng.timeout
|
||||||
|
if ng.timeout == 0 || r.timeout == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// if route timeout is 0 (means no timeout), cannot set read/write timeout
|
||||||
|
if *r.timeout == 0 {
|
||||||
|
ng.timeout = 0
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// need to guarantee the timeout is the max of all routes
|
||||||
|
// otherwise impossible to set http.Server.ReadTimeout & WriteTimeout
|
||||||
|
if *r.timeout > ng.timeout {
|
||||||
|
ng.timeout = *r.timeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// notFoundHandler returns a middleware that handles 404 not found requests.
|
// notFoundHandler returns a middleware that handles 404 not found requests.
|
||||||
func (ng *engine) notFoundHandler(next http.Handler) http.Handler {
|
func (ng *engine) notFoundHandler(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -311,7 +355,7 @@ func (ng *engine) start(router httpx.Router, opts ...StartOption) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// make sure user defined options overwrite default options
|
// make sure user defined options overwrite default options
|
||||||
opts = append([]StartOption{ng.withTimeout()}, opts...)
|
opts = append([]StartOption{ng.withNetworkTimeout()}, opts...)
|
||||||
|
|
||||||
if len(ng.conf.CertFile) == 0 && len(ng.conf.KeyFile) == 0 {
|
if len(ng.conf.CertFile) == 0 && len(ng.conf.KeyFile) == 0 {
|
||||||
return internal.StartHttp(ng.conf.Host, ng.conf.Port, router, opts...)
|
return internal.StartHttp(ng.conf.Host, ng.conf.Port, router, opts...)
|
||||||
@@ -334,18 +378,19 @@ func (ng *engine) use(middleware Middleware) {
|
|||||||
ng.middlewares = append(ng.middlewares, middleware)
|
ng.middlewares = append(ng.middlewares, middleware)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ng *engine) withTimeout() internal.StartOption {
|
func (ng *engine) withNetworkTimeout() internal.StartOption {
|
||||||
return func(svr *http.Server) {
|
return func(svr *http.Server) {
|
||||||
timeout := ng.timeout
|
if !ng.hasTimeout() {
|
||||||
if timeout > 0 {
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// factor 0.8, to avoid clients send longer content-length than the actual content,
|
// factor 0.8, to avoid clients send longer content-length than the actual content,
|
||||||
// without this timeout setting, the server will time out and respond 503 Service Unavailable,
|
// without this timeout setting, the server will time out and respond 503 Service Unavailable,
|
||||||
// which triggers the circuit breaker.
|
// which triggers the circuit breaker.
|
||||||
svr.ReadTimeout = 4 * timeout / 5
|
svr.ReadTimeout = 4 * ng.timeout / 5
|
||||||
// factor 1.1, to avoid servers don't have enough time to write responses.
|
// factor 1.1, to avoid servers don't have enough time to write responses.
|
||||||
// setting the factor less than 1.0 may lead clients not receiving the responses.
|
// setting the factor less than 1.0 may lead clients not receiving the responses.
|
||||||
svr.WriteTimeout = 11 * timeout / 10
|
svr.WriteTimeout = 11 * ng.timeout / 10
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -73,7 +73,17 @@ Verbose: true
|
|||||||
Path: "/",
|
Path: "/",
|
||||||
Handler: func(w http.ResponseWriter, r *http.Request) {},
|
Handler: func(w http.ResponseWriter, r *http.Request) {},
|
||||||
}},
|
}},
|
||||||
timeout: time.Minute,
|
timeout: ptrOfDuration(time.Minute),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
jwt: jwtSetting{},
|
||||||
|
signature: signatureSetting{},
|
||||||
|
routes: []Route{{
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Path: "/",
|
||||||
|
Handler: func(w http.ResponseWriter, r *http.Request) {},
|
||||||
|
}},
|
||||||
|
timeout: ptrOfDuration(0),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
priority: true,
|
priority: true,
|
||||||
@@ -84,7 +94,7 @@ Verbose: true
|
|||||||
Path: "/",
|
Path: "/",
|
||||||
Handler: func(w http.ResponseWriter, r *http.Request) {},
|
Handler: func(w http.ResponseWriter, r *http.Request) {},
|
||||||
}},
|
}},
|
||||||
timeout: time.Second,
|
timeout: ptrOfDuration(time.Second),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
priority: true,
|
priority: true,
|
||||||
@@ -227,8 +237,12 @@ Verbose: true
|
|||||||
}))
|
}))
|
||||||
|
|
||||||
timeout := time.Second * 3
|
timeout := time.Second * 3
|
||||||
if route.timeout > timeout {
|
if route.timeout != nil {
|
||||||
timeout = route.timeout
|
if *route.timeout == 0 {
|
||||||
|
timeout = 0
|
||||||
|
} else if *route.timeout > timeout {
|
||||||
|
timeout = *route.timeout
|
||||||
|
}
|
||||||
}
|
}
|
||||||
assert.Equal(t, timeout, ng.timeout)
|
assert.Equal(t, timeout, ng.timeout)
|
||||||
})
|
})
|
||||||
@@ -236,10 +250,69 @@ Verbose: true
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNewEngine_unsignedCallback(t *testing.T) {
|
||||||
|
priKeyfile, err := fs.TempFilenameWithText(priKey)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
defer os.Remove(priKeyfile)
|
||||||
|
|
||||||
|
yaml := `Name: foo
|
||||||
|
Host: localhost
|
||||||
|
Port: 0
|
||||||
|
Middlewares:
|
||||||
|
Log: false
|
||||||
|
`
|
||||||
|
route := featuredRoutes{
|
||||||
|
priority: true,
|
||||||
|
jwt: jwtSetting{
|
||||||
|
enabled: true,
|
||||||
|
},
|
||||||
|
signature: signatureSetting{
|
||||||
|
enabled: true,
|
||||||
|
SignatureConf: SignatureConf{
|
||||||
|
Strict: true,
|
||||||
|
PrivateKeys: []PrivateKeyConf{
|
||||||
|
{
|
||||||
|
Fingerprint: "a",
|
||||||
|
KeyFile: priKeyfile,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
routes: []Route{{
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Path: "/",
|
||||||
|
Handler: func(w http.ResponseWriter, r *http.Request) {},
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
|
||||||
|
var index int32
|
||||||
|
t.Run(fmt.Sprintf("%s-%v", yaml, route.routes), func(t *testing.T) {
|
||||||
|
var cnf RestConf
|
||||||
|
assert.Nil(t, conf.LoadFromYamlBytes([]byte(yaml), &cnf))
|
||||||
|
ng := newEngine(cnf)
|
||||||
|
if atomic.AddInt32(&index, 1)%2 == 0 {
|
||||||
|
ng.setUnsignedCallback(func(w http.ResponseWriter, r *http.Request,
|
||||||
|
next http.Handler, strict bool, code int) {
|
||||||
|
})
|
||||||
|
}
|
||||||
|
ng.addRoutes(route)
|
||||||
|
ng.use(func(next http.HandlerFunc) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.NotNil(t, ng.start(mockedRouter{}, func(svr *http.Server) {
|
||||||
|
}))
|
||||||
|
|
||||||
|
assert.Equal(t, time.Duration(time.Second*3), ng.timeout)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestEngine_checkedTimeout(t *testing.T) {
|
func TestEngine_checkedTimeout(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
timeout time.Duration
|
timeout *time.Duration
|
||||||
expect time.Duration
|
expect time.Duration
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
@@ -248,17 +321,17 @@ func TestEngine_checkedTimeout(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "less",
|
name: "less",
|
||||||
timeout: time.Millisecond * 500,
|
timeout: ptrOfDuration(time.Millisecond * 500),
|
||||||
expect: time.Millisecond * 500,
|
expect: time.Millisecond * 500,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "equal",
|
name: "equal",
|
||||||
timeout: time.Second,
|
timeout: ptrOfDuration(time.Second),
|
||||||
expect: time.Second,
|
expect: time.Second,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "more",
|
name: "more",
|
||||||
timeout: time.Millisecond * 1500,
|
timeout: ptrOfDuration(time.Millisecond * 1500),
|
||||||
expect: time.Millisecond * 1500,
|
expect: time.Millisecond * 1500,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -394,9 +467,14 @@ func TestEngine_withTimeout(t *testing.T) {
|
|||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
test := test
|
test := test
|
||||||
t.Run(test.name, func(t *testing.T) {
|
t.Run(test.name, func(t *testing.T) {
|
||||||
ng := newEngine(RestConf{Timeout: test.timeout})
|
ng := newEngine(RestConf{
|
||||||
|
Timeout: test.timeout,
|
||||||
|
Middlewares: MiddlewaresConf{
|
||||||
|
Timeout: true,
|
||||||
|
},
|
||||||
|
})
|
||||||
svr := &http.Server{}
|
svr := &http.Server{}
|
||||||
ng.withTimeout()(svr)
|
ng.withNetworkTimeout()(svr)
|
||||||
|
|
||||||
assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*4/5, svr.ReadTimeout)
|
assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*4/5, svr.ReadTimeout)
|
||||||
assert.Equal(t, time.Duration(0), svr.ReadHeaderTimeout)
|
assert.Equal(t, time.Duration(0), svr.ReadHeaderTimeout)
|
||||||
@@ -406,6 +484,62 @@ func TestEngine_withTimeout(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEngine_ReadWriteTimeout(t *testing.T) {
|
||||||
|
logx.Disable()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
timeout int64
|
||||||
|
middleware bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "0/false",
|
||||||
|
timeout: 0,
|
||||||
|
middleware: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "0/true",
|
||||||
|
timeout: 0,
|
||||||
|
middleware: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "set/false",
|
||||||
|
timeout: 1000,
|
||||||
|
middleware: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "both set",
|
||||||
|
timeout: 1000,
|
||||||
|
middleware: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
test := test
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
ng := newEngine(RestConf{
|
||||||
|
Timeout: test.timeout,
|
||||||
|
Middlewares: MiddlewaresConf{
|
||||||
|
Timeout: test.middleware,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
svr := &http.Server{}
|
||||||
|
ng.withNetworkTimeout()(svr)
|
||||||
|
|
||||||
|
assert.Equal(t, time.Duration(0), svr.ReadHeaderTimeout)
|
||||||
|
assert.Equal(t, time.Duration(0), svr.IdleTimeout)
|
||||||
|
|
||||||
|
if test.timeout > 0 && test.middleware {
|
||||||
|
assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*4/5, svr.ReadTimeout)
|
||||||
|
assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*11/10, svr.WriteTimeout)
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, time.Duration(0), svr.ReadTimeout)
|
||||||
|
assert.Equal(t, time.Duration(0), svr.WriteTimeout)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestEngine_start(t *testing.T) {
|
func TestEngine_start(t *testing.T) {
|
||||||
logx.Disable()
|
logx.Disable()
|
||||||
|
|
||||||
|
|||||||
@@ -106,8 +106,8 @@ func (h *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
tw.mu.Lock()
|
tw.mu.Lock()
|
||||||
defer tw.mu.Unlock()
|
defer tw.mu.Unlock()
|
||||||
// there isn't any user-defined middleware before TimoutHandler,
|
// there isn't any user-defined middleware before TimeoutHandler,
|
||||||
// so we can guarantee that cancelation in biz related code won't come here.
|
// so we can guarantee that cancellation in biz related code won't come here.
|
||||||
httpx.ErrorCtx(r.Context(), w, ctx.Err(), func(w http.ResponseWriter, err error) {
|
httpx.ErrorCtx(r.Context(), w, ctx.Err(), func(w http.ResponseWriter, err error) {
|
||||||
if errors.Is(err, context.Canceled) {
|
if errors.Is(err, context.Canceled) {
|
||||||
w.WriteHeader(statusClientClosedRequest)
|
w.WriteHeader(statusClientClosedRequest)
|
||||||
@@ -151,7 +151,7 @@ func (tw *timeoutWriter) Flush() {
|
|||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Header returns the underline temporary http.Header.
|
// Header returns the underlying temporary http.Header.
|
||||||
func (tw *timeoutWriter) Header() http.Header {
|
func (tw *timeoutWriter) Header() http.Header {
|
||||||
return tw.h
|
return tw.h
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -105,7 +105,7 @@ func buildRequest(ctx context.Context, method, url string, data any) (*http.Requ
|
|||||||
req.URL.RawQuery = buildFormQuery(u, val[formKey])
|
req.URL.RawQuery = buildFormQuery(u, val[formKey])
|
||||||
fillHeader(req, val[headerKey])
|
fillHeader(req, val[headerKey])
|
||||||
if hasJsonBody {
|
if hasJsonBody {
|
||||||
req.Header.Set(header.ContentType, header.JsonContentType)
|
req.Header.Set(header.ContentType, header.ContentTypeJson)
|
||||||
}
|
}
|
||||||
|
|
||||||
return req, nil
|
return req, nil
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ func TestDoRequest_NotFound(t *testing.T) {
|
|||||||
defer svr.Close()
|
defer svr.Close()
|
||||||
req, err := http.NewRequest(http.MethodPost, svr.URL, nil)
|
req, err := http.NewRequest(http.MethodPost, svr.URL, nil)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
req.Header.Set(header.ContentType, header.JsonContentType)
|
req.Header.Set(header.ContentType, header.ContentTypeJson)
|
||||||
resp, err := DoRequest(req)
|
resp, err := DoRequest(req)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
|
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ func TestParse(t *testing.T) {
|
|||||||
}
|
}
|
||||||
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("foo", "bar")
|
w.Header().Set("foo", "bar")
|
||||||
w.Header().Set(header.ContentType, header.JsonContentType)
|
w.Header().Set(header.ContentType, header.ContentTypeJson)
|
||||||
w.Write([]byte(`{"name":"kevin","value":100}`))
|
w.Write([]byte(`{"name":"kevin","value":100}`))
|
||||||
}))
|
}))
|
||||||
defer svr.Close()
|
defer svr.Close()
|
||||||
@@ -38,7 +38,7 @@ func TestParseHeaderError(t *testing.T) {
|
|||||||
}
|
}
|
||||||
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("foo", "bar")
|
w.Header().Set("foo", "bar")
|
||||||
w.Header().Set(header.ContentType, header.JsonContentType)
|
w.Header().Set(header.ContentType, header.ContentTypeJson)
|
||||||
}))
|
}))
|
||||||
defer svr.Close()
|
defer svr.Close()
|
||||||
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
|
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
|
||||||
@@ -54,7 +54,7 @@ func TestParseNoBody(t *testing.T) {
|
|||||||
}
|
}
|
||||||
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("foo", "bar")
|
w.Header().Set("foo", "bar")
|
||||||
w.Header().Set(header.ContentType, header.JsonContentType)
|
w.Header().Set(header.ContentType, header.ContentTypeJson)
|
||||||
}))
|
}))
|
||||||
defer svr.Close()
|
defer svr.Close()
|
||||||
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
|
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
|
||||||
@@ -72,7 +72,7 @@ func TestParseWithZeroValue(t *testing.T) {
|
|||||||
}
|
}
|
||||||
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("foo", "0")
|
w.Header().Set("foo", "0")
|
||||||
w.Header().Set(header.ContentType, header.JsonContentType)
|
w.Header().Set(header.ContentType, header.ContentTypeJson)
|
||||||
w.Write([]byte(`{"bar":0}`))
|
w.Write([]byte(`{"bar":0}`))
|
||||||
}))
|
}))
|
||||||
defer svr.Close()
|
defer svr.Close()
|
||||||
@@ -90,7 +90,7 @@ func TestParseWithNegativeContentLength(t *testing.T) {
|
|||||||
Bar int `json:"bar"`
|
Bar int `json:"bar"`
|
||||||
}
|
}
|
||||||
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set(header.ContentType, header.JsonContentType)
|
w.Header().Set(header.ContentType, header.ContentTypeJson)
|
||||||
w.Write([]byte(`{"bar":0}`))
|
w.Write([]byte(`{"bar":0}`))
|
||||||
}))
|
}))
|
||||||
defer svr.Close()
|
defer svr.Close()
|
||||||
@@ -124,7 +124,7 @@ func TestParseWithNegativeContentLength(t *testing.T) {
|
|||||||
func TestParseWithNegativeContentLengthNoBody(t *testing.T) {
|
func TestParseWithNegativeContentLengthNoBody(t *testing.T) {
|
||||||
var val struct{}
|
var val struct{}
|
||||||
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set(header.ContentType, header.JsonContentType)
|
w.Header().Set(header.ContentType, header.ContentTypeJson)
|
||||||
}))
|
}))
|
||||||
defer svr.Close()
|
defer svr.Close()
|
||||||
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
|
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
|
||||||
@@ -156,7 +156,7 @@ func TestParseWithNegativeContentLengthNoBody(t *testing.T) {
|
|||||||
func TestParseJsonBody_BodyError(t *testing.T) {
|
func TestParseJsonBody_BodyError(t *testing.T) {
|
||||||
var val struct{}
|
var val struct{}
|
||||||
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set(header.ContentType, header.JsonContentType)
|
w.Header().Set(header.ContentType, header.ContentTypeJson)
|
||||||
}))
|
}))
|
||||||
defer svr.Close()
|
defer svr.Close()
|
||||||
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
|
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ func TestNamedService_DoRequestPost(t *testing.T) {
|
|||||||
service := NewService("foo")
|
service := NewService("foo")
|
||||||
req, err := http.NewRequest(http.MethodPost, svr.URL, nil)
|
req, err := http.NewRequest(http.MethodPost, svr.URL, nil)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
req.Header.Set(header.ContentType, header.JsonContentType)
|
req.Header.Set(header.ContentType, header.ContentTypeJson)
|
||||||
resp, err := service.DoRequest(req)
|
resp, err := service.DoRequest(req)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
|
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
|
||||||
|
|||||||
@@ -476,7 +476,7 @@ func TestParseJsonBody(t *testing.T) {
|
|||||||
|
|
||||||
body := `{"name":"kevin", "age": 18}`
|
body := `{"name":"kevin", "age": 18}`
|
||||||
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
|
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
|
||||||
r.Header.Set(ContentType, header.JsonContentType)
|
r.Header.Set(ContentType, header.ContentTypeJson)
|
||||||
|
|
||||||
if assert.NoError(t, Parse(r, &v)) {
|
if assert.NoError(t, Parse(r, &v)) {
|
||||||
assert.Equal(t, "kevin", v.Name)
|
assert.Equal(t, "kevin", v.Name)
|
||||||
@@ -492,7 +492,7 @@ func TestParseJsonBody(t *testing.T) {
|
|||||||
|
|
||||||
body := `{"name":"kevin", "ag": 18}`
|
body := `{"name":"kevin", "ag": 18}`
|
||||||
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
|
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
|
||||||
r.Header.Set(ContentType, header.JsonContentType)
|
r.Header.Set(ContentType, header.ContentTypeJson)
|
||||||
|
|
||||||
assert.Error(t, Parse(r, &v))
|
assert.Error(t, Parse(r, &v))
|
||||||
})
|
})
|
||||||
@@ -517,7 +517,7 @@ func TestParseJsonBody(t *testing.T) {
|
|||||||
|
|
||||||
body := `[{"name":"kevin", "age": 18}]`
|
body := `[{"name":"kevin", "age": 18}]`
|
||||||
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
|
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
|
||||||
r.Header.Set(ContentType, header.JsonContentType)
|
r.Header.Set(ContentType, header.ContentTypeJson)
|
||||||
|
|
||||||
assert.NoError(t, Parse(r, &v))
|
assert.NoError(t, Parse(r, &v))
|
||||||
assert.Equal(t, 1, len(v))
|
assert.Equal(t, 1, len(v))
|
||||||
@@ -537,7 +537,7 @@ func TestParseJsonBody(t *testing.T) {
|
|||||||
|
|
||||||
body := `[{"name":"apple", "age": 18}]`
|
body := `[{"name":"apple", "age": 18}]`
|
||||||
r := httptest.NewRequest(http.MethodPost, "/a?product=tree", strings.NewReader(body))
|
r := httptest.NewRequest(http.MethodPost, "/a?product=tree", strings.NewReader(body))
|
||||||
r.Header.Set(ContentType, header.JsonContentType)
|
r.Header.Set(ContentType, header.ContentTypeJson)
|
||||||
|
|
||||||
assert.NoError(t, Parse(r, &v))
|
assert.NoError(t, Parse(r, &v))
|
||||||
assert.Equal(t, 1, len(v))
|
assert.Equal(t, 1, len(v))
|
||||||
@@ -555,7 +555,7 @@ func TestParseJsonBody(t *testing.T) {
|
|||||||
body, _ := json.Marshal(v1)
|
body, _ := json.Marshal(v1)
|
||||||
t.Logf("body:%s", string(body))
|
t.Logf("body:%s", string(body))
|
||||||
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(body)))
|
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(body)))
|
||||||
r.Header.Set(ContentType, header.JsonContentType)
|
r.Header.Set(ContentType, header.ContentTypeJson)
|
||||||
var v2 v
|
var v2 v
|
||||||
err := ParseJsonBody(r, &v2)
|
err := ParseJsonBody(r, &v2)
|
||||||
if assert.NoError(t, err) {
|
if assert.NoError(t, err) {
|
||||||
@@ -609,7 +609,7 @@ func TestParseHeaders(t *testing.T) {
|
|||||||
request.Header.Add("addrs", "addr2")
|
request.Header.Add("addrs", "addr2")
|
||||||
request.Header.Add("X-Forwarded-For", "10.0.10.11")
|
request.Header.Add("X-Forwarded-For", "10.0.10.11")
|
||||||
request.Header.Add("x-real-ip", "10.0.11.10")
|
request.Header.Add("x-real-ip", "10.0.11.10")
|
||||||
request.Header.Add("Accept", header.JsonContentType)
|
request.Header.Add("Accept", header.ContentTypeJson)
|
||||||
err = ParseHeaders(request, &v)
|
err = ParseHeaders(request, &v)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@@ -619,7 +619,7 @@ func TestParseHeaders(t *testing.T) {
|
|||||||
assert.Equal(t, []string{"addr1", "addr2"}, v.Addrs)
|
assert.Equal(t, []string{"addr1", "addr2"}, v.Addrs)
|
||||||
assert.Equal(t, "10.0.10.11", v.XForwardedFor)
|
assert.Equal(t, "10.0.10.11", v.XForwardedFor)
|
||||||
assert.Equal(t, "10.0.11.10", v.XRealIP)
|
assert.Equal(t, "10.0.11.10", v.XRealIP)
|
||||||
assert.Equal(t, header.JsonContentType, v.Accept)
|
assert.Equal(t, header.ContentTypeJson, v.Accept)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseHeaders_Error(t *testing.T) {
|
func TestParseHeaders_Error(t *testing.T) {
|
||||||
@@ -711,7 +711,7 @@ func TestParseWithFloatPtr(t *testing.T) {
|
|||||||
}
|
}
|
||||||
body := `{"weightFloat32": 3.2}`
|
body := `{"weightFloat32": 3.2}`
|
||||||
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
|
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
|
||||||
r.Header.Set(ContentType, header.JsonContentType)
|
r.Header.Set(ContentType, header.ContentTypeJson)
|
||||||
|
|
||||||
if assert.NoError(t, Parse(r, &v)) {
|
if assert.NoError(t, Parse(r, &v)) {
|
||||||
assert.Equal(t, float32(3.2), *v.WeightFloat32)
|
assert.Equal(t, float32(3.2), *v.WeightFloat32)
|
||||||
|
|||||||
@@ -179,7 +179,7 @@ func doWriteJson(w http.ResponseWriter, code int, v any) error {
|
|||||||
return fmt.Errorf("marshal json failed, error: %w", err)
|
return fmt.Errorf("marshal json failed, error: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
w.Header().Set(ContentType, header.JsonContentType)
|
w.Header().Set(ContentType, header.ContentTypeJson)
|
||||||
w.WriteHeader(code)
|
w.WriteHeader(code)
|
||||||
|
|
||||||
if n, err := w.Write(bs); err != nil {
|
if n, err := w.Write(bs); err != nil {
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ const (
|
|||||||
// ContentType means Content-Type.
|
// ContentType means Content-Type.
|
||||||
ContentType = header.ContentType
|
ContentType = header.ContentType
|
||||||
// JsonContentType means application/json.
|
// JsonContentType means application/json.
|
||||||
JsonContentType = header.JsonContentType
|
JsonContentType = header.ContentTypeJson
|
||||||
// KeyField means key.
|
// KeyField means key.
|
||||||
KeyField = "key"
|
KeyField = "key"
|
||||||
// SecretField means secret.
|
// SecretField means secret.
|
||||||
|
|||||||
@@ -2,15 +2,16 @@ package fileserver
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"path"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Middleware returns a middleware that serves files from the given file system.
|
// Middleware returns a middleware that serves files from the given file system.
|
||||||
func Middleware(path string, fs http.FileSystem) func(http.HandlerFunc) http.HandlerFunc {
|
func Middleware(upath string, fs http.FileSystem) func(http.HandlerFunc) http.HandlerFunc {
|
||||||
fileServer := http.FileServer(fs)
|
fileServer := http.FileServer(fs)
|
||||||
pathWithoutTrailSlash := ensureNoTrailingSlash(path)
|
pathWithoutTrailSlash := ensureNoTrailingSlash(upath)
|
||||||
canServe := createServeChecker(path, fs)
|
canServe := createServeChecker(upath, fs)
|
||||||
|
|
||||||
return func(next http.HandlerFunc) http.HandlerFunc {
|
return func(next http.HandlerFunc) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -28,9 +29,22 @@ func createFileChecker(fs http.FileSystem) func(string) bool {
|
|||||||
var lock sync.RWMutex
|
var lock sync.RWMutex
|
||||||
fileChecker := make(map[string]bool)
|
fileChecker := make(map[string]bool)
|
||||||
|
|
||||||
return func(path string) bool {
|
return func(upath string) bool {
|
||||||
|
// Emulate http.Dir.Open’s path normalization for embed.FS.Open.
|
||||||
|
// http.FileServer redirects any request ending in "/index.html"
|
||||||
|
// to the same path without the final "index.html".
|
||||||
|
// So the path here may be empty or end with a "/".
|
||||||
|
// http.Dir.Open uses this logic to clean the path,
|
||||||
|
// correctly handling those two cases.
|
||||||
|
// embed.FS doesn’t perform this normalization, so we apply the same logic here.
|
||||||
|
upath = path.Clean("/" + upath)[1:]
|
||||||
|
if len(upath) == 0 {
|
||||||
|
// if the path is empty, we use "." to open the current directory
|
||||||
|
upath = "."
|
||||||
|
}
|
||||||
|
|
||||||
lock.RLock()
|
lock.RLock()
|
||||||
exist, ok := fileChecker[path]
|
exist, ok := fileChecker[upath]
|
||||||
lock.RUnlock()
|
lock.RUnlock()
|
||||||
if ok {
|
if ok {
|
||||||
return exist
|
return exist
|
||||||
@@ -39,9 +53,9 @@ func createFileChecker(fs http.FileSystem) func(string) bool {
|
|||||||
lock.Lock()
|
lock.Lock()
|
||||||
defer lock.Unlock()
|
defer lock.Unlock()
|
||||||
|
|
||||||
file, err := fs.Open(path)
|
file, err := fs.Open(upath)
|
||||||
exist = err == nil
|
exist = err == nil
|
||||||
fileChecker[path] = exist
|
fileChecker[upath] = exist
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -51,8 +65,8 @@ func createFileChecker(fs http.FileSystem) func(string) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func createServeChecker(path string, fs http.FileSystem) func(r *http.Request) bool {
|
func createServeChecker(upath string, fs http.FileSystem) func(r *http.Request) bool {
|
||||||
pathWithTrailSlash := ensureTrailingSlash(path)
|
pathWithTrailSlash := ensureTrailingSlash(upath)
|
||||||
fileChecker := createFileChecker(fs)
|
fileChecker := createFileChecker(fs)
|
||||||
|
|
||||||
return func(r *http.Request) bool {
|
return func(r *http.Request) bool {
|
||||||
@@ -62,18 +76,18 @@ func createServeChecker(path string, fs http.FileSystem) func(r *http.Request) b
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ensureTrailingSlash(path string) string {
|
func ensureTrailingSlash(upath string) string {
|
||||||
if strings.HasSuffix(path, "/") {
|
if strings.HasSuffix(upath, "/") {
|
||||||
return path
|
return upath
|
||||||
}
|
}
|
||||||
|
|
||||||
return path + "/"
|
return upath + "/"
|
||||||
}
|
}
|
||||||
|
|
||||||
func ensureNoTrailingSlash(path string) string {
|
func ensureNoTrailingSlash(upath string) string {
|
||||||
if strings.HasSuffix(path, "/") {
|
if strings.HasSuffix(upath, "/") {
|
||||||
return path[:len(path)-1]
|
return upath[:len(upath)-1]
|
||||||
}
|
}
|
||||||
|
|
||||||
return path
|
return upath
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
package fileserver
|
package fileserver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"embed"
|
||||||
|
"io/fs"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -61,6 +63,46 @@ func TestMiddleware(t *testing.T) {
|
|||||||
requestPath: "/ws",
|
requestPath: "/ws",
|
||||||
expectedStatus: http.StatusAlreadyReported,
|
expectedStatus: http.StatusAlreadyReported,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
// http.FileServer redirects any request ending in "/index.html"
|
||||||
|
// to the same path, without the final "index.html".
|
||||||
|
{
|
||||||
|
name: "Serve index.html",
|
||||||
|
path: "/static",
|
||||||
|
dir: "testdata",
|
||||||
|
requestPath: "/static/index.html",
|
||||||
|
expectedStatus: http.StatusMovedPermanently,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Serve index.html with path with trailing slash",
|
||||||
|
path: "/static/",
|
||||||
|
dir: "testdata",
|
||||||
|
requestPath: "/static/index.html",
|
||||||
|
expectedStatus: http.StatusMovedPermanently,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Serve index.html in a nested directory",
|
||||||
|
path: "/static",
|
||||||
|
dir: "testdata",
|
||||||
|
requestPath: "/static/nested/index.html",
|
||||||
|
expectedStatus: http.StatusMovedPermanently,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Request index.html indirectly",
|
||||||
|
path: "/static",
|
||||||
|
dir: "testdata",
|
||||||
|
requestPath: "/static/",
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedContent: "hello",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Request index.html in a nested directory indirectly",
|
||||||
|
path: "/static",
|
||||||
|
dir: "testdata",
|
||||||
|
requestPath: "/static/nested/",
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedContent: "hello",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@@ -87,6 +129,128 @@ func TestMiddleware(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
//go:embed testdata
|
||||||
|
testdataFS embed.FS
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMiddleware_embedFS(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
path string
|
||||||
|
requestPath string
|
||||||
|
expectedStatus int
|
||||||
|
expectedContent string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Serve static file",
|
||||||
|
path: "/static",
|
||||||
|
requestPath: "/static/example.txt",
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedContent: "1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Path with trailing slash",
|
||||||
|
path: "/static/",
|
||||||
|
requestPath: "/static/example.txt",
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedContent: "1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Root path",
|
||||||
|
path: "/",
|
||||||
|
requestPath: "/example.txt",
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedContent: "1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Pass through non-matching path",
|
||||||
|
path: "/static/",
|
||||||
|
requestPath: "/other/path",
|
||||||
|
expectedStatus: http.StatusAlreadyReported,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Not exist file",
|
||||||
|
path: "/assets",
|
||||||
|
requestPath: "/assets/not-exist.txt",
|
||||||
|
expectedStatus: http.StatusAlreadyReported,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Not exist file in root",
|
||||||
|
path: "/",
|
||||||
|
requestPath: "/not-exist.txt",
|
||||||
|
expectedStatus: http.StatusAlreadyReported,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "websocket request",
|
||||||
|
path: "/",
|
||||||
|
requestPath: "/ws",
|
||||||
|
expectedStatus: http.StatusAlreadyReported,
|
||||||
|
},
|
||||||
|
|
||||||
|
// http.FileServer redirects any request ending in "/index.html"
|
||||||
|
// to the same path, without the final "index.html".
|
||||||
|
{
|
||||||
|
name: "Serve index.html",
|
||||||
|
path: "/static",
|
||||||
|
requestPath: "/static/index.html",
|
||||||
|
expectedStatus: http.StatusMovedPermanently,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Serve index.html with path with trailing slash",
|
||||||
|
path: "/static/",
|
||||||
|
requestPath: "/static/index.html",
|
||||||
|
expectedStatus: http.StatusMovedPermanently,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Serve index.html in a nested directory",
|
||||||
|
path: "/static",
|
||||||
|
requestPath: "/static/nested/index.html",
|
||||||
|
expectedStatus: http.StatusMovedPermanently,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Request index.html indirectly",
|
||||||
|
path: "/static",
|
||||||
|
requestPath: "/static/",
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedContent: "hello",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Request index.html in a nested directory indirectly",
|
||||||
|
path: "/static",
|
||||||
|
requestPath: "/static/nested/",
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedContent: "hello",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
subFS, err := fs.Sub(testdataFS, "testdata")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
middleware := Middleware(tt.path, http.FS(subFS))
|
||||||
|
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusAlreadyReported)
|
||||||
|
})
|
||||||
|
|
||||||
|
handlerToTest := middleware(nextHandler)
|
||||||
|
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, tt.requestPath, nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handlerToTest.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
assert.Equal(t, tt.expectedStatus, rr.Code)
|
||||||
|
if len(tt.expectedContent) > 0 {
|
||||||
|
assert.Equal(t, tt.expectedContent, rr.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestEnsureTrailingSlash(t *testing.T) {
|
func TestEnsureTrailingSlash(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
input string
|
input string
|
||||||
|
|||||||
1
rest/internal/fileserver/testdata/index.html
vendored
Normal file
1
rest/internal/fileserver/testdata/index.html
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
hello
|
||||||
1
rest/internal/fileserver/testdata/nested/index.html
vendored
Normal file
1
rest/internal/fileserver/testdata/nested/index.html
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
hello
|
||||||
@@ -3,8 +3,18 @@ package header
|
|||||||
const (
|
const (
|
||||||
// ApplicationJson stands for application/json.
|
// ApplicationJson stands for application/json.
|
||||||
ApplicationJson = "application/json"
|
ApplicationJson = "application/json"
|
||||||
|
// CacheControl is the header key for Cache-Control.
|
||||||
|
CacheControl = "Cache-Control"
|
||||||
|
// CacheControlNoCache is the value for Cache-Control: no-cache.
|
||||||
|
CacheControlNoCache = "no-cache"
|
||||||
|
// Connection is the header key for Connection.
|
||||||
|
Connection = "Connection"
|
||||||
|
// ConnectionKeepAlive is the value for Connection: keep-alive.
|
||||||
|
ConnectionKeepAlive = "keep-alive"
|
||||||
// ContentType is the header key for Content-Type.
|
// ContentType is the header key for Content-Type.
|
||||||
ContentType = "Content-Type"
|
ContentType = "Content-Type"
|
||||||
// JsonContentType is the content type for JSON.
|
// ContentTypeJson is the content type for JSON.
|
||||||
JsonContentType = "application/json; charset=utf-8"
|
ContentTypeJson = "application/json; charset=utf-8"
|
||||||
|
// ContentTypeEventStream is the content type for event stream.
|
||||||
|
ContentTypeEventStream = "text/event-stream"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -628,7 +628,7 @@ func TestParseWrappedRequest(t *testing.T) {
|
|||||||
func TestParseWrappedGetRequestWithJsonHeader(t *testing.T) {
|
func TestParseWrappedGetRequestWithJsonHeader(t *testing.T) {
|
||||||
r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017", bytes.NewReader(nil))
|
r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017", bytes.NewReader(nil))
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
r.Header.Set(httpx.ContentType, header.JsonContentType)
|
r.Header.Set(httpx.ContentType, header.ContentTypeJson)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
Request struct {
|
Request struct {
|
||||||
@@ -661,7 +661,7 @@ func TestParseWrappedGetRequestWithJsonHeader(t *testing.T) {
|
|||||||
func TestParseWrappedHeadRequestWithJsonHeader(t *testing.T) {
|
func TestParseWrappedHeadRequestWithJsonHeader(t *testing.T) {
|
||||||
r, err := http.NewRequest(http.MethodHead, "http://hello.com/kevin/2017", bytes.NewReader(nil))
|
r, err := http.NewRequest(http.MethodHead, "http://hello.com/kevin/2017", bytes.NewReader(nil))
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
r.Header.Set(httpx.ContentType, header.JsonContentType)
|
r.Header.Set(httpx.ContentType, header.ContentTypeJson)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
Request struct {
|
Request struct {
|
||||||
@@ -758,7 +758,7 @@ func TestParseWithAllUtf8(t *testing.T) {
|
|||||||
r, err := http.NewRequest(http.MethodPost, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000",
|
r, err := http.NewRequest(http.MethodPost, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000",
|
||||||
bytes.NewBufferString(`{"location": "shanghai", "time": 20170912}`))
|
bytes.NewBufferString(`{"location": "shanghai", "time": 20170912}`))
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
r.Header.Set(httpx.ContentType, header.JsonContentType)
|
r.Header.Set(httpx.ContentType, header.ContentTypeJson)
|
||||||
|
|
||||||
router := NewRouter()
|
router := NewRouter()
|
||||||
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
|
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
|
||||||
@@ -948,7 +948,7 @@ func TestParseWithMissingAllPaths(t *testing.T) {
|
|||||||
func TestParseGetWithContentLengthHeader(t *testing.T) {
|
func TestParseGetWithContentLengthHeader(t *testing.T) {
|
||||||
r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000", nil)
|
r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000", nil)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
r.Header.Set(httpx.ContentType, header.JsonContentType)
|
r.Header.Set(httpx.ContentType, header.ContentTypeJson)
|
||||||
r.Header.Set(contentLength, "1024")
|
r.Header.Set(contentLength, "1024")
|
||||||
|
|
||||||
router := NewRouter()
|
router := NewRouter()
|
||||||
@@ -976,7 +976,7 @@ func TestParseJsonPostWithTypeMismatch(t *testing.T) {
|
|||||||
r, err := http.NewRequest(http.MethodPost, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000",
|
r, err := http.NewRequest(http.MethodPost, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000",
|
||||||
bytes.NewBufferString(`{"time": "20170912"}`))
|
bytes.NewBufferString(`{"time": "20170912"}`))
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
r.Header.Set(httpx.ContentType, header.JsonContentType)
|
r.Header.Set(httpx.ContentType, header.ContentTypeJson)
|
||||||
|
|
||||||
router := NewRouter()
|
router := NewRouter()
|
||||||
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
|
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
|
||||||
@@ -1002,7 +1002,7 @@ func TestParseJsonPostWithInt2String(t *testing.T) {
|
|||||||
r, err := http.NewRequest(http.MethodPost, "http://hello.com/kevin/2017",
|
r, err := http.NewRequest(http.MethodPost, "http://hello.com/kevin/2017",
|
||||||
bytes.NewBufferString(`{"time": 20170912}`))
|
bytes.NewBufferString(`{"time": 20170912}`))
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
r.Header.Set(httpx.ContentType, header.JsonContentType)
|
r.Header.Set(httpx.ContentType, header.ContentTypeJson)
|
||||||
|
|
||||||
router := NewRouter()
|
router := NewRouter()
|
||||||
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
|
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
|
||||||
|
|||||||
@@ -63,6 +63,11 @@ func NewServer(c RestConf, opts ...RunOption) (*Server, error) {
|
|||||||
return server, nil
|
return server, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AddRoute adds given route into the Server.
|
||||||
|
func (s *Server) AddRoute(r Route, opts ...RouteOption) {
|
||||||
|
s.AddRoutes([]Route{r}, opts...)
|
||||||
|
}
|
||||||
|
|
||||||
// AddRoutes add given routes into the Server.
|
// AddRoutes add given routes into the Server.
|
||||||
func (s *Server) AddRoutes(rs []Route, opts ...RouteOption) {
|
func (s *Server) AddRoutes(rs []Route, opts ...RouteOption) {
|
||||||
r := featuredRoutes{
|
r := featuredRoutes{
|
||||||
@@ -74,11 +79,6 @@ func (s *Server) AddRoutes(rs []Route, opts ...RouteOption) {
|
|||||||
s.ngin.addRoutes(r)
|
s.ngin.addRoutes(r)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddRoute adds given route into the Server.
|
|
||||||
func (s *Server) AddRoute(r Route, opts ...RouteOption) {
|
|
||||||
s.AddRoutes([]Route{r}, opts...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// PrintRoutes prints the added routes to stdout.
|
// PrintRoutes prints the added routes to stdout.
|
||||||
func (s *Server) PrintRoutes() {
|
func (s *Server) PrintRoutes() {
|
||||||
s.ngin.print()
|
s.ngin.print()
|
||||||
@@ -119,6 +119,16 @@ func (s *Server) Use(middleware Middleware) {
|
|||||||
s.ngin.use(middleware)
|
s.ngin.use(middleware)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// build builds the Server and binds the routes to the router.
|
||||||
|
func (s *Server) build() error {
|
||||||
|
return s.ngin.bindRoutes(s.router)
|
||||||
|
}
|
||||||
|
|
||||||
|
// serve serves the HTTP requests using the Server's router.
|
||||||
|
func (s *Server) serve(w http.ResponseWriter, r *http.Request) {
|
||||||
|
s.router.ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
// ToMiddleware converts the given handler to a Middleware.
|
// ToMiddleware converts the given handler to a Middleware.
|
||||||
func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware {
|
func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware {
|
||||||
return func(handle http.HandlerFunc) http.HandlerFunc {
|
return func(handle http.HandlerFunc) http.HandlerFunc {
|
||||||
@@ -279,10 +289,18 @@ func WithSignature(signature SignatureConf) RouteOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithSSE returns a RouteOption to enable server-sent events.
|
||||||
|
func WithSSE() RouteOption {
|
||||||
|
return func(r *featuredRoutes) {
|
||||||
|
r.sse = true
|
||||||
|
r.timeout = ptrOfDuration(0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// WithTimeout returns a RouteOption to set timeout with given value.
|
// WithTimeout returns a RouteOption to set timeout with given value.
|
||||||
func WithTimeout(timeout time.Duration) RouteOption {
|
func WithTimeout(timeout time.Duration) RouteOption {
|
||||||
return func(r *featuredRoutes) {
|
return func(r *featuredRoutes) {
|
||||||
r.timeout = timeout
|
r.timeout = &timeout
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -317,6 +335,10 @@ func handleError(err error) {
|
|||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ptrOfDuration(d time.Duration) *time.Duration {
|
||||||
|
return &d
|
||||||
|
}
|
||||||
|
|
||||||
func validateSecret(secret string) {
|
func validateSecret(secret string) {
|
||||||
if len(secret) < 8 {
|
if len(secret) < 8 {
|
||||||
panic("secret's length can't be less than 8")
|
panic("secret's length can't be less than 8")
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import (
|
|||||||
"github.com/zeromicro/go-zero/rest/chain"
|
"github.com/zeromicro/go-zero/rest/chain"
|
||||||
"github.com/zeromicro/go-zero/rest/httpx"
|
"github.com/zeromicro/go-zero/rest/httpx"
|
||||||
"github.com/zeromicro/go-zero/rest/internal/cors"
|
"github.com/zeromicro/go-zero/rest/internal/cors"
|
||||||
|
"github.com/zeromicro/go-zero/rest/internal/header"
|
||||||
"github.com/zeromicro/go-zero/rest/router"
|
"github.com/zeromicro/go-zero/rest/router"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -344,7 +345,7 @@ func TestWithPriority(t *testing.T) {
|
|||||||
func TestWithTimeout(t *testing.T) {
|
func TestWithTimeout(t *testing.T) {
|
||||||
var fr featuredRoutes
|
var fr featuredRoutes
|
||||||
WithTimeout(time.Hour)(&fr)
|
WithTimeout(time.Hour)(&fr)
|
||||||
assert.Equal(t, time.Hour, fr.timeout)
|
assert.Equal(t, time.Hour, *fr.timeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWithTLSConfig(t *testing.T) {
|
func TestWithTLSConfig(t *testing.T) {
|
||||||
@@ -754,6 +755,40 @@ Port: 54321
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestServerEventStream(t *testing.T) {
|
||||||
|
server := MustNewServer(RestConf{})
|
||||||
|
server.AddRoutes([]Route{
|
||||||
|
{
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Path: "/foo",
|
||||||
|
Handler: func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte("foo"))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Path: "/bar",
|
||||||
|
Handler: func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte("bar"))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, WithSSE())
|
||||||
|
|
||||||
|
check := func(val string) {
|
||||||
|
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("/%s", val), http.NoBody)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
serve(server, rr, req)
|
||||||
|
assert.Equal(t, http.StatusOK, rr.Code)
|
||||||
|
assert.Equal(t, header.ContentTypeEventStream, rr.Header().Get(header.ContentType))
|
||||||
|
assert.Equal(t, header.CacheControlNoCache, rr.Header().Get(header.CacheControl))
|
||||||
|
assert.Equal(t, header.ConnectionKeepAlive, rr.Header().Get(header.Connection))
|
||||||
|
assert.Equal(t, val, rr.Body.String())
|
||||||
|
}
|
||||||
|
check("foo")
|
||||||
|
check("bar")
|
||||||
|
}
|
||||||
|
|
||||||
//go:embed testdata
|
//go:embed testdata
|
||||||
var content embed.FS
|
var content embed.FS
|
||||||
|
|
||||||
@@ -770,7 +805,7 @@ func TestServerEmbedFileSystem(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// serve is for test purpose, allow developer to do a unit test with
|
// serve is for test purpose, allow developer to do a unit test with
|
||||||
// all defined router without starting an HTTP Server.
|
// all defined routes without starting an HTTP Server.
|
||||||
//
|
//
|
||||||
// For example:
|
// For example:
|
||||||
//
|
//
|
||||||
@@ -784,6 +819,6 @@ func TestServerEmbedFileSystem(t *testing.T) {
|
|||||||
// serve(server, w, r)
|
// serve(server, w, r)
|
||||||
// // verify the response
|
// // verify the response
|
||||||
func serve(s *Server, w http.ResponseWriter, r *http.Request) {
|
func serve(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||||
s.ngin.bindRoutes(s.router)
|
_ = s.build()
|
||||||
s.router.ServeHTTP(w, r)
|
s.serve(w, r)
|
||||||
}
|
}
|
||||||
|
|||||||
27
rest/serverless.go
Normal file
27
rest/serverless.go
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
package rest
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
// Serverless is a wrapper around Server that allows it to be used in serverless environments.
|
||||||
|
type Serverless struct {
|
||||||
|
server *Server
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewServerless creates a new Serverless instance from the provided Server.
|
||||||
|
func NewServerless(server *Server) (*Serverless, error) {
|
||||||
|
// Ensure the server is built before using it in a serverless context.
|
||||||
|
// Why not call server.build() when serving requests,
|
||||||
|
// is because we need to ensure fail fast behavior.
|
||||||
|
if err := server.build(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Serverless{
|
||||||
|
server: server,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serve handles HTTP requests by delegating them to the underlying Server instance.
|
||||||
|
func (s *Serverless) Serve(w http.ResponseWriter, r *http.Request) {
|
||||||
|
s.server.serve(w, r)
|
||||||
|
}
|
||||||
67
rest/serverless_test.go
Normal file
67
rest/serverless_test.go
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
package rest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/zeromicro/go-zero/core/conf"
|
||||||
|
"github.com/zeromicro/go-zero/core/logx/logtest"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewServerless(t *testing.T) {
|
||||||
|
logtest.Discard(t)
|
||||||
|
|
||||||
|
const configYaml = `
|
||||||
|
Name: foo
|
||||||
|
Host: localhost
|
||||||
|
Port: 0
|
||||||
|
`
|
||||||
|
var cnf RestConf
|
||||||
|
assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
|
||||||
|
|
||||||
|
svr, err := NewServer(cnf)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
svr.AddRoute(Route{
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Path: "/",
|
||||||
|
Handler: func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte("Hello World"))
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
serverless, err := NewServerless(svr)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
serverless.Serve(w, r)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Equal(t, "Hello World", w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewServerlessWithError(t *testing.T) {
|
||||||
|
logtest.Discard(t)
|
||||||
|
|
||||||
|
const configYaml = `
|
||||||
|
Name: foo
|
||||||
|
Host: localhost
|
||||||
|
Port: 0
|
||||||
|
`
|
||||||
|
var cnf RestConf
|
||||||
|
assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
|
||||||
|
|
||||||
|
svr, err := NewServer(cnf)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
svr.AddRoute(Route{
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Path: "notstartwith/",
|
||||||
|
Handler: nil,
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err = NewServerless(svr)
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
@@ -31,10 +31,11 @@ type (
|
|||||||
}
|
}
|
||||||
|
|
||||||
featuredRoutes struct {
|
featuredRoutes struct {
|
||||||
timeout time.Duration
|
timeout *time.Duration
|
||||||
priority bool
|
priority bool
|
||||||
jwt jwtSetting
|
jwt jwtSetting
|
||||||
signature signatureSetting
|
signature signatureSetting
|
||||||
|
sse bool
|
||||||
routes []Route
|
routes []Route
|
||||||
maxBytes int64
|
maxBytes int64
|
||||||
}
|
}
|
||||||
|
|||||||
1
tools/goctl/.gitignore
vendored
Normal file
1
tools/goctl/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
dist
|
||||||
@@ -2,6 +2,7 @@ package api
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/api/apigen"
|
"github.com/zeromicro/go-zero/tools/goctl/api/apigen"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/api/dartgen"
|
"github.com/zeromicro/go-zero/tools/goctl/api/dartgen"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/api/docgen"
|
"github.com/zeromicro/go-zero/tools/goctl/api/docgen"
|
||||||
@@ -10,6 +11,7 @@ import (
|
|||||||
"github.com/zeromicro/go-zero/tools/goctl/api/javagen"
|
"github.com/zeromicro/go-zero/tools/goctl/api/javagen"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/api/ktgen"
|
"github.com/zeromicro/go-zero/tools/goctl/api/ktgen"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/api/new"
|
"github.com/zeromicro/go-zero/tools/goctl/api/new"
|
||||||
|
"github.com/zeromicro/go-zero/tools/goctl/api/swagger"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/api/tsgen"
|
"github.com/zeromicro/go-zero/tools/goctl/api/tsgen"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/api/validate"
|
"github.com/zeromicro/go-zero/tools/goctl/api/validate"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/config"
|
"github.com/zeromicro/go-zero/tools/goctl/config"
|
||||||
@@ -31,6 +33,7 @@ var (
|
|||||||
ktCmd = cobrax.NewCommand("kt", cobrax.WithRunE(ktgen.KtCommand))
|
ktCmd = cobrax.NewCommand("kt", cobrax.WithRunE(ktgen.KtCommand))
|
||||||
pluginCmd = cobrax.NewCommand("plugin", cobrax.WithRunE(plugin.PluginCommand))
|
pluginCmd = cobrax.NewCommand("plugin", cobrax.WithRunE(plugin.PluginCommand))
|
||||||
tsCmd = cobrax.NewCommand("ts", cobrax.WithRunE(tsgen.TsCommand))
|
tsCmd = cobrax.NewCommand("ts", cobrax.WithRunE(tsgen.TsCommand))
|
||||||
|
swaggerCmd = cobrax.NewCommand("swagger", cobrax.WithRunE(swagger.Command))
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
@@ -46,6 +49,7 @@ func init() {
|
|||||||
pluginCmdFlags = pluginCmd.Flags()
|
pluginCmdFlags = pluginCmd.Flags()
|
||||||
tsCmdFlags = tsCmd.Flags()
|
tsCmdFlags = tsCmd.Flags()
|
||||||
validateCmdFlags = validateCmd.Flags()
|
validateCmdFlags = validateCmd.Flags()
|
||||||
|
swaggerCmdFlags = swaggerCmd.Flags()
|
||||||
)
|
)
|
||||||
|
|
||||||
apiCmdFlags.StringVar(&apigen.VarStringOutput, "o")
|
apiCmdFlags.StringVar(&apigen.VarStringOutput, "o")
|
||||||
@@ -73,6 +77,7 @@ func init() {
|
|||||||
goCmdFlags.StringVar(&gogen.VarStringRemote, "remote")
|
goCmdFlags.StringVar(&gogen.VarStringRemote, "remote")
|
||||||
goCmdFlags.StringVar(&gogen.VarStringBranch, "branch")
|
goCmdFlags.StringVar(&gogen.VarStringBranch, "branch")
|
||||||
goCmdFlags.BoolVar(&gogen.VarBoolWithTest, "test")
|
goCmdFlags.BoolVar(&gogen.VarBoolWithTest, "test")
|
||||||
|
goCmdFlags.BoolVar(&gogen.VarBoolTypeGroup, "type-group")
|
||||||
goCmdFlags.StringVarWithDefaultValue(&gogen.VarStringStyle, "style", config.DefaultFormat)
|
goCmdFlags.StringVarWithDefaultValue(&gogen.VarStringStyle, "style", config.DefaultFormat)
|
||||||
|
|
||||||
javaCmdFlags.StringVar(&javagen.VarStringDir, "dir")
|
javaCmdFlags.StringVar(&javagen.VarStringDir, "dir")
|
||||||
@@ -97,8 +102,13 @@ func init() {
|
|||||||
tsCmdFlags.StringVar(&tsgen.VarStringCaller, "caller")
|
tsCmdFlags.StringVar(&tsgen.VarStringCaller, "caller")
|
||||||
tsCmdFlags.BoolVar(&tsgen.VarBoolUnWrap, "unwrap")
|
tsCmdFlags.BoolVar(&tsgen.VarBoolUnWrap, "unwrap")
|
||||||
|
|
||||||
|
swaggerCmdFlags.StringVar(&swagger.VarStringAPI, "api")
|
||||||
|
swaggerCmdFlags.StringVar(&swagger.VarStringDir, "dir")
|
||||||
|
swaggerCmdFlags.StringVar(&swagger.VarStringFilename, "filename")
|
||||||
|
swaggerCmdFlags.BoolVar(&swagger.VarBoolYaml, "yaml")
|
||||||
|
|
||||||
validateCmdFlags.StringVar(&validate.VarStringAPI, "api")
|
validateCmdFlags.StringVar(&validate.VarStringAPI, "api")
|
||||||
|
|
||||||
// Add sub-commands
|
// Add sub-commands
|
||||||
Cmd.AddCommand(dartCmd, docCmd, formatCmd, goCmd, javaCmd, ktCmd, newCmd, pluginCmd, tsCmd, validateCmd)
|
Cmd.AddCommand(dartCmd, docCmd, formatCmd, goCmd, javaCmd, ktCmd, newCmd, pluginCmd, tsCmd, validateCmd, swaggerCmd)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -42,9 +42,20 @@ var (
|
|||||||
func GoFormatApi(_ *cobra.Command, _ []string) error {
|
func GoFormatApi(_ *cobra.Command, _ []string) error {
|
||||||
var be errorx.BatchError
|
var be errorx.BatchError
|
||||||
if VarBoolUseStdin {
|
if VarBoolUseStdin {
|
||||||
|
if env.UseExperimental() {
|
||||||
|
data, err := io.ReadAll(os.Stdin)
|
||||||
|
if err != nil {
|
||||||
|
be.Add(err)
|
||||||
|
} else {
|
||||||
|
if err := apiF.Source(data, os.Stdout); err != nil {
|
||||||
|
be.Add(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
if err := apiFormatReader(os.Stdin, VarStringDir, VarBoolSkipCheckDeclare); err != nil {
|
if err := apiFormatReader(os.Stdin, VarStringDir, VarBoolSkipCheckDeclare); err != nil {
|
||||||
be.Add(err)
|
be.Add(err)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
if len(VarStringDir) == 0 {
|
if len(VarStringDir) == 0 {
|
||||||
return errors.New("missing -dir")
|
return errors.New("missing -dir")
|
||||||
|
|||||||
@@ -40,6 +40,8 @@ var (
|
|||||||
// VarStringStyle describes the style of output files.
|
// VarStringStyle describes the style of output files.
|
||||||
VarStringStyle string
|
VarStringStyle string
|
||||||
VarBoolWithTest bool
|
VarBoolWithTest bool
|
||||||
|
// VarBoolTypeGroup describes whether to group types.
|
||||||
|
VarBoolTypeGroup bool
|
||||||
)
|
)
|
||||||
|
|
||||||
// GoCommand gen go project files from command line
|
// GoCommand gen go project files from command line
|
||||||
|
|||||||
@@ -6,8 +6,10 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/zeromicro/go-zero/core/collection"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||||
apiutil "github.com/zeromicro/go-zero/tools/goctl/api/util"
|
apiutil "github.com/zeromicro/go-zero/tools/goctl/api/util"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/config"
|
"github.com/zeromicro/go-zero/tools/goctl/config"
|
||||||
@@ -39,20 +41,152 @@ func BuildTypes(types []spec.Type) (string, error) {
|
|||||||
return builder.String(), nil
|
return builder.String(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func genTypes(dir string, cfg *config.Config, api *spec.ApiSpec) error {
|
func getTypeName(tp spec.Type) string {
|
||||||
val, err := BuildTypes(api.Types)
|
if tp == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
switch val := tp.(type) {
|
||||||
|
case spec.DefineStruct:
|
||||||
|
typeName := util.Title(tp.Name())
|
||||||
|
return typeName
|
||||||
|
case spec.PointerType:
|
||||||
|
return getTypeName(val.Type)
|
||||||
|
case spec.ArrayType:
|
||||||
|
return getTypeName(val.Value)
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func genTypesWithGroup(dir string, cfg *config.Config, api *spec.ApiSpec) error {
|
||||||
|
groupTypes := make(map[string]map[string]spec.Type)
|
||||||
|
typesBelongToFiles := make(map[string]*collection.Set)
|
||||||
|
|
||||||
|
for _, v := range api.Service.Groups {
|
||||||
|
group := v.GetAnnotation(groupProperty)
|
||||||
|
if len(group) == 0 {
|
||||||
|
group = groupTypeDefault
|
||||||
|
}
|
||||||
|
// convert filepath to Identifier name spec.
|
||||||
|
group = strings.TrimPrefix(group, "/")
|
||||||
|
group = strings.TrimSuffix(group, "/")
|
||||||
|
group = util.SafeString(group)
|
||||||
|
for _, v := range v.Routes {
|
||||||
|
requestTypeName := getTypeName(v.RequestType)
|
||||||
|
responseTypeName := getTypeName(v.ResponseType)
|
||||||
|
requestTypeFileSet, ok := typesBelongToFiles[requestTypeName]
|
||||||
|
if !ok {
|
||||||
|
requestTypeFileSet = collection.NewSet()
|
||||||
|
}
|
||||||
|
if len(requestTypeName) > 0 {
|
||||||
|
requestTypeFileSet.AddStr(group)
|
||||||
|
typesBelongToFiles[requestTypeName] = requestTypeFileSet
|
||||||
|
}
|
||||||
|
|
||||||
|
responseTypeFileSet, ok := typesBelongToFiles[responseTypeName]
|
||||||
|
if !ok {
|
||||||
|
responseTypeFileSet = collection.NewSet()
|
||||||
|
}
|
||||||
|
if len(responseTypeName) > 0 {
|
||||||
|
responseTypeFileSet.AddStr(group)
|
||||||
|
typesBelongToFiles[responseTypeName] = responseTypeFileSet
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
typesInOneFile := make(map[string]*collection.Set)
|
||||||
|
for typeName, fileSet := range typesBelongToFiles {
|
||||||
|
count := fileSet.Count()
|
||||||
|
switch {
|
||||||
|
case count == 0: // it means there has no structure type or no request/response body
|
||||||
|
continue
|
||||||
|
case count == 1: // it means a structure type used in only one group.
|
||||||
|
groupName := fileSet.KeysStr()[0]
|
||||||
|
typeSet, ok := typesInOneFile[groupName]
|
||||||
|
if !ok {
|
||||||
|
typeSet = collection.NewSet()
|
||||||
|
}
|
||||||
|
typeSet.AddStr(typeName)
|
||||||
|
typesInOneFile[groupName] = typeSet
|
||||||
|
default: // it means this type is used in multiple groups.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range api.Types {
|
||||||
|
typeName := util.Title(v.Name())
|
||||||
|
groupSet, ok := typesBelongToFiles[typeName]
|
||||||
|
var typeCount int
|
||||||
|
if !ok {
|
||||||
|
typeCount = 0
|
||||||
|
} else {
|
||||||
|
typeCount = groupSet.Count()
|
||||||
|
}
|
||||||
|
|
||||||
|
if typeCount == 0 { // not belong to any group
|
||||||
|
types, ok := groupTypes[groupTypeDefault]
|
||||||
|
if !ok {
|
||||||
|
types = make(map[string]spec.Type)
|
||||||
|
}
|
||||||
|
types[typeName] = v
|
||||||
|
groupTypes[groupTypeDefault] = types
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if typeCount == 1 { // belong to one group
|
||||||
|
groupName := groupSet.KeysStr()[0]
|
||||||
|
types, ok := groupTypes[groupName]
|
||||||
|
if !ok {
|
||||||
|
types = make(map[string]spec.Type)
|
||||||
|
}
|
||||||
|
types[typeName] = v
|
||||||
|
groupTypes[groupName] = types
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// belong to multiple groups
|
||||||
|
types, ok := groupTypes[groupTypeDefault]
|
||||||
|
if !ok {
|
||||||
|
types = make(map[string]spec.Type)
|
||||||
|
}
|
||||||
|
types[typeName] = v
|
||||||
|
groupTypes[groupTypeDefault] = types
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
for group, typeGroup := range groupTypes {
|
||||||
|
var types []spec.Type
|
||||||
|
for _, v := range typeGroup {
|
||||||
|
types = append(types, v)
|
||||||
|
}
|
||||||
|
sort.Slice(types, func(i, j int) bool {
|
||||||
|
return types[i].Name() < types[j].Name()
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := writeTypes(dir, group, cfg, types); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeTypes(dir, baseFilename string, cfg *config.Config, types []spec.Type) error {
|
||||||
|
if len(types) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
val, err := BuildTypes(types)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
typeFilename, err := format.FileNamingFormat(cfg.NamingFormat, typesFile)
|
typeFilename, err := format.FileNamingFormat(cfg.NamingFormat, baseFilename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
typeFilename = typeFilename + ".go"
|
typeFilename = typeFilename + ".go"
|
||||||
filename := path.Join(dir, typesDir, typeFilename)
|
filename := path.Join(dir, typesDir, typeFilename)
|
||||||
os.Remove(filename)
|
_ = os.Remove(filename)
|
||||||
|
|
||||||
return genFile(fileGenConfig{
|
return genFile(fileGenConfig{
|
||||||
dir: dir,
|
dir: dir,
|
||||||
@@ -70,6 +204,13 @@ func genTypes(dir string, cfg *config.Config, api *spec.ApiSpec) error {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func genTypes(dir string, cfg *config.Config, api *spec.ApiSpec) error {
|
||||||
|
if VarBoolTypeGroup {
|
||||||
|
return genTypesWithGroup(dir, cfg, api)
|
||||||
|
}
|
||||||
|
return writeTypes(dir, typesFile, cfg, api.Types)
|
||||||
|
}
|
||||||
|
|
||||||
func writeType(writer io.Writer, tp spec.Type) error {
|
func writeType(writer io.Writer, tp spec.Type) error {
|
||||||
structType, ok := tp.(spec.DefineStruct)
|
structType, ok := tp.(spec.DefineStruct)
|
||||||
if !ok {
|
if !ok {
|
||||||
|
|||||||
@@ -10,4 +10,6 @@ const (
|
|||||||
middlewareDir = internal + "middleware"
|
middlewareDir = internal + "middleware"
|
||||||
typesDir = internal + typesPacket
|
typesDir = internal + typesPacket
|
||||||
groupProperty = "group"
|
groupProperty = "group"
|
||||||
|
|
||||||
|
groupTypeDefault="types"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -8,10 +8,10 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"path"
|
"path"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"text/template"
|
"text/template"
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/core/stringx"
|
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||||
apiutil "github.com/zeromicro/go-zero/tools/goctl/api/util"
|
apiutil "github.com/zeromicro/go-zero/tools/goctl/api/util"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/internal/version"
|
"github.com/zeromicro/go-zero/tools/goctl/internal/version"
|
||||||
@@ -96,13 +96,13 @@ func (c *componentsContext) createComponent(dir, packetName string, ty spec.Type
|
|||||||
for _, item := range c.responseTypes {
|
for _, item := range c.responseTypes {
|
||||||
if item.Name() == defineStruct.Name() {
|
if item.Name() == defineStruct.Name() {
|
||||||
superClassName = "HttpResponseData"
|
superClassName = "HttpResponseData"
|
||||||
if !stringx.Contains(c.imports, httpResponseData) {
|
if !slices.Contains(c.imports, httpResponseData) {
|
||||||
c.imports = append(c.imports, httpResponseData)
|
c.imports = append(c.imports, httpResponseData)
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if superClassName == "HttpData" && !stringx.Contains(c.imports, httpData) {
|
if superClassName == "HttpData" && !slices.Contains(c.imports, httpData) {
|
||||||
c.imports = append(c.imports, httpData)
|
c.imports = append(c.imports, httpData)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -266,7 +266,7 @@ func (c *componentsContext) genGetSet(writer io.Writer, indent int) error {
|
|||||||
tyString := javaType
|
tyString := javaType
|
||||||
decorator := ""
|
decorator := ""
|
||||||
javaPrimitiveType := []string{"int", "long", "boolean", "float", "double", "short"}
|
javaPrimitiveType := []string{"int", "long", "boolean", "float", "double", "short"}
|
||||||
if !stringx.Contains(javaPrimitiveType, javaType) {
|
if !slices.Contains(javaPrimitiveType, javaType) {
|
||||||
if member.IsOptional() || member.IsOmitEmpty() {
|
if member.IsOptional() || member.IsOmitEmpty() {
|
||||||
decorator = "@Nullable "
|
decorator = "@Nullable "
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -3,9 +3,9 @@ package spec
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"path"
|
"path"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/core/stringx"
|
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/util"
|
"github.com/zeromicro/go-zero/tools/goctl/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -57,14 +57,14 @@ func (m Member) Tags() []*Tag {
|
|||||||
|
|
||||||
// IsOptional returns true if tag is optional
|
// IsOptional returns true if tag is optional
|
||||||
func (m Member) IsOptional() bool {
|
func (m Member) IsOptional() bool {
|
||||||
if !m.IsBodyMember() {
|
if !m.IsBodyMember() && !m.IsFormMember() {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
tag := m.Tags()
|
tag := m.Tags()
|
||||||
for _, item := range tag {
|
for _, item := range tag {
|
||||||
if item.Key == bodyTagKey {
|
if item.Key == bodyTagKey || item.Key == formTagKey {
|
||||||
if stringx.Contains(item.Options, "optional") {
|
if slices.Contains(item.Options, "optional") {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -81,7 +81,7 @@ func (m Member) IsOmitEmpty() bool {
|
|||||||
tag := m.Tags()
|
tag := m.Tags()
|
||||||
for _, item := range tag {
|
for _, item := range tag {
|
||||||
if item.Key == bodyTagKey {
|
if item.Key == bodyTagKey {
|
||||||
if stringx.Contains(item.Options, "omitempty") {
|
if slices.Contains(item.Options, "omitempty") {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -93,7 +93,7 @@ func (m Member) IsOmitEmpty() bool {
|
|||||||
func (m Member) GetPropertyName() (string, error) {
|
func (m Member) GetPropertyName() (string, error) {
|
||||||
tags := m.Tags()
|
tags := m.Tags()
|
||||||
for _, tag := range tags {
|
for _, tag := range tags {
|
||||||
if stringx.Contains(definedKeys, tag.Key) {
|
if slices.Contains(definedKeys, tag.Key) {
|
||||||
if tag.Name == "-" {
|
if tag.Name == "-" {
|
||||||
return util.Untitle(m.Name), nil
|
return util.Untitle(m.Name), nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ type (
|
|||||||
|
|
||||||
// ApiSpec describes an api file
|
// ApiSpec describes an api file
|
||||||
ApiSpec struct {
|
ApiSpec struct {
|
||||||
Info Info // Deprecated: useless expression
|
Info Info
|
||||||
Syntax ApiSyntax // Deprecated: useless expression
|
Syntax ApiSyntax // Deprecated: useless expression
|
||||||
Imports []Import // Deprecated: useless expression
|
Imports []Import // Deprecated: useless expression
|
||||||
Types []Type
|
Types []Type
|
||||||
|
|||||||
@@ -49,6 +49,9 @@ func Parse(tag string) (*Tags, error) {
|
|||||||
|
|
||||||
// Get gets tag value by specified key
|
// Get gets tag value by specified key
|
||||||
func (t *Tags) Get(key string) (*Tag, error) {
|
func (t *Tags) Get(key string) (*Tag, error) {
|
||||||
|
if t == nil {
|
||||||
|
return nil, errTagNotExist
|
||||||
|
}
|
||||||
for _, tag := range t.tags {
|
for _, tag := range t.tags {
|
||||||
if tag.Key == key {
|
if tag.Key == key {
|
||||||
return tag, nil
|
return tag, nil
|
||||||
@@ -60,6 +63,9 @@ func (t *Tags) Get(key string) (*Tag, error) {
|
|||||||
|
|
||||||
// Keys returns all keys in Tags
|
// Keys returns all keys in Tags
|
||||||
func (t *Tags) Keys() []string {
|
func (t *Tags) Keys() []string {
|
||||||
|
if t == nil {
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
var keys []string
|
var keys []string
|
||||||
for _, tag := range t.tags {
|
for _, tag := range t.tags {
|
||||||
keys = append(keys, tag.Key)
|
keys = append(keys, tag.Key)
|
||||||
@@ -69,5 +75,8 @@ func (t *Tags) Keys() []string {
|
|||||||
|
|
||||||
// Tags returns all tags in Tags
|
// Tags returns all tags in Tags
|
||||||
func (t *Tags) Tags() []*Tag {
|
func (t *Tags) Tags() []*Tag {
|
||||||
|
if t == nil {
|
||||||
|
return []*Tag{}
|
||||||
|
}
|
||||||
return t.tags
|
return t.tags
|
||||||
}
|
}
|
||||||
|
|||||||
68
tools/goctl/api/spec/tags_test.go
Normal file
68
tools/goctl/api/spec/tags_test.go
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
package spec
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTags_Get(t *testing.T) {
|
||||||
|
tags := &Tags{
|
||||||
|
tags: []*Tag{
|
||||||
|
{Key: "json", Name: "foo", Options: []string{"omitempty"}},
|
||||||
|
{Key: "xml", Name: "bar", Options: nil},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tag, err := tags.Get("json")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, tag)
|
||||||
|
assert.Equal(t, "json", tag.Key)
|
||||||
|
assert.Equal(t, "foo", tag.Name)
|
||||||
|
|
||||||
|
_, err = tags.Get("yaml")
|
||||||
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
var nilTags *Tags
|
||||||
|
_, err = nilTags.Get("json")
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTags_Keys(t *testing.T) {
|
||||||
|
tags := &Tags{
|
||||||
|
tags: []*Tag{
|
||||||
|
{Key: "json", Name: "foo", Options: []string{"omitempty"}},
|
||||||
|
{Key: "xml", Name: "bar", Options: nil},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
keys := tags.Keys()
|
||||||
|
expected := []string{"json", "xml"}
|
||||||
|
assert.Equal(t, expected, keys)
|
||||||
|
|
||||||
|
var nilTags *Tags
|
||||||
|
nilKeys := nilTags.Keys()
|
||||||
|
assert.Empty(t, nilKeys)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTags_Tags(t *testing.T) {
|
||||||
|
tags := &Tags{
|
||||||
|
tags: []*Tag{
|
||||||
|
{Key: "json", Name: "foo", Options: []string{"omitempty"}},
|
||||||
|
{Key: "xml", Name: "bar", Options: nil},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := tags.Tags()
|
||||||
|
assert.Len(t, result, 2)
|
||||||
|
assert.Equal(t, "json", result[0].Key)
|
||||||
|
assert.Equal(t, "foo", result[0].Name)
|
||||||
|
assert.Equal(t, []string{"omitempty"}, result[0].Options)
|
||||||
|
assert.Equal(t, "xml", result[1].Key)
|
||||||
|
assert.Equal(t, "bar", result[1].Name)
|
||||||
|
assert.Nil(t, result[1].Options)
|
||||||
|
|
||||||
|
var nilTags *Tags
|
||||||
|
nilResult := nilTags.Tags()
|
||||||
|
assert.Empty(t, nilResult)
|
||||||
|
}
|
||||||
75
tools/goctl/api/swagger/annotation.go
Normal file
75
tools/goctl/api/swagger/annotation.go
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
package swagger
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/zeromicro/go-zero/tools/goctl/util"
|
||||||
|
"google.golang.org/grpc/metadata"
|
||||||
|
)
|
||||||
|
|
||||||
|
func getBoolFromKVOrDefault(properties map[string]string, key string, def bool) bool {
|
||||||
|
if len(properties) == 0 {
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
md := metadata.New(properties)
|
||||||
|
val := md.Get(key)
|
||||||
|
if len(val) == 0 {
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
str := util.Unquote(val[0])
|
||||||
|
if len(str) == 0 {
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
res, _ := strconv.ParseBool(str)
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
func getStringFromKVOrDefault(properties map[string]string, key string, def string) string {
|
||||||
|
if len(properties) == 0 {
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
md := metadata.New(properties)
|
||||||
|
val := md.Get(key)
|
||||||
|
if len(val) == 0 {
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
str := util.Unquote(val[0])
|
||||||
|
if len(str) == 0 {
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
return str
|
||||||
|
}
|
||||||
|
|
||||||
|
func getListFromInfoOrDefault(properties map[string]string, key string, def []string) []string {
|
||||||
|
if len(properties) == 0 {
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
md := metadata.New(properties)
|
||||||
|
val := md.Get(key)
|
||||||
|
if len(val) == 0 {
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
|
||||||
|
str := util.Unquote(val[0])
|
||||||
|
if len(str) == 0 {
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
resp := util.FieldsAndTrimSpace(str, commaRune)
|
||||||
|
if len(resp) == 0 {
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
|
||||||
|
func getFirstUsableString(def ...string) string {
|
||||||
|
if len(def) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
for _, val := range def {
|
||||||
|
str := util.Unquote(val)
|
||||||
|
if len(str) != 0 {
|
||||||
|
return str
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
53
tools/goctl/api/swagger/annotation_test.go
Normal file
53
tools/goctl/api/swagger/annotation_test.go
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
package swagger
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_getBoolFromKVOrDefault(t *testing.T) {
|
||||||
|
properties := map[string]string{
|
||||||
|
"enabled": `"true"`,
|
||||||
|
"disabled": `"false"`,
|
||||||
|
"invalid": `"notabool"`,
|
||||||
|
"empty_value": `""`,
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.True(t, getBoolFromKVOrDefault(properties, "enabled", false))
|
||||||
|
assert.False(t, getBoolFromKVOrDefault(properties, "disabled", true))
|
||||||
|
assert.False(t, getBoolFromKVOrDefault(properties, "invalid", false))
|
||||||
|
assert.True(t, getBoolFromKVOrDefault(properties, "missing", true))
|
||||||
|
assert.False(t, getBoolFromKVOrDefault(properties, "empty_value", false))
|
||||||
|
assert.False(t, getBoolFromKVOrDefault(nil, "nil", false))
|
||||||
|
assert.False(t, getBoolFromKVOrDefault(map[string]string{}, "empty", false))
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_getStringFromKVOrDefault(t *testing.T) {
|
||||||
|
properties := map[string]string{
|
||||||
|
"name": `"example"`,
|
||||||
|
"empty": `""`,
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, "example", getStringFromKVOrDefault(properties, "name", "default"))
|
||||||
|
assert.Equal(t, "default", getStringFromKVOrDefault(properties, "empty", "default"))
|
||||||
|
assert.Equal(t, "default", getStringFromKVOrDefault(properties, "missing", "default"))
|
||||||
|
assert.Equal(t, "default", getStringFromKVOrDefault(nil, "nil", "default"))
|
||||||
|
assert.Equal(t, "default", getStringFromKVOrDefault(map[string]string{}, "empty", "default"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_getListFromInfoOrDefault(t *testing.T) {
|
||||||
|
properties := map[string]string{
|
||||||
|
"list": `"a, b, c"`,
|
||||||
|
"empty": `""`,
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, []string{"a", " b", " c"}, getListFromInfoOrDefault(properties, "list", []string{"default"}))
|
||||||
|
assert.Equal(t, []string{"default"}, getListFromInfoOrDefault(properties, "empty", []string{"default"}))
|
||||||
|
assert.Equal(t, []string{"default"}, getListFromInfoOrDefault(properties, "missing", []string{"default"}))
|
||||||
|
assert.Equal(t, []string{"default"}, getListFromInfoOrDefault(nil, "nil", []string{"default"}))
|
||||||
|
assert.Equal(t, []string{"default"}, getListFromInfoOrDefault(map[string]string{}, "empty", []string{"default"}))
|
||||||
|
assert.Equal(t, []string{"default"}, getListFromInfoOrDefault(map[string]string{
|
||||||
|
"foo": ",,",
|
||||||
|
}, "foo", []string{"default"}))
|
||||||
|
}
|
||||||
138
tools/goctl/api/swagger/api.go
Normal file
138
tools/goctl/api/swagger/api.go
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
package swagger
|
||||||
|
|
||||||
|
import "github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||||
|
|
||||||
|
func fillAllStructs(api *spec.ApiSpec) {
|
||||||
|
var (
|
||||||
|
tps []spec.Type
|
||||||
|
structTypes = make(map[string]spec.DefineStruct)
|
||||||
|
groups []spec.Group
|
||||||
|
)
|
||||||
|
for _, tp := range api.Types {
|
||||||
|
structTypes[tp.Name()] = tp.(spec.DefineStruct)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tp := range api.Types {
|
||||||
|
filledTP := fillStruct("", tp, structTypes)
|
||||||
|
tps = append(tps, filledTP)
|
||||||
|
structTypes[filledTP.Name()] = filledTP.(spec.DefineStruct)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, group := range api.Service.Groups {
|
||||||
|
var routes []spec.Route
|
||||||
|
for _, route := range group.Routes {
|
||||||
|
route.RequestType = fillStruct("", route.RequestType, structTypes)
|
||||||
|
route.ResponseType = fillStruct("", route.ResponseType, structTypes)
|
||||||
|
routes = append(routes, route)
|
||||||
|
}
|
||||||
|
group.Routes = routes
|
||||||
|
groups = append(groups, group)
|
||||||
|
}
|
||||||
|
api.Service.Groups = groups
|
||||||
|
api.Types = tps
|
||||||
|
}
|
||||||
|
|
||||||
|
func fillStruct(parent string, tp spec.Type, allTypes map[string]spec.DefineStruct) spec.Type {
|
||||||
|
switch val := tp.(type) {
|
||||||
|
case spec.DefineStruct:
|
||||||
|
var members []spec.Member
|
||||||
|
for _, member := range val.Members {
|
||||||
|
switch memberType := member.Type.(type) {
|
||||||
|
case spec.PointerType:
|
||||||
|
member.Type = spec.PointerType{
|
||||||
|
RawName: memberType.RawName,
|
||||||
|
Type: fillStruct(val.Name(), memberType.Type, allTypes),
|
||||||
|
}
|
||||||
|
case spec.ArrayType:
|
||||||
|
member.Type = spec.ArrayType{
|
||||||
|
RawName: memberType.RawName,
|
||||||
|
Value: fillStruct(val.Name(), memberType.Value, allTypes),
|
||||||
|
}
|
||||||
|
case spec.MapType:
|
||||||
|
member.Type = spec.MapType{
|
||||||
|
RawName: memberType.RawName,
|
||||||
|
Key: memberType.Key,
|
||||||
|
Value: fillStruct(val.Name(), memberType.Value, allTypes),
|
||||||
|
}
|
||||||
|
case spec.DefineStruct:
|
||||||
|
if parent != memberType.Name() { // avoid recursive struct
|
||||||
|
if st, ok := allTypes[memberType.Name()]; ok {
|
||||||
|
member.Type = fillStruct("", st, allTypes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case spec.NestedStruct:
|
||||||
|
member.Type = fillStruct("", member.Type, allTypes)
|
||||||
|
}
|
||||||
|
members = append(members, member)
|
||||||
|
}
|
||||||
|
if len(members) == 0 {
|
||||||
|
st, ok := allTypes[val.RawName]
|
||||||
|
if ok {
|
||||||
|
members = st.Members
|
||||||
|
}
|
||||||
|
}
|
||||||
|
val.Members = members
|
||||||
|
return val
|
||||||
|
case spec.NestedStruct:
|
||||||
|
var members []spec.Member
|
||||||
|
for _, member := range val.Members {
|
||||||
|
switch memberType := member.Type.(type) {
|
||||||
|
case spec.PointerType:
|
||||||
|
member.Type = spec.PointerType{
|
||||||
|
RawName: memberType.RawName,
|
||||||
|
Type: fillStruct(val.Name(), memberType.Type, allTypes),
|
||||||
|
}
|
||||||
|
case spec.ArrayType:
|
||||||
|
member.Type = spec.ArrayType{
|
||||||
|
RawName: memberType.RawName,
|
||||||
|
Value: fillStruct(val.Name(), memberType.Value, allTypes),
|
||||||
|
}
|
||||||
|
case spec.MapType:
|
||||||
|
member.Type = spec.MapType{
|
||||||
|
RawName: memberType.RawName,
|
||||||
|
Key: memberType.Key,
|
||||||
|
Value: fillStruct(val.Name(), memberType.Value, allTypes),
|
||||||
|
}
|
||||||
|
case spec.DefineStruct:
|
||||||
|
if parent != memberType.Name() { // avoid recursive struct
|
||||||
|
if st, ok := allTypes[memberType.Name()]; ok {
|
||||||
|
member.Type = fillStruct("", st, allTypes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case spec.NestedStruct:
|
||||||
|
if parent != memberType.Name() {
|
||||||
|
if st, ok := allTypes[memberType.Name()]; ok {
|
||||||
|
member.Type = fillStruct("", st, allTypes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
members = append(members, member)
|
||||||
|
}
|
||||||
|
if len(members) == 0 {
|
||||||
|
st, ok := allTypes[val.RawName]
|
||||||
|
if ok {
|
||||||
|
members = st.Members
|
||||||
|
}
|
||||||
|
}
|
||||||
|
val.Members = members
|
||||||
|
return val
|
||||||
|
case spec.PointerType:
|
||||||
|
return spec.PointerType{
|
||||||
|
RawName: val.RawName,
|
||||||
|
Type: fillStruct(parent, val.Type, allTypes),
|
||||||
|
}
|
||||||
|
case spec.ArrayType:
|
||||||
|
return spec.ArrayType{
|
||||||
|
RawName: val.RawName,
|
||||||
|
Value: fillStruct(parent, val.Value, allTypes),
|
||||||
|
}
|
||||||
|
case spec.MapType:
|
||||||
|
return spec.MapType{
|
||||||
|
RawName: val.RawName,
|
||||||
|
Key: val.Key,
|
||||||
|
Value: fillStruct(parent, val.Value, allTypes),
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return tp
|
||||||
|
}
|
||||||
|
}
|
||||||
87
tools/goctl/api/swagger/command.go
Normal file
87
tools/goctl/api/swagger/command.go
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
package swagger
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"github.com/zeromicro/go-zero/tools/goctl/pkg/parser/api/parser"
|
||||||
|
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
|
||||||
|
"gopkg.in/yaml.v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// VarStringAPI specifies the API filename.
|
||||||
|
VarStringAPI string
|
||||||
|
|
||||||
|
// VarStringDir specifies the directory to generate swagger file.
|
||||||
|
VarStringDir string
|
||||||
|
|
||||||
|
// VarStringFilename specifies the generated swagger file name without the extension.
|
||||||
|
VarStringFilename string
|
||||||
|
|
||||||
|
// VarBoolYaml specifies whether to generate a YAML file.
|
||||||
|
VarBoolYaml bool
|
||||||
|
)
|
||||||
|
|
||||||
|
func Command(_ *cobra.Command, _ []string) error {
|
||||||
|
if len(VarStringAPI) == 0 {
|
||||||
|
return errors.New("missing -api")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(VarStringDir) == 0 {
|
||||||
|
return errors.New("missing -dir")
|
||||||
|
}
|
||||||
|
|
||||||
|
api, err := parser.Parse(VarStringAPI, "")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
fillAllStructs(api)
|
||||||
|
|
||||||
|
if err := api.Validate(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
swagger, err := spec2Swagger(api)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
data, err := json.MarshalIndent(swagger, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = pathx.MkdirIfNotExist(VarStringDir)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
filename := VarStringFilename
|
||||||
|
if filename == "" {
|
||||||
|
base := filepath.Base(VarStringAPI)
|
||||||
|
filename = strings.TrimSuffix(base, filepath.Ext(base))
|
||||||
|
}
|
||||||
|
|
||||||
|
if VarBoolYaml {
|
||||||
|
filePath := filepath.Join(VarStringDir, filename+".yaml")
|
||||||
|
|
||||||
|
var jsonObj interface{}
|
||||||
|
if err := yaml.Unmarshal(data, &jsonObj); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := yaml.Marshal(jsonObj)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return os.WriteFile(filePath, data, 0644)
|
||||||
|
}
|
||||||
|
|
||||||
|
// generate json swagger file
|
||||||
|
filePath := filepath.Join(VarStringDir, filename+".json")
|
||||||
|
return os.WriteFile(filePath, data, 0644)
|
||||||
|
}
|
||||||
65
tools/goctl/api/swagger/const.go
Normal file
65
tools/goctl/api/swagger/const.go
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
package swagger
|
||||||
|
|
||||||
|
const (
|
||||||
|
tagHeader = "header"
|
||||||
|
tagPath = "path"
|
||||||
|
tagForm = "form"
|
||||||
|
tagJson = "json"
|
||||||
|
defFlag = "default="
|
||||||
|
enumFlag = "options="
|
||||||
|
rangeFlag = "range="
|
||||||
|
exampleFlag = "example="
|
||||||
|
optionalFlag = "optional"
|
||||||
|
|
||||||
|
paramsInHeader = "header"
|
||||||
|
paramsInPath = "path"
|
||||||
|
paramsInQuery = "query"
|
||||||
|
paramsInBody = "body"
|
||||||
|
paramsInForm = "formData"
|
||||||
|
|
||||||
|
swaggerTypeInteger = "integer"
|
||||||
|
swaggerTypeNumber = "number"
|
||||||
|
swaggerTypeString = "string"
|
||||||
|
swaggerTypeBoolean = "boolean"
|
||||||
|
swaggerTypeArray = "array"
|
||||||
|
swaggerTypeObject = "object"
|
||||||
|
|
||||||
|
swaggerVersion = "2.0"
|
||||||
|
applicationJson = "application/json"
|
||||||
|
applicationForm = "application/x-www-form-urlencoded"
|
||||||
|
schemeHttps = "https"
|
||||||
|
defaultBasePath = "/"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
propertyKeyUseDefinitions = "useDefinitions"
|
||||||
|
propertyKeyExternalDocsDescription = "externalDocsDescription"
|
||||||
|
propertyKeyExternalDocsURL = "externalDocsURL"
|
||||||
|
propertyKeyTitle = "title"
|
||||||
|
propertyKeyTermsOfService = "termsOfService"
|
||||||
|
propertyKeyDescription = "description"
|
||||||
|
propertyKeyVersion = "version"
|
||||||
|
propertyKeyContactName = "contactName"
|
||||||
|
propertyKeyContactURL = "contactURL"
|
||||||
|
propertyKeyContactEmail = "contactEmail"
|
||||||
|
propertyKeyLicenseName = "licenseName"
|
||||||
|
propertyKeyLicenseURL = "licenseURL"
|
||||||
|
propertyKeyProduces = "produces"
|
||||||
|
propertyKeyConsumes = "consumes"
|
||||||
|
propertyKeySchemes = "schemes"
|
||||||
|
propertyKeyTags = "tags"
|
||||||
|
propertyKeySummary = "summary"
|
||||||
|
propertyKeyGroup = "group"
|
||||||
|
propertyKeyOperationId = "operationId"
|
||||||
|
propertyKeyDeprecated = "deprecated"
|
||||||
|
propertyKeyPrefix = "prefix"
|
||||||
|
propertyKeyAuthType = "authType"
|
||||||
|
propertyKeyHost = "host"
|
||||||
|
propertyKeyBasePath = "basePath"
|
||||||
|
propertyKeyWrapCodeMsg = "wrapCodeMsg"
|
||||||
|
propertyKeyBizCodeEnumDescription = "bizCodeEnumDescription"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultValueOfPropertyUseDefinition = false
|
||||||
|
)
|
||||||
25
tools/goctl/api/swagger/contenttype.go
Normal file
25
tools/goctl/api/swagger/contenttype.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
package swagger
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||||
|
)
|
||||||
|
|
||||||
|
func consumesFromTypeOrDef(ctx Context, method string, tp spec.Type) []string {
|
||||||
|
if strings.EqualFold(method, http.MethodGet) {
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
if tp == nil {
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
structType, ok := tp.(spec.DefineStruct)
|
||||||
|
if !ok {
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
if typeContainsTag(ctx, structType, tagJson) {
|
||||||
|
return []string{applicationJson}
|
||||||
|
}
|
||||||
|
return []string{applicationForm}
|
||||||
|
}
|
||||||
68
tools/goctl/api/swagger/contenttype_test.go
Normal file
68
tools/goctl/api/swagger/contenttype_test.go
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
package swagger
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConsumesFromTypeOrDef(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
method string
|
||||||
|
tp spec.Type
|
||||||
|
expected []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "GET method with nil type",
|
||||||
|
method: http.MethodGet,
|
||||||
|
tp: nil,
|
||||||
|
expected: []string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "post nil",
|
||||||
|
method: http.MethodPost,
|
||||||
|
tp: nil,
|
||||||
|
expected: []string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "json tag",
|
||||||
|
method: http.MethodPost,
|
||||||
|
tp: spec.DefineStruct{
|
||||||
|
Members: []spec.Member{
|
||||||
|
{
|
||||||
|
Tag: `json:"example"`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: []string{applicationJson},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "form tag",
|
||||||
|
method: http.MethodPost,
|
||||||
|
tp: spec.DefineStruct{
|
||||||
|
Members: []spec.Member{
|
||||||
|
{
|
||||||
|
Tag: `form:"example"`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: []string{applicationForm},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Non struct type",
|
||||||
|
method: http.MethodPost,
|
||||||
|
tp: spec.ArrayType{},
|
||||||
|
expected: []string{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := consumesFromTypeOrDef(testingContext(t), tt.method, tt.tp)
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
28
tools/goctl/api/swagger/context.go
Normal file
28
tools/goctl/api/swagger/context.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
package swagger
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Context struct {
|
||||||
|
UseDefinitions bool
|
||||||
|
WrapCodeMsg bool
|
||||||
|
BizCodeEnumDescription string
|
||||||
|
}
|
||||||
|
|
||||||
|
func testingContext(_ *testing.T) Context {
|
||||||
|
return Context{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func contextFromApi(info spec.Info) Context {
|
||||||
|
if len(info.Properties) == 0 {
|
||||||
|
return Context{}
|
||||||
|
}
|
||||||
|
return Context{
|
||||||
|
UseDefinitions: getBoolFromKVOrDefault(info.Properties, propertyKeyUseDefinitions, defaultValueOfPropertyUseDefinition),
|
||||||
|
WrapCodeMsg: getBoolFromKVOrDefault(info.Properties, propertyKeyWrapCodeMsg, false),
|
||||||
|
BizCodeEnumDescription: getStringFromKVOrDefault(info.Properties, propertyKeyBizCodeEnumDescription, "business code"),
|
||||||
|
}
|
||||||
|
}
|
||||||
32
tools/goctl/api/swagger/definition.go
Normal file
32
tools/goctl/api/swagger/definition.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
package swagger
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/go-openapi/spec"
|
||||||
|
apiSpec "github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||||
|
)
|
||||||
|
|
||||||
|
func definitionsFromTypes(ctx Context, types []apiSpec.Type) spec.Definitions {
|
||||||
|
if !ctx.UseDefinitions {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
definitions := make(spec.Definitions)
|
||||||
|
for _, tp := range types {
|
||||||
|
typeName := tp.Name()
|
||||||
|
definitions[typeName] = schemaFromType(ctx, tp)
|
||||||
|
}
|
||||||
|
return definitions
|
||||||
|
}
|
||||||
|
|
||||||
|
func schemaFromType(ctx Context, tp apiSpec.Type) spec.Schema {
|
||||||
|
p, r := propertiesFromType(ctx, tp)
|
||||||
|
props := spec.SchemaProps{
|
||||||
|
Type: typeFromGoType(ctx, tp),
|
||||||
|
Properties: p,
|
||||||
|
AdditionalProperties: mapFromGoType(ctx, tp),
|
||||||
|
Items: itemsFromGoType(ctx, tp),
|
||||||
|
Required: r,
|
||||||
|
}
|
||||||
|
return spec.Schema{
|
||||||
|
SchemaProps: props,
|
||||||
|
}
|
||||||
|
}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user