mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-06-11 16:40:19 +08:00
Compare commits
137 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c9ff6a10d3 | ||
|
|
a71e56de52 | ||
|
|
bae8d4f4c8 | ||
|
|
8c6266f338 | ||
|
|
95d5b81f44 | ||
|
|
bca7bbc142 | ||
|
|
df9a52664b | ||
|
|
937cf0db96 | ||
|
|
75cebb65f8 | ||
|
|
410f56e73a | ||
|
|
017909a3ab | ||
|
|
0d31e6c375 | ||
|
|
0ba86b1849 | ||
|
|
4cacc4d9d3 | ||
|
|
a99c14da4a | ||
|
|
985582264a | ||
|
|
8364e341e1 | ||
|
|
0f2b589d4d | ||
|
|
19fec36d24 | ||
|
|
f037bf344d | ||
|
|
d99cf35b07 | ||
|
|
f459f1b5ff | ||
|
|
0140fd417b | ||
|
|
7969e0ca38 | ||
|
|
91c885b5b0 | ||
|
|
d4cccca387 | ||
|
|
4b2095ed03 | ||
|
|
1229eeb2d2 | ||
|
|
9142b146c5 | ||
|
|
8a1b2d5aed | ||
|
|
da5d39e6ca | ||
|
|
68c5a17c67 | ||
|
|
b53f9f5f2d | ||
|
|
36d57626b6 | ||
|
|
4e36ba832f | ||
|
|
a44954a771 | ||
|
|
f3edd4b880 | ||
|
|
2de3e397ff | ||
|
|
a435eb56f2 | ||
|
|
d80761c147 | ||
|
|
e7bd0d8b60 | ||
|
|
b109b3ef4c | ||
|
|
e3c371ac89 | ||
|
|
15eb6f4f6d | ||
|
|
4d3681b71c | ||
|
|
a682bda0bb | ||
|
|
45b27ad93a | ||
|
|
292a8302a1 | ||
|
|
91ab1f6d2b | ||
|
|
5048c350ae | ||
|
|
94edc32f3e | ||
|
|
ec989b2e2a | ||
|
|
82fe802e81 | ||
|
|
072d68f897 | ||
|
|
2e91ba5811 | ||
|
|
5564c43197 | ||
|
|
e55158b0f7 | ||
|
|
69aa7fe346 | ||
|
|
c3820a95c1 | ||
|
|
493f3bad0f | ||
|
|
eb0d5ad3a4 | ||
|
|
14192050ae | ||
|
|
9193e771e3 | ||
|
|
808b4e496a | ||
|
|
e416d01f8d | ||
|
|
789c5de873 | ||
|
|
52078a0c14 | ||
|
|
7ef13116a0 | ||
|
|
6b8053410a | ||
|
|
81c6928445 | ||
|
|
761c2dd716 | ||
|
|
aeceb3cfbe | ||
|
|
15ea07aad1 | ||
|
|
98bebbc74f | ||
|
|
eafd11d949 | ||
|
|
b251ce346e | ||
|
|
812140ba36 | ||
|
|
44735e949c | ||
|
|
bf313c3c56 | ||
|
|
94e7753262 | ||
|
|
9c478626d2 | ||
|
|
801c283478 | ||
|
|
2a54faf997 | ||
|
|
ecd98f3653 | ||
|
|
61641581eb | ||
|
|
6f2730d5ae | ||
|
|
0eff777b62 | ||
|
|
cafbf535f7 | ||
|
|
6edfce63e3 | ||
|
|
cdb0098b18 | ||
|
|
620c7f9693 | ||
|
|
dba444a382 | ||
|
|
b24fb3ebf7 | ||
|
|
967f0926eb | ||
|
|
e68c683df9 | ||
|
|
247985a065 | ||
|
|
80573af0d8 | ||
|
|
c0394b631a | ||
|
|
68d1aba377 | ||
|
|
3315e60272 | ||
|
|
327ef73700 | ||
|
|
eb11521655 | ||
|
|
4c37545e55 | ||
|
|
2f47c1fba4 | ||
|
|
16d54d0ace | ||
|
|
9925bcbf99 | ||
|
|
38a5ecb796 | ||
|
|
af78fc7c5f | ||
|
|
790302b486 | ||
|
|
6a0672b801 | ||
|
|
560c61612c | ||
|
|
6a988dc4a9 | ||
|
|
15842c3c7a | ||
|
|
f2914a74df | ||
|
|
f113d512e8 | ||
|
|
7a4818da59 | ||
|
|
48d0709ca6 | ||
|
|
f747585518 | ||
|
|
507ff96546 | ||
|
|
651eabb4c6 | ||
|
|
e6b4372056 | ||
|
|
24073969a1 | ||
|
|
ca797ed22c | ||
|
|
e347d3f8f8 | ||
|
|
396393b336 | ||
|
|
1f0531b254 | ||
|
|
77fb271a06 | ||
|
|
af7cf79963 | ||
|
|
7926d396d7 | ||
|
|
080cd3df84 | ||
|
|
c4e1a6a2d8 | ||
|
|
4e71e95e44 | ||
|
|
84db9bcd15 | ||
|
|
b28f79ac11 | ||
|
|
e134e77b2b | ||
|
|
f669d84ce8 | ||
|
|
9213b8ac27 |
18
.github/workflows/issue-translator.yml
vendored
18
.github/workflows/issue-translator.yml
vendored
@@ -1,18 +0,0 @@
|
|||||||
name: 'issue-translator'
|
|
||||||
on:
|
|
||||||
issue_comment:
|
|
||||||
types: [created]
|
|
||||||
issues:
|
|
||||||
types: [opened]
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
build:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- uses: usthe/issues-translate-action@v2.7
|
|
||||||
with:
|
|
||||||
IS_MODIFY_TITLE: true
|
|
||||||
# not require, default false, . Decide whether to modify the issue title
|
|
||||||
# if true, the robot account @Issues-translate-bot must have modification permissions, invite @Issues-translate-bot to your project or use your custom bot.
|
|
||||||
CUSTOM_BOT_NOTE: Bot detected the issue body's language is not English, translate it automatically. 👯👭🏻🧑🤝🧑👫🧑🏿🤝🧑🏻👩🏾🤝👨🏿👬🏿
|
|
||||||
# not require. Customize the translation robot prefix message.
|
|
||||||
42
.github/workflows/version-check.yml
vendored
Normal file
42
.github/workflows/version-check.yml
vendored
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
name: Release Version Check
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
tags:
|
||||||
|
- 'tools/goctl/v*'
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
version-check:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: '1.21'
|
||||||
|
|
||||||
|
- name: Extract tag version
|
||||||
|
id: get_version
|
||||||
|
run: |
|
||||||
|
# Extract version from tools/goctl/v* format
|
||||||
|
VERSION="${GITHUB_REF#refs/tags/tools/goctl/v}"
|
||||||
|
echo "VERSION=$VERSION" >> $GITHUB_ENV
|
||||||
|
echo "Extracted version: $VERSION"
|
||||||
|
|
||||||
|
- name: Check version in goctl source code
|
||||||
|
run: |
|
||||||
|
# Change to goctl directory
|
||||||
|
cd tools/goctl
|
||||||
|
|
||||||
|
# Check version in BuildVersion constant
|
||||||
|
VERSION_IN_CODE=$(grep -r "const BuildVersion =" . | grep -o '".*"' | tr -d '"')
|
||||||
|
echo "Version in code: $VERSION_IN_CODE"
|
||||||
|
echo "Expected version: $VERSION"
|
||||||
|
|
||||||
|
if [ "$VERSION_IN_CODE" != "$VERSION" ]; then
|
||||||
|
echo "Version mismatch: Version in code ($VERSION_IN_CODE) doesn't match tag version ($VERSION)"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo "✅ Version check passed!"
|
||||||
@@ -8,16 +8,12 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/core/mathx"
|
|
||||||
"github.com/zeromicro/go-zero/core/proc"
|
"github.com/zeromicro/go-zero/core/proc"
|
||||||
"github.com/zeromicro/go-zero/core/stat"
|
"github.com/zeromicro/go-zero/core/stat"
|
||||||
"github.com/zeromicro/go-zero/core/stringx"
|
"github.com/zeromicro/go-zero/core/stringx"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const numHistoryReasons = 5
|
||||||
numHistoryReasons = 5
|
|
||||||
timeFormat = "15:04:05"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ErrServiceUnavailable is returned when the Breaker state is open.
|
// ErrServiceUnavailable is returned when the Breaker state is open.
|
||||||
var ErrServiceUnavailable = errors.New("circuit breaker is open")
|
var ErrServiceUnavailable = errors.New("circuit breaker is open")
|
||||||
@@ -262,9 +258,9 @@ type errorWindow struct {
|
|||||||
|
|
||||||
func (ew *errorWindow) add(reason string) {
|
func (ew *errorWindow) add(reason string) {
|
||||||
ew.lock.Lock()
|
ew.lock.Lock()
|
||||||
ew.reasons[ew.index] = fmt.Sprintf("%s %s", time.Now().Format(timeFormat), reason)
|
ew.reasons[ew.index] = fmt.Sprintf("%s %s", time.Now().Format(time.TimeOnly), reason)
|
||||||
ew.index = (ew.index + 1) % numHistoryReasons
|
ew.index = (ew.index + 1) % numHistoryReasons
|
||||||
ew.count = mathx.MinInt(ew.count+1, numHistoryReasons)
|
ew.count = min(ew.count+1, numHistoryReasons)
|
||||||
ew.lock.Unlock()
|
ew.lock.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -423,7 +423,7 @@ func TestRegistry_Monitor(t *testing.T) {
|
|||||||
GetRegistry().clusters = map[string]*cluster{
|
GetRegistry().clusters = map[string]*cluster{
|
||||||
getClusterKey(endpoints): {
|
getClusterKey(endpoints): {
|
||||||
watchers: map[watchKey]*watchValue{
|
watchers: map[watchKey]*watchValue{
|
||||||
watchKey{
|
{
|
||||||
key: "foo",
|
key: "foo",
|
||||||
exactMatch: true,
|
exactMatch: true,
|
||||||
}: {
|
}: {
|
||||||
@@ -449,7 +449,7 @@ func TestRegistry_Unmonitor(t *testing.T) {
|
|||||||
GetRegistry().clusters = map[string]*cluster{
|
GetRegistry().clusters = map[string]*cluster{
|
||||||
getClusterKey(endpoints): {
|
getClusterKey(endpoints): {
|
||||||
watchers: map[watchKey]*watchValue{
|
watchers: map[watchKey]*watchValue{
|
||||||
watchKey{
|
{
|
||||||
key: "foo",
|
key: "foo",
|
||||||
exactMatch: true,
|
exactMatch: true,
|
||||||
}: {
|
}: {
|
||||||
|
|||||||
@@ -86,21 +86,16 @@ func TestConsistentHashIncrementalTransfer(t *testing.T) {
|
|||||||
|
|
||||||
func TestConsistentHashTransferOnFailure(t *testing.T) {
|
func TestConsistentHashTransferOnFailure(t *testing.T) {
|
||||||
index := 41
|
index := 41
|
||||||
keys, newKeys := getKeysBeforeAndAfterFailure(t, "localhost:", index)
|
ratioNotExists := getTransferRatioOnFailure(t, index)
|
||||||
var transferred int
|
assert.True(t, ratioNotExists == 0, fmt.Sprintf("%d: %f", index, ratioNotExists))
|
||||||
for k, v := range newKeys {
|
index = 13
|
||||||
if v != keys[k] {
|
ratio := getTransferRatioOnFailure(t, index)
|
||||||
transferred++
|
assert.True(t, ratio < 2.5/keySize, fmt.Sprintf("%d: %f", index, ratio))
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ratio := float32(transferred) / float32(requestSize)
|
|
||||||
assert.True(t, ratio < 2.5/float32(keySize), fmt.Sprintf("%d: %f", index, ratio))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConsistentHashLeastTransferOnFailure(t *testing.T) {
|
func TestConsistentHashLeastTransferOnFailure(t *testing.T) {
|
||||||
prefix := "localhost:"
|
prefix := "localhost:"
|
||||||
index := 41
|
index := 13
|
||||||
keys, newKeys := getKeysBeforeAndAfterFailure(t, prefix, index)
|
keys, newKeys := getKeysBeforeAndAfterFailure(t, prefix, index)
|
||||||
for k, v := range keys {
|
for k, v := range keys {
|
||||||
newV := newKeys[k]
|
newV := newKeys[k]
|
||||||
@@ -164,6 +159,17 @@ func getKeysBeforeAndAfterFailure(t *testing.T, prefix string, index int) (map[i
|
|||||||
return keys, newKeys
|
return keys, newKeys
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getTransferRatioOnFailure(t *testing.T, index int) float32 {
|
||||||
|
keys, newKeys := getKeysBeforeAndAfterFailure(t, "localhost:", index)
|
||||||
|
var transferred int
|
||||||
|
for k, v := range newKeys {
|
||||||
|
if v != keys[k] {
|
||||||
|
transferred++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return float32(transferred) / float32(requestSize)
|
||||||
|
}
|
||||||
|
|
||||||
type mockNode struct {
|
type mockNode struct {
|
||||||
addr string
|
addr string
|
||||||
id int
|
id int
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ package hash
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/md5"
|
"crypto/md5"
|
||||||
"fmt"
|
"encoding/hex"
|
||||||
|
|
||||||
"github.com/spaolacci/murmur3"
|
"github.com/spaolacci/murmur3"
|
||||||
)
|
)
|
||||||
@@ -20,6 +20,7 @@ func Md5(data []byte) []byte {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Md5Hex returns the md5 hex string of data.
|
// Md5Hex returns the md5 hex string of data.
|
||||||
|
// This function is optimized for better performance than fmt.Sprintf.
|
||||||
func Md5Hex(data []byte) string {
|
func Md5Hex(data []byte) string {
|
||||||
return fmt.Sprintf("%x", Md5(data))
|
return hex.EncodeToString(Md5(data))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,12 +7,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
fieldsContextKey contextKey
|
|
||||||
globalFields atomic.Value
|
globalFields atomic.Value
|
||||||
globalFieldsLock sync.Mutex
|
globalFieldsLock sync.Mutex
|
||||||
)
|
)
|
||||||
|
|
||||||
type contextKey struct{}
|
type fieldsKey struct{}
|
||||||
|
|
||||||
// AddGlobalFields adds global fields.
|
// AddGlobalFields adds global fields.
|
||||||
func AddGlobalFields(fields ...LogField) {
|
func AddGlobalFields(fields ...LogField) {
|
||||||
@@ -29,16 +28,16 @@ func AddGlobalFields(fields ...LogField) {
|
|||||||
|
|
||||||
// ContextWithFields returns a new context with the given fields.
|
// ContextWithFields returns a new context with the given fields.
|
||||||
func ContextWithFields(ctx context.Context, fields ...LogField) context.Context {
|
func ContextWithFields(ctx context.Context, fields ...LogField) context.Context {
|
||||||
if val := ctx.Value(fieldsContextKey); val != nil {
|
if val := ctx.Value(fieldsKey{}); val != nil {
|
||||||
if arr, ok := val.([]LogField); ok {
|
if arr, ok := val.([]LogField); ok {
|
||||||
allFields := make([]LogField, 0, len(arr)+len(fields))
|
allFields := make([]LogField, 0, len(arr)+len(fields))
|
||||||
allFields = append(allFields, arr...)
|
allFields = append(allFields, arr...)
|
||||||
allFields = append(allFields, fields...)
|
allFields = append(allFields, fields...)
|
||||||
return context.WithValue(ctx, fieldsContextKey, allFields)
|
return context.WithValue(ctx, fieldsKey{}, allFields)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return context.WithValue(ctx, fieldsContextKey, fields)
|
return context.WithValue(ctx, fieldsKey{}, fields)
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithFields returns a new logger with the given fields.
|
// WithFields returns a new logger with the given fields.
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ func TestAddGlobalFields(t *testing.T) {
|
|||||||
|
|
||||||
func TestContextWithFields(t *testing.T) {
|
func TestContextWithFields(t *testing.T) {
|
||||||
ctx := ContextWithFields(context.Background(), Field("a", 1), Field("b", 2))
|
ctx := ContextWithFields(context.Background(), Field("a", 1), Field("b", 2))
|
||||||
vals := ctx.Value(fieldsContextKey)
|
vals := ctx.Value(fieldsKey{})
|
||||||
assert.NotNil(t, vals)
|
assert.NotNil(t, vals)
|
||||||
fields, ok := vals.([]LogField)
|
fields, ok := vals.([]LogField)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
@@ -43,7 +43,7 @@ func TestContextWithFields(t *testing.T) {
|
|||||||
|
|
||||||
func TestWithFields(t *testing.T) {
|
func TestWithFields(t *testing.T) {
|
||||||
ctx := WithFields(context.Background(), Field("a", 1), Field("b", 2))
|
ctx := WithFields(context.Background(), Field("a", 1), Field("b", 2))
|
||||||
vals := ctx.Value(fieldsContextKey)
|
vals := ctx.Value(fieldsKey{})
|
||||||
assert.NotNil(t, vals)
|
assert.NotNil(t, vals)
|
||||||
fields, ok := vals.([]LogField)
|
fields, ok := vals.([]LogField)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
@@ -55,7 +55,7 @@ func TestWithFieldsAppend(t *testing.T) {
|
|||||||
ctx := context.WithValue(context.Background(), dummyKey, "dummy")
|
ctx := context.WithValue(context.Background(), dummyKey, "dummy")
|
||||||
ctx = ContextWithFields(ctx, Field("a", 1), Field("b", 2))
|
ctx = ContextWithFields(ctx, Field("a", 1), Field("b", 2))
|
||||||
ctx = ContextWithFields(ctx, Field("c", 3), Field("d", 4))
|
ctx = ContextWithFields(ctx, Field("c", 3), Field("d", 4))
|
||||||
vals := ctx.Value(fieldsContextKey)
|
vals := ctx.Value(fieldsKey{})
|
||||||
assert.NotNil(t, vals)
|
assert.NotNil(t, vals)
|
||||||
fields, ok := vals.([]LogField)
|
fields, ok := vals.([]LogField)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
@@ -80,8 +80,8 @@ func TestWithFieldsAppendCopy(t *testing.T) {
|
|||||||
ctxa := ContextWithFields(ctx, af)
|
ctxa := ContextWithFields(ctx, af)
|
||||||
ctxb := ContextWithFields(ctx, bf)
|
ctxb := ContextWithFields(ctx, bf)
|
||||||
|
|
||||||
assert.EqualValues(t, af, ctxa.Value(fieldsContextKey).([]LogField)[count])
|
assert.EqualValues(t, af, ctxa.Value(fieldsKey{}).([]LogField)[count])
|
||||||
assert.EqualValues(t, bf, ctxb.Value(fieldsContextKey).([]LogField)[count])
|
assert.EqualValues(t, bf, ctxb.Value(fieldsKey{}).([]LogField)[count])
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkAtomicValue(b *testing.B) {
|
func BenchmarkAtomicValue(b *testing.B) {
|
||||||
|
|||||||
@@ -560,7 +560,7 @@ func shallLogStat() bool {
|
|||||||
// If we check shallLog here, the fmt.Sprint might be called even if the log level is not enabled.
|
// If we check shallLog here, the fmt.Sprint might be called even if the log level is not enabled.
|
||||||
// The caller should check shallLog before calling this function.
|
// The caller should check shallLog before calling this function.
|
||||||
func writeDebug(val any, fields ...LogField) {
|
func writeDebug(val any, fields ...LogField) {
|
||||||
getWriter().Debug(val, addCaller(fields...)...)
|
getWriter().Debug(val, mergeGlobalFields(addCaller(fields...))...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// writeError writes v into the error log.
|
// writeError writes v into the error log.
|
||||||
@@ -568,7 +568,7 @@ func writeDebug(val any, fields ...LogField) {
|
|||||||
// If we check shallLog here, the fmt.Sprint might be called even if the log level is not enabled.
|
// If we check shallLog here, the fmt.Sprint might be called even if the log level is not enabled.
|
||||||
// The caller should check shallLog before calling this function.
|
// The caller should check shallLog before calling this function.
|
||||||
func writeError(val any, fields ...LogField) {
|
func writeError(val any, fields ...LogField) {
|
||||||
getWriter().Error(val, addCaller(fields...)...)
|
getWriter().Error(val, mergeGlobalFields(addCaller(fields...))...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// writeInfo writes v into info log.
|
// writeInfo writes v into info log.
|
||||||
@@ -576,7 +576,7 @@ func writeError(val any, fields ...LogField) {
|
|||||||
// If we check shallLog here, the fmt.Sprint might be called even if the log level is not enabled.
|
// If we check shallLog here, the fmt.Sprint might be called even if the log level is not enabled.
|
||||||
// The caller should check shallLog before calling this function.
|
// The caller should check shallLog before calling this function.
|
||||||
func writeInfo(val any, fields ...LogField) {
|
func writeInfo(val any, fields ...LogField) {
|
||||||
getWriter().Info(val, addCaller(fields...)...)
|
getWriter().Info(val, mergeGlobalFields(addCaller(fields...))...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// writeSevere writes v into severe log.
|
// writeSevere writes v into severe log.
|
||||||
@@ -592,7 +592,7 @@ func writeSevere(msg string) {
|
|||||||
// If we check shallLog here, the fmt.Sprint might be called even if the log level is not enabled.
|
// If we check shallLog here, the fmt.Sprint might be called even if the log level is not enabled.
|
||||||
// The caller should check shallLog before calling this function.
|
// The caller should check shallLog before calling this function.
|
||||||
func writeSlow(val any, fields ...LogField) {
|
func writeSlow(val any, fields ...LogField) {
|
||||||
getWriter().Slow(val, addCaller(fields...)...)
|
getWriter().Slow(val, mergeGlobalFields(addCaller(fields...))...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// writeStack writes v into stack log.
|
// writeStack writes v into stack log.
|
||||||
@@ -608,5 +608,5 @@ func writeStack(msg string) {
|
|||||||
// If we check shallLog here, the fmt.Sprint might be called even if the log level is not enabled.
|
// If we check shallLog here, the fmt.Sprint might be called even if the log level is not enabled.
|
||||||
// The caller should check shallLog before calling this function.
|
// The caller should check shallLog before calling this function.
|
||||||
func writeStat(msg string) {
|
func writeStat(msg string) {
|
||||||
getWriter().Stat(msg, addCaller()...)
|
getWriter().Stat(msg, mergeGlobalFields(addCaller())...)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -206,7 +206,9 @@ func (l *richLogger) WithFields(fields ...LogField) Logger {
|
|||||||
|
|
||||||
func (l *richLogger) buildFields(fields ...LogField) []LogField {
|
func (l *richLogger) buildFields(fields ...LogField) []LogField {
|
||||||
fields = append(l.fields, fields...)
|
fields = append(l.fields, fields...)
|
||||||
|
// caller field should always appear together with global fields
|
||||||
fields = append(fields, Field(callerKey, getCaller(callerDepth+l.callerSkip)))
|
fields = append(fields, Field(callerKey, getCaller(callerDepth+l.callerSkip)))
|
||||||
|
fields = mergeGlobalFields(fields)
|
||||||
|
|
||||||
if l.ctx == nil {
|
if l.ctx == nil {
|
||||||
return fields
|
return fields
|
||||||
@@ -222,7 +224,7 @@ func (l *richLogger) buildFields(fields ...LogField) []LogField {
|
|||||||
fields = append(fields, Field(spanKey, spanID))
|
fields = append(fields, Field(spanKey, spanID))
|
||||||
}
|
}
|
||||||
|
|
||||||
val := l.ctx.Value(fieldsContextKey)
|
val := l.ctx.Value(fieldsKey{})
|
||||||
if val != nil {
|
if val != nil {
|
||||||
if arr, ok := val.([]LogField); ok {
|
if arr, ok := val.([]LogField); ok {
|
||||||
fields = append(fields, arr...)
|
fields = append(fields, arr...)
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
dateFormat = "2006-01-02"
|
|
||||||
hoursPerDay = 24
|
hoursPerDay = 24
|
||||||
bufferSize = 100
|
bufferSize = 100
|
||||||
defaultDirMode = 0o755
|
defaultDirMode = 0o755
|
||||||
@@ -116,7 +115,7 @@ func (r *DailyRotateRule) OutdatedFiles() []string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var buf strings.Builder
|
var buf strings.Builder
|
||||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay*r.days)).Format(dateFormat)
|
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay*r.days)).Format(time.DateOnly)
|
||||||
buf.WriteString(r.filename)
|
buf.WriteString(r.filename)
|
||||||
buf.WriteString(r.delimiter)
|
buf.WriteString(r.delimiter)
|
||||||
buf.WriteString(boundary)
|
buf.WriteString(boundary)
|
||||||
@@ -425,7 +424,7 @@ func compressLogFile(file string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func getNowDate() string {
|
func getNowDate() string {
|
||||||
return time.Now().Format(dateFormat)
|
return time.Now().Format(time.DateOnly)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getNowDateInRFC3339Format() string {
|
func getNowDateInRFC3339Format() string {
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ func TestDailyRotateRuleOutdatedFiles(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("temp files", func(t *testing.T) {
|
t.Run("temp files", func(t *testing.T) {
|
||||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
|
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(time.DateOnly)
|
||||||
f1, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
f1, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
_ = f1.Close()
|
_ = f1.Close()
|
||||||
@@ -73,7 +73,7 @@ func TestDailyRotateRuleOutdatedFiles(t *testing.T) {
|
|||||||
|
|
||||||
func TestDailyRotateRuleShallRotate(t *testing.T) {
|
func TestDailyRotateRuleShallRotate(t *testing.T) {
|
||||||
var rule DailyRotateRule
|
var rule DailyRotateRule
|
||||||
rule.rotatedTime = time.Now().Add(time.Hour * 24).Format(dateFormat)
|
rule.rotatedTime = time.Now().Add(time.Hour * 24).Format(time.DateOnly)
|
||||||
assert.True(t, rule.ShallRotate(0))
|
assert.True(t, rule.ShallRotate(0))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -117,12 +117,12 @@ func TestSizeLimitRotateRuleOutdatedFiles(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("temp files", func(t *testing.T) {
|
t.Run("temp files", func(t *testing.T) {
|
||||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
|
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(time.DateOnly)
|
||||||
f1, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
f1, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
f2, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
f2, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
boundary1 := time.Now().Add(time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
|
boundary1 := time.Now().Add(time.Hour * time.Duration(hoursPerDay) * 2).Format(time.DateOnly)
|
||||||
f3, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary1)
|
f3, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary1)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
@@ -144,12 +144,12 @@ func TestSizeLimitRotateRuleOutdatedFiles(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("no backups", func(t *testing.T) {
|
t.Run("no backups", func(t *testing.T) {
|
||||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
|
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(time.DateOnly)
|
||||||
f1, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
f1, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
f2, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
f2, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
boundary1 := time.Now().Add(time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
|
boundary1 := time.Now().Add(time.Hour * time.Duration(hoursPerDay) * 2).Format(time.DateOnly)
|
||||||
f3, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary1)
|
f3, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary1)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
@@ -319,7 +319,7 @@ func TestRotateLoggerWrite(t *testing.T) {
|
|||||||
}
|
}
|
||||||
// the following write calls cannot be changed to Write, because of DATA RACE.
|
// the following write calls cannot be changed to Write, because of DATA RACE.
|
||||||
logger.write([]byte(`foo`))
|
logger.write([]byte(`foo`))
|
||||||
rule.rotatedTime = time.Now().Add(-time.Hour * 24).Format(dateFormat)
|
rule.rotatedTime = time.Now().Add(-time.Hour * 24).Format(time.DateOnly)
|
||||||
logger.write([]byte(`bar`))
|
logger.write([]byte(`bar`))
|
||||||
logger.Close()
|
logger.Close()
|
||||||
logger.write([]byte(`baz`))
|
logger.write([]byte(`baz`))
|
||||||
@@ -447,7 +447,7 @@ func TestRotateLoggerWithSizeLimitRotateRuleWrite(t *testing.T) {
|
|||||||
}
|
}
|
||||||
// the following write calls cannot be changed to Write, because of DATA RACE.
|
// the following write calls cannot be changed to Write, because of DATA RACE.
|
||||||
logger.write([]byte(`foo`))
|
logger.write([]byte(`foo`))
|
||||||
rule.rotatedTime = time.Now().Add(-time.Hour * 24).Format(dateFormat)
|
rule.rotatedTime = time.Now().Add(-time.Hour * 24).Format(time.DateOnly)
|
||||||
logger.write([]byte(`bar`))
|
logger.write([]byte(`bar`))
|
||||||
logger.Close()
|
logger.Close()
|
||||||
logger.write([]byte(`baz`))
|
logger.write([]byte(`baz`))
|
||||||
|
|||||||
@@ -17,15 +17,27 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
|
// Writer is the interface for writing logs.
|
||||||
|
// It's designed to let users customize their own log writer,
|
||||||
|
// such as writing logs to a kafka, a database, or using third-party loggers.
|
||||||
Writer interface {
|
Writer interface {
|
||||||
|
// Alert sends an alert message, if your writer implemented alerting functionality.
|
||||||
Alert(v any)
|
Alert(v any)
|
||||||
|
// Close closes the writer.
|
||||||
Close() error
|
Close() error
|
||||||
|
// Debug logs a message at debug level.
|
||||||
Debug(v any, fields ...LogField)
|
Debug(v any, fields ...LogField)
|
||||||
|
// Error logs a message at error level.
|
||||||
Error(v any, fields ...LogField)
|
Error(v any, fields ...LogField)
|
||||||
|
// Info logs a message at info level.
|
||||||
Info(v any, fields ...LogField)
|
Info(v any, fields ...LogField)
|
||||||
|
// Severe logs a message at severe level.
|
||||||
Severe(v any)
|
Severe(v any)
|
||||||
|
// Slow logs a message at slow level.
|
||||||
Slow(v any, fields ...LogField)
|
Slow(v any, fields ...LogField)
|
||||||
|
// Stack logs a message at error level.
|
||||||
Stack(v any)
|
Stack(v any)
|
||||||
|
// Stat logs a message at stat level.
|
||||||
Stat(v any, fields ...LogField)
|
Stat(v any, fields ...LogField)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -324,20 +336,6 @@ func buildPlainFields(fields logEntry) []string {
|
|||||||
return items
|
return items
|
||||||
}
|
}
|
||||||
|
|
||||||
func combineGlobalFields(fields []LogField) []LogField {
|
|
||||||
globals := globalFields.Load()
|
|
||||||
if globals == nil {
|
|
||||||
return fields
|
|
||||||
}
|
|
||||||
|
|
||||||
gf := globals.([]LogField)
|
|
||||||
ret := make([]LogField, 0, len(gf)+len(fields))
|
|
||||||
ret = append(ret, gf...)
|
|
||||||
ret = append(ret, fields...)
|
|
||||||
|
|
||||||
return ret
|
|
||||||
}
|
|
||||||
|
|
||||||
func marshalJson(t interface{}) ([]byte, error) {
|
func marshalJson(t interface{}) ([]byte, error) {
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
encoder := json.NewEncoder(&buf)
|
encoder := json.NewEncoder(&buf)
|
||||||
@@ -352,6 +350,20 @@ func marshalJson(t interface{}) ([]byte, error) {
|
|||||||
return buf.Bytes(), err
|
return buf.Bytes(), err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func mergeGlobalFields(fields []LogField) []LogField {
|
||||||
|
globals := globalFields.Load()
|
||||||
|
if globals == nil {
|
||||||
|
return fields
|
||||||
|
}
|
||||||
|
|
||||||
|
gf := globals.([]LogField)
|
||||||
|
ret := make([]LogField, 0, len(gf)+len(fields))
|
||||||
|
ret = append(ret, gf...)
|
||||||
|
ret = append(ret, fields...)
|
||||||
|
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
func output(writer io.Writer, level string, val any, fields ...LogField) {
|
func output(writer io.Writer, level string, val any, fields ...LogField) {
|
||||||
// only truncate string content, don't know how to truncate the values of other types.
|
// only truncate string content, don't know how to truncate the values of other types.
|
||||||
if v, ok := val.(string); ok {
|
if v, ok := val.(string); ok {
|
||||||
@@ -362,7 +374,6 @@ func output(writer io.Writer, level string, val any, fields ...LogField) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fields = combineGlobalFields(fields)
|
|
||||||
// +3 for timestamp, level and content
|
// +3 for timestamp, level and content
|
||||||
entry := make(logEntry, len(fields)+3)
|
entry := make(logEntry, len(fields)+3)
|
||||||
for _, field := range fields {
|
for _, field := range fields {
|
||||||
|
|||||||
@@ -13,6 +13,15 @@ const (
|
|||||||
|
|
||||||
// Marshal marshals the given val and returns the map that contains the fields.
|
// Marshal marshals the given val and returns the map that contains the fields.
|
||||||
// optional=another is not implemented, and it's hard to implement and not commonly used.
|
// optional=another is not implemented, and it's hard to implement and not commonly used.
|
||||||
|
// support anonymous field, e.g.:
|
||||||
|
//
|
||||||
|
// type Foo struct {
|
||||||
|
// Token string `header:"token"`
|
||||||
|
// }
|
||||||
|
// type FooB struct {
|
||||||
|
// Foo
|
||||||
|
// Bar string `json:"bar"`
|
||||||
|
// }
|
||||||
func Marshal(val any) (map[string]map[string]any, error) {
|
func Marshal(val any) (map[string]map[string]any, error) {
|
||||||
ret := make(map[string]map[string]any)
|
ret := make(map[string]map[string]any)
|
||||||
tp := reflect.TypeOf(val)
|
tp := reflect.TypeOf(val)
|
||||||
@@ -44,6 +53,16 @@ func getTag(field reflect.StructField) (string, bool) {
|
|||||||
return strings.TrimSpace(tag), false
|
return strings.TrimSpace(tag), false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func insertValue(collector map[string]map[string]any, tag string, key string, val any) {
|
||||||
|
if m, ok := collector[tag]; ok {
|
||||||
|
m[key] = val
|
||||||
|
} else {
|
||||||
|
collector[tag] = map[string]any{
|
||||||
|
key: val,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func processMember(field reflect.StructField, value reflect.Value,
|
func processMember(field reflect.StructField, value reflect.Value,
|
||||||
collector map[string]map[string]any) error {
|
collector map[string]map[string]any) error {
|
||||||
var key string
|
var key string
|
||||||
@@ -69,15 +88,20 @@ func processMember(field reflect.StructField, value reflect.Value,
|
|||||||
val = fmt.Sprint(val)
|
val = fmt.Sprint(val)
|
||||||
}
|
}
|
||||||
|
|
||||||
m, ok := collector[tag]
|
if field.Anonymous {
|
||||||
if ok {
|
anonCollector, err := Marshal(val)
|
||||||
m[key] = val
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for anonTag, anonMap := range anonCollector {
|
||||||
|
for anonKey, anonVal := range anonMap {
|
||||||
|
insertValue(collector, anonTag, anonKey, anonVal)
|
||||||
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
m = map[string]any{
|
insertValue(collector, tag, key, val)
|
||||||
key: val,
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
collector[tag] = m
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -118,7 +142,7 @@ func validateOptional(field reflect.StructField, value reflect.Value) error {
|
|||||||
if value.IsNil() {
|
if value.IsNil() {
|
||||||
return fmt.Errorf("field %q is nil", field.Name)
|
return fmt.Errorf("field %q is nil", field.Name)
|
||||||
}
|
}
|
||||||
case reflect.Array, reflect.Slice, reflect.Map:
|
case reflect.Slice, reflect.Map:
|
||||||
if value.IsNil() || value.Len() == 0 {
|
if value.IsNil() || value.Len() == 0 {
|
||||||
return fmt.Errorf("field %q is empty", field.Name)
|
return fmt.Errorf("field %q is empty", field.Name)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -27,6 +27,124 @@ func TestMarshal(t *testing.T) {
|
|||||||
assert.True(t, m[emptyTag]["Anonymous"].(bool))
|
assert.True(t, m[emptyTag]["Anonymous"].(bool))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMarshal_Anonymous(t *testing.T) {
|
||||||
|
t.Run("anonymous", func(t *testing.T) {
|
||||||
|
type BaseHeader struct {
|
||||||
|
Token string `header:"token"`
|
||||||
|
}
|
||||||
|
v := struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Address string `json:"address,options=[beijing,shanghai]"`
|
||||||
|
Age int `json:"age"`
|
||||||
|
BaseHeader
|
||||||
|
}{
|
||||||
|
Name: "kevin",
|
||||||
|
Address: "shanghai",
|
||||||
|
Age: 20,
|
||||||
|
BaseHeader: BaseHeader{
|
||||||
|
Token: "token_xxx",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
m, err := Marshal(v)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "kevin", m["json"]["name"])
|
||||||
|
assert.Equal(t, "shanghai", m["json"]["address"])
|
||||||
|
assert.Equal(t, 20, m["json"]["age"].(int))
|
||||||
|
assert.Equal(t, "token_xxx", m["header"]["token"])
|
||||||
|
|
||||||
|
v1 := struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Address string `json:"address,options=[beijing,shanghai]"`
|
||||||
|
Age int `json:"age"`
|
||||||
|
BaseHeader
|
||||||
|
}{
|
||||||
|
Name: "kevin",
|
||||||
|
Address: "shanghai",
|
||||||
|
Age: 20,
|
||||||
|
}
|
||||||
|
m1, err1 := Marshal(v1)
|
||||||
|
assert.Nil(t, err1)
|
||||||
|
assert.Equal(t, "kevin", m1["json"]["name"])
|
||||||
|
assert.Equal(t, "shanghai", m1["json"]["address"])
|
||||||
|
assert.Equal(t, 20, m1["json"]["age"].(int))
|
||||||
|
|
||||||
|
type AnotherHeader struct {
|
||||||
|
Version string `header:"version"`
|
||||||
|
}
|
||||||
|
v2 := struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Address string `json:"address,options=[beijing,shanghai]"`
|
||||||
|
Age int `json:"age"`
|
||||||
|
BaseHeader
|
||||||
|
AnotherHeader
|
||||||
|
}{
|
||||||
|
Name: "kevin",
|
||||||
|
Address: "shanghai",
|
||||||
|
Age: 20,
|
||||||
|
BaseHeader: BaseHeader{
|
||||||
|
Token: "token_xxx",
|
||||||
|
},
|
||||||
|
AnotherHeader: AnotherHeader{
|
||||||
|
Version: "v1.0",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
m2, err2 := Marshal(v2)
|
||||||
|
assert.Nil(t, err2)
|
||||||
|
assert.Equal(t, "kevin", m2["json"]["name"])
|
||||||
|
assert.Equal(t, "shanghai", m2["json"]["address"])
|
||||||
|
assert.Equal(t, 20, m2["json"]["age"].(int))
|
||||||
|
assert.Equal(t, "token_xxx", m2["header"]["token"])
|
||||||
|
assert.Equal(t, "v1.0", m2["header"]["version"])
|
||||||
|
|
||||||
|
type PointerHeader struct {
|
||||||
|
Ref *string `header:"ref"`
|
||||||
|
}
|
||||||
|
ref := "reference"
|
||||||
|
v3 := struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Address string `json:"address,options=[beijing,shanghai]"`
|
||||||
|
Age int `json:"age"`
|
||||||
|
PointerHeader
|
||||||
|
}{
|
||||||
|
Name: "kevin",
|
||||||
|
Address: "shanghai",
|
||||||
|
Age: 20,
|
||||||
|
PointerHeader: PointerHeader{
|
||||||
|
Ref: &ref,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
m3, err3 := Marshal(v3)
|
||||||
|
assert.Nil(t, err3)
|
||||||
|
assert.Equal(t, "kevin", m3["json"]["name"])
|
||||||
|
assert.Equal(t, "shanghai", m3["json"]["address"])
|
||||||
|
assert.Equal(t, 20, m3["json"]["age"].(int))
|
||||||
|
assert.Equal(t, "reference", *m3["header"]["ref"].(*string))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("bad anonymous", func(t *testing.T) {
|
||||||
|
type BaseHeader struct {
|
||||||
|
Token string `json:"token,options=[a,b]"`
|
||||||
|
}
|
||||||
|
|
||||||
|
v := struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Address string `json:"address,options=[beijing,shanghai]"`
|
||||||
|
Age int `json:"age"`
|
||||||
|
BaseHeader
|
||||||
|
}{
|
||||||
|
Name: "kevin",
|
||||||
|
Address: "shanghai",
|
||||||
|
Age: 20,
|
||||||
|
BaseHeader: BaseHeader{
|
||||||
|
Token: "c",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := Marshal(v)
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestMarshal_Ptr(t *testing.T) {
|
func TestMarshal_Ptr(t *testing.T) {
|
||||||
v := &struct {
|
v := &struct {
|
||||||
Name string `path:"name"`
|
Name string `path:"name"`
|
||||||
@@ -344,3 +462,15 @@ func TestMarshal_FromString(t *testing.T) {
|
|||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, "10", m["json"]["age"].(string))
|
assert.Equal(t, "10", m["json"]["age"].(string))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMarshal_Array(t *testing.T) {
|
||||||
|
v := struct {
|
||||||
|
H [1]int `json:"h,string"`
|
||||||
|
}{
|
||||||
|
H: [1]int{1},
|
||||||
|
}
|
||||||
|
|
||||||
|
m, err := Marshal(v)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "[1]", m["json"]["h"].(string))
|
||||||
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -15,11 +16,9 @@ import (
|
|||||||
"github.com/zeromicro/go-zero/core/jsonx"
|
"github.com/zeromicro/go-zero/core/jsonx"
|
||||||
"github.com/zeromicro/go-zero/core/lang"
|
"github.com/zeromicro/go-zero/core/lang"
|
||||||
"github.com/zeromicro/go-zero/core/proc"
|
"github.com/zeromicro/go-zero/core/proc"
|
||||||
"github.com/zeromicro/go-zero/core/stringx"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
comma = ","
|
|
||||||
defaultKeyName = "key"
|
defaultKeyName = "key"
|
||||||
delimiter = '.'
|
delimiter = '.'
|
||||||
ignoreKey = "-"
|
ignoreKey = "-"
|
||||||
@@ -31,14 +30,15 @@ var (
|
|||||||
errValueNotSettable = errors.New("value is not settable")
|
errValueNotSettable = errors.New("value is not settable")
|
||||||
errValueNotStruct = errors.New("value type is not struct")
|
errValueNotStruct = errors.New("value type is not struct")
|
||||||
keyUnmarshaler = NewUnmarshaler(defaultKeyName)
|
keyUnmarshaler = NewUnmarshaler(defaultKeyName)
|
||||||
|
boolType = reflect.TypeOf(false)
|
||||||
durationType = reflect.TypeOf(time.Duration(0))
|
durationType = reflect.TypeOf(time.Duration(0))
|
||||||
|
stringType = reflect.TypeOf("")
|
||||||
cacheKeys = make(map[string][]string)
|
cacheKeys = make(map[string][]string)
|
||||||
cacheKeysLock sync.Mutex
|
cacheKeysLock sync.Mutex
|
||||||
defaultCache = make(map[string]any)
|
defaultCache = make(map[string]any)
|
||||||
defaultCacheLock sync.Mutex
|
defaultCacheLock sync.Mutex
|
||||||
emptyMap = map[string]any{}
|
emptyMap = map[string]any{}
|
||||||
emptyValue = reflect.ValueOf(lang.Placeholder)
|
emptyValue = reflect.ValueOf(lang.Placeholder)
|
||||||
stringSliceType = reflect.TypeOf([]string{})
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
@@ -152,10 +152,6 @@ func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value,
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if u.opts.fromArray {
|
|
||||||
refValue = makeStringSlice(refValue)
|
|
||||||
}
|
|
||||||
|
|
||||||
var valid bool
|
var valid bool
|
||||||
conv := reflect.MakeSlice(reflect.SliceOf(baseType), refValue.Len(), refValue.Cap())
|
conv := reflect.MakeSlice(reflect.SliceOf(baseType), refValue.Len(), refValue.Cap())
|
||||||
|
|
||||||
@@ -628,9 +624,19 @@ func (u *Unmarshaler) processFieldNotFromString(fieldType reflect.Type, value re
|
|||||||
|
|
||||||
return u.fillSliceFromString(fieldType, value, mapValue, fullName)
|
return u.fillSliceFromString(fieldType, value, mapValue, fullName)
|
||||||
case valueKind == reflect.String && derefedFieldType == durationType:
|
case valueKind == reflect.String && derefedFieldType == durationType:
|
||||||
return fillDurationValue(fieldType, value, mapValue.(string))
|
v, err := convertToString(mapValue, fullName)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return fillDurationValue(fieldType, value, v)
|
||||||
case valueKind == reflect.String && typeKind == reflect.Struct && u.implementsUnmarshaler(fieldType):
|
case valueKind == reflect.String && typeKind == reflect.Struct && u.implementsUnmarshaler(fieldType):
|
||||||
return u.fillUnmarshalerStruct(fieldType, value, mapValue.(string))
|
v, err := convertToString(mapValue, fullName)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return u.fillUnmarshalerStruct(fieldType, value, v)
|
||||||
default:
|
default:
|
||||||
return u.processFieldPrimitive(fieldType, value, mapValue, opts, fullName)
|
return u.processFieldPrimitive(fieldType, value, mapValue, opts, fullName)
|
||||||
}
|
}
|
||||||
@@ -761,24 +767,24 @@ func (u *Unmarshaler) processFieldWithEnvValue(fieldType reflect.Type, value ref
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
fieldKind := fieldType.Kind()
|
derefType := Deref(fieldType)
|
||||||
switch fieldKind {
|
switch derefType {
|
||||||
case reflect.Bool:
|
case boolType:
|
||||||
val, err := strconv.ParseBool(envVal)
|
val, err := strconv.ParseBool(envVal)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unmarshal field %q with environment variable, %w", fullName, err)
|
return fmt.Errorf("unmarshal field %q with environment variable, %w", fullName, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
value.SetBool(val)
|
SetValue(fieldType, value, reflect.ValueOf(val))
|
||||||
return nil
|
return nil
|
||||||
case durationType.Kind():
|
case durationType:
|
||||||
if err := fillDurationValue(fieldType, value, envVal); err != nil {
|
if err := fillDurationValue(fieldType, value, envVal); err != nil {
|
||||||
return fmt.Errorf("unmarshal field %q with environment variable, %w", fullName, err)
|
return fmt.Errorf("unmarshal field %q with environment variable, %w", fullName, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
case reflect.String:
|
case stringType:
|
||||||
value.SetString(envVal)
|
SetValue(fieldType, value, reflect.ValueOf(envVal))
|
||||||
return nil
|
return nil
|
||||||
default:
|
default:
|
||||||
return u.processFieldPrimitiveWithJSONNumber(fieldType, value, json.Number(envVal), opts, fullName)
|
return u.processFieldPrimitiveWithJSONNumber(fieldType, value, json.Number(envVal), opts, fullName)
|
||||||
@@ -900,7 +906,7 @@ func (u *Unmarshaler) processNamedFieldWithValueFromString(fieldType reflect.Typ
|
|||||||
valueKind.String())
|
valueKind.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
if !stringx.Contains(options, checkValue) {
|
if !slices.Contains(options, checkValue) {
|
||||||
return fmt.Errorf(`value "%s" for field %q is not defined in options "%v"`,
|
return fmt.Errorf(`value "%s" for field %q is not defined in options "%v"`,
|
||||||
mapValue, key, options)
|
mapValue, key, options)
|
||||||
}
|
}
|
||||||
@@ -1189,35 +1195,6 @@ func join(elem ...string) string {
|
|||||||
return builder.String()
|
return builder.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeStringSlice(refValue reflect.Value) reflect.Value {
|
|
||||||
if refValue.Len() != 1 {
|
|
||||||
return refValue
|
|
||||||
}
|
|
||||||
|
|
||||||
element := refValue.Index(0)
|
|
||||||
if element.Kind() != reflect.String {
|
|
||||||
return refValue
|
|
||||||
}
|
|
||||||
|
|
||||||
val, ok := element.Interface().(string)
|
|
||||||
if !ok {
|
|
||||||
return refValue
|
|
||||||
}
|
|
||||||
|
|
||||||
splits := strings.Split(val, comma)
|
|
||||||
if len(splits) <= 1 {
|
|
||||||
return refValue
|
|
||||||
}
|
|
||||||
|
|
||||||
slice := reflect.MakeSlice(stringSliceType, len(splits), len(splits))
|
|
||||||
for i, split := range splits {
|
|
||||||
// allow empty strings
|
|
||||||
slice.Index(i).Set(reflect.ValueOf(split))
|
|
||||||
}
|
|
||||||
|
|
||||||
return slice
|
|
||||||
}
|
|
||||||
|
|
||||||
func newInitError(name string) error {
|
func newInitError(name string) error {
|
||||||
return fmt.Errorf("field %q is not set", name)
|
return fmt.Errorf("field %q is not set", name)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -203,6 +203,20 @@ func TestUnmarshalDuration(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUnmarshalDurationUnexpectedError(t *testing.T) {
|
||||||
|
type inner struct {
|
||||||
|
Duration time.Duration `key:"duration"`
|
||||||
|
}
|
||||||
|
content := "{\"duration\": 1}"
|
||||||
|
var m = map[string]any{}
|
||||||
|
err := jsonx.Unmarshal([]byte(content), &m)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
var in inner
|
||||||
|
err = UnmarshalKey(m, &in)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "expect string")
|
||||||
|
}
|
||||||
|
|
||||||
func TestUnmarshalDurationDefault(t *testing.T) {
|
func TestUnmarshalDurationDefault(t *testing.T) {
|
||||||
type inner struct {
|
type inner struct {
|
||||||
Int int `key:"int"`
|
Int int `key:"int"`
|
||||||
@@ -1462,9 +1476,7 @@ func TestUnmarshalIntSlice(t *testing.T) {
|
|||||||
|
|
||||||
ast := assert.New(t)
|
ast := assert.New(t)
|
||||||
unmarshaler := NewUnmarshaler(defaultKeyName, WithFromArray())
|
unmarshaler := NewUnmarshaler(defaultKeyName, WithFromArray())
|
||||||
if ast.NoError(unmarshaler.Unmarshal(m, &v)) {
|
ast.Error(unmarshaler.Unmarshal(m, &v))
|
||||||
ast.ElementsMatch([]int{1, 2}, v.Ages)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1546,7 +1558,22 @@ func TestUnmarshalStringSliceFromString(t *testing.T) {
|
|||||||
ast := assert.New(t)
|
ast := assert.New(t)
|
||||||
unmarshaler := NewUnmarshaler(defaultKeyName, WithFromArray())
|
unmarshaler := NewUnmarshaler(defaultKeyName, WithFromArray())
|
||||||
if ast.NoError(unmarshaler.Unmarshal(m, &v)) {
|
if ast.NoError(unmarshaler.Unmarshal(m, &v)) {
|
||||||
ast.ElementsMatch([]string{"", ""}, v.Names)
|
ast.ElementsMatch([]string{","}, v.Names)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("slice from valid strings with comma", func(t *testing.T) {
|
||||||
|
var v struct {
|
||||||
|
Names []string `key:"names"`
|
||||||
|
}
|
||||||
|
m := map[string]any{
|
||||||
|
"names": []string{"aa,bb"},
|
||||||
|
}
|
||||||
|
|
||||||
|
ast := assert.New(t)
|
||||||
|
unmarshaler := NewUnmarshaler(defaultKeyName, WithFromArray())
|
||||||
|
if ast.NoError(unmarshaler.Unmarshal(m, &v)) {
|
||||||
|
ast.ElementsMatch([]string{"aa,bb"}, v.Names)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -4652,6 +4679,23 @@ func TestUnmarshal_EnvInt(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUnmarshal_EnvInt64(t *testing.T) {
|
||||||
|
type Value struct {
|
||||||
|
Age int64 `key:"age,env=TEST_NAME_INT64"`
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
envName = "TEST_NAME_INT64"
|
||||||
|
envVal = "88"
|
||||||
|
)
|
||||||
|
t.Setenv(envName, envVal)
|
||||||
|
|
||||||
|
var v Value
|
||||||
|
if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
|
||||||
|
assert.Equal(t, int64(88), v.Age)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestUnmarshal_EnvIntOverwrite(t *testing.T) {
|
func TestUnmarshal_EnvIntOverwrite(t *testing.T) {
|
||||||
type Value struct {
|
type Value struct {
|
||||||
Age int `key:"age,env=TEST_NAME_INT"`
|
Age int `key:"age,env=TEST_NAME_INT"`
|
||||||
@@ -4757,20 +4801,33 @@ func TestUnmarshal_EnvBoolBad(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshal_EnvDuration(t *testing.T) {
|
func TestUnmarshal_EnvDuration(t *testing.T) {
|
||||||
type Value struct {
|
|
||||||
Duration time.Duration `key:"duration,env=TEST_NAME_DURATION"`
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
envName = "TEST_NAME_DURATION"
|
envName = "TEST_NAME_DURATION"
|
||||||
envVal = "1s"
|
envVal = "1s"
|
||||||
)
|
)
|
||||||
t.Setenv(envName, envVal)
|
t.Setenv(envName, envVal)
|
||||||
|
|
||||||
|
t.Run("valid duration", func(t *testing.T) {
|
||||||
|
type Value struct {
|
||||||
|
Duration time.Duration `key:"duration,env=TEST_NAME_DURATION"`
|
||||||
|
}
|
||||||
|
|
||||||
var v Value
|
var v Value
|
||||||
if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
|
if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
|
||||||
assert.Equal(t, time.Second, v.Duration)
|
assert.Equal(t, time.Second, v.Duration)
|
||||||
}
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ptr of duration", func(t *testing.T) {
|
||||||
|
type Value struct {
|
||||||
|
Duration *time.Duration `key:"duration,env=TEST_NAME_DURATION"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var v Value
|
||||||
|
if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
|
||||||
|
assert.Equal(t, time.Second, *v.Duration)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshal_EnvDurationBadValue(t *testing.T) {
|
func TestUnmarshal_EnvDurationBadValue(t *testing.T) {
|
||||||
@@ -5982,6 +6039,16 @@ func TestUnmarshal_Unmarshaler(t *testing.T) {
|
|||||||
}, &v))
|
}, &v))
|
||||||
assert.Nil(t, v.Foo)
|
assert.Nil(t, v.Foo)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("json.Number", func(t *testing.T) {
|
||||||
|
v := struct {
|
||||||
|
Foo *mockUnmarshaler `json:"name"`
|
||||||
|
}{}
|
||||||
|
m := map[string]any{
|
||||||
|
"name": json.Number("123"),
|
||||||
|
}
|
||||||
|
assert.Error(t, UnmarshalJsonMap(m, &v))
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseJsonStringValue(t *testing.T) {
|
func TestParseJsonStringValue(t *testing.T) {
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -91,6 +92,15 @@ func ValidatePtr(v reflect.Value) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func convertToString(val any, fullName string) (string, error) {
|
||||||
|
v, ok := val.(string)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("expect string for field %s, but got type %T", fullName, val)
|
||||||
|
}
|
||||||
|
|
||||||
|
return v, nil
|
||||||
|
}
|
||||||
|
|
||||||
func convertTypeFromString(kind reflect.Kind, str string) (any, error) {
|
func convertTypeFromString(kind reflect.Kind, str string) (any, error) {
|
||||||
switch kind {
|
switch kind {
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
@@ -634,11 +644,11 @@ func validateValueInOptions(val any, options []string) error {
|
|||||||
if len(options) > 0 {
|
if len(options) > 0 {
|
||||||
switch v := val.(type) {
|
switch v := val.(type) {
|
||||||
case string:
|
case string:
|
||||||
if !stringx.Contains(options, v) {
|
if !slices.Contains(options, v) {
|
||||||
return fmt.Errorf(`error: value %q is not defined in options "%v"`, v, options)
|
return fmt.Errorf(`error: value %q is not defined in options "%v"`, v, options)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
if !stringx.Contains(options, Repr(v)) {
|
if !slices.Contains(options, Repr(v)) {
|
||||||
return fmt.Errorf(`error: value "%v" is not defined in options "%v"`, val, options)
|
return fmt.Errorf(`error: value "%v" is not defined in options "%v"`, val, options)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,19 +1,13 @@
|
|||||||
package mathx
|
package mathx
|
||||||
|
|
||||||
// MaxInt returns the larger one of a and b.
|
// MaxInt returns the larger one of a and b.
|
||||||
|
// Deprecated: use builtin max instead.
|
||||||
func MaxInt(a, b int) int {
|
func MaxInt(a, b int) int {
|
||||||
if a > b {
|
return max(a, b)
|
||||||
return a
|
|
||||||
}
|
|
||||||
|
|
||||||
return b
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MinInt returns the smaller one of a and b.
|
// MinInt returns the smaller one of a and b.
|
||||||
|
// Deprecated: use builtin min instead.
|
||||||
func MinInt(a, b int) int {
|
func MinInt(a, b int) int {
|
||||||
if a < b {
|
return min(a, b)
|
||||||
return a
|
|
||||||
}
|
|
||||||
|
|
||||||
return b
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -142,89 +142,6 @@ func MapReduceChan[T, U, V any](source <-chan T, mapper MapperFunc[T, U], reduce
|
|||||||
return mapReduceWithPanicChan(source, panicChan, mapper, reducer, opts...)
|
return mapReduceWithPanicChan(source, panicChan, mapper, reducer, opts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// mapReduceWithPanicChan maps all elements from source, and reduce the output elements with given reducer.
|
|
||||||
func mapReduceWithPanicChan[T, U, V any](source <-chan T, panicChan *onceChan, mapper MapperFunc[T, U],
|
|
||||||
reducer ReducerFunc[U, V], opts ...Option) (val V, err error) {
|
|
||||||
options := buildOptions(opts...)
|
|
||||||
// output is used to write the final result
|
|
||||||
output := make(chan V)
|
|
||||||
defer func() {
|
|
||||||
// reducer can only write once, if more, panic
|
|
||||||
for range output {
|
|
||||||
panic("more than one element written in reducer")
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// collector is used to collect data from mapper, and consume in reducer
|
|
||||||
collector := make(chan U, options.workers)
|
|
||||||
// if done is closed, all mappers and reducer should stop processing
|
|
||||||
done := make(chan struct{})
|
|
||||||
writer := newGuardedWriter(options.ctx, output, done)
|
|
||||||
var closeOnce sync.Once
|
|
||||||
// use atomic type to avoid data race
|
|
||||||
var retErr errorx.AtomicError
|
|
||||||
finish := func() {
|
|
||||||
closeOnce.Do(func() {
|
|
||||||
close(done)
|
|
||||||
close(output)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
cancel := once(func(err error) {
|
|
||||||
if err != nil {
|
|
||||||
retErr.Set(err)
|
|
||||||
} else {
|
|
||||||
retErr.Set(ErrCancelWithNil)
|
|
||||||
}
|
|
||||||
|
|
||||||
drain(source)
|
|
||||||
finish()
|
|
||||||
})
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
defer func() {
|
|
||||||
drain(collector)
|
|
||||||
if r := recover(); r != nil {
|
|
||||||
panicChan.write(r)
|
|
||||||
}
|
|
||||||
finish()
|
|
||||||
}()
|
|
||||||
|
|
||||||
reducer(collector, writer, cancel)
|
|
||||||
}()
|
|
||||||
|
|
||||||
go executeMappers(mapperContext[T, U]{
|
|
||||||
ctx: options.ctx,
|
|
||||||
mapper: func(item T, w Writer[U]) {
|
|
||||||
mapper(item, w, cancel)
|
|
||||||
},
|
|
||||||
source: source,
|
|
||||||
panicChan: panicChan,
|
|
||||||
collector: collector,
|
|
||||||
doneChan: done,
|
|
||||||
workers: options.workers,
|
|
||||||
})
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-options.ctx.Done():
|
|
||||||
cancel(context.DeadlineExceeded)
|
|
||||||
err = context.DeadlineExceeded
|
|
||||||
case v := <-panicChan.channel:
|
|
||||||
// drain output here, otherwise for loop panic in defer
|
|
||||||
drain(output)
|
|
||||||
panic(v)
|
|
||||||
case v, ok := <-output:
|
|
||||||
if e := retErr.Load(); e != nil {
|
|
||||||
err = e
|
|
||||||
} else if ok {
|
|
||||||
val = v
|
|
||||||
} else {
|
|
||||||
err = ErrReduceNoOutput
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// MapReduceVoid maps all elements generated from given generate,
|
// MapReduceVoid maps all elements generated from given generate,
|
||||||
// and reduce the output elements with given reducer.
|
// and reduce the output elements with given reducer.
|
||||||
func MapReduceVoid[T, U any](generate GenerateFunc[T], mapper MapperFunc[T, U],
|
func MapReduceVoid[T, U any](generate GenerateFunc[T], mapper MapperFunc[T, U],
|
||||||
@@ -330,6 +247,89 @@ func executeMappers[T, U any](mCtx mapperContext[T, U]) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// mapReduceWithPanicChan maps all elements from source, and reduce the output elements with given reducer.
|
||||||
|
func mapReduceWithPanicChan[T, U, V any](source <-chan T, panicChan *onceChan, mapper MapperFunc[T, U],
|
||||||
|
reducer ReducerFunc[U, V], opts ...Option) (val V, err error) {
|
||||||
|
options := buildOptions(opts...)
|
||||||
|
// output is used to write the final result
|
||||||
|
output := make(chan V)
|
||||||
|
defer func() {
|
||||||
|
// reducer can only write once, if more, panic
|
||||||
|
for range output {
|
||||||
|
panic("more than one element written in reducer")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// collector is used to collect data from mapper, and consume in reducer
|
||||||
|
collector := make(chan U, options.workers)
|
||||||
|
// if done is closed, all mappers and reducer should stop processing
|
||||||
|
done := make(chan struct{})
|
||||||
|
writer := newGuardedWriter(options.ctx, output, done)
|
||||||
|
var closeOnce sync.Once
|
||||||
|
// use atomic type to avoid data race
|
||||||
|
var retErr errorx.AtomicError
|
||||||
|
finish := func() {
|
||||||
|
closeOnce.Do(func() {
|
||||||
|
close(done)
|
||||||
|
close(output)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
cancel := once(func(err error) {
|
||||||
|
if err != nil {
|
||||||
|
retErr.Set(err)
|
||||||
|
} else {
|
||||||
|
retErr.Set(ErrCancelWithNil)
|
||||||
|
}
|
||||||
|
|
||||||
|
drain(source)
|
||||||
|
finish()
|
||||||
|
})
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer func() {
|
||||||
|
drain(collector)
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
panicChan.write(r)
|
||||||
|
}
|
||||||
|
finish()
|
||||||
|
}()
|
||||||
|
|
||||||
|
reducer(collector, writer, cancel)
|
||||||
|
}()
|
||||||
|
|
||||||
|
go executeMappers(mapperContext[T, U]{
|
||||||
|
ctx: options.ctx,
|
||||||
|
mapper: func(item T, w Writer[U]) {
|
||||||
|
mapper(item, w, cancel)
|
||||||
|
},
|
||||||
|
source: source,
|
||||||
|
panicChan: panicChan,
|
||||||
|
collector: collector,
|
||||||
|
doneChan: done,
|
||||||
|
workers: options.workers,
|
||||||
|
})
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-options.ctx.Done():
|
||||||
|
cancel(context.DeadlineExceeded)
|
||||||
|
err = context.DeadlineExceeded
|
||||||
|
case v := <-panicChan.channel:
|
||||||
|
// drain output here, otherwise for loop panic in defer
|
||||||
|
drain(output)
|
||||||
|
panic(v)
|
||||||
|
case v, ok := <-output:
|
||||||
|
if e := retErr.Load(); e != nil {
|
||||||
|
err = e
|
||||||
|
} else if ok {
|
||||||
|
val = v
|
||||||
|
} else {
|
||||||
|
err = ErrReduceNoOutput
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func newOptions() *mapReduceOptions {
|
func newOptions() *mapReduceOptions {
|
||||||
return &mapReduceOptions{
|
return &mapReduceOptions{
|
||||||
ctx: context.Background(),
|
ctx: context.Background(),
|
||||||
|
|||||||
@@ -1,13 +1,12 @@
|
|||||||
package prof
|
package prof
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"fmt"
|
||||||
"strconv"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/olekukonko/tablewriter"
|
|
||||||
"github.com/zeromicro/go-zero/core/logx"
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
"github.com/zeromicro/go-zero/core/threading"
|
"github.com/zeromicro/go-zero/core/threading"
|
||||||
)
|
)
|
||||||
@@ -28,46 +27,15 @@ type (
|
|||||||
|
|
||||||
const flushInterval = 5 * time.Minute
|
const flushInterval = 5 * time.Minute
|
||||||
|
|
||||||
var (
|
var pc = &profileCenter{
|
||||||
pc = &profileCenter{
|
|
||||||
slots: make(map[string]*profileSlot),
|
slots: make(map[string]*profileSlot),
|
||||||
}
|
}
|
||||||
once sync.Once
|
|
||||||
)
|
|
||||||
|
|
||||||
func report(name string, duration time.Duration) {
|
func init() {
|
||||||
updated := func() bool {
|
flushRepeatedly()
|
||||||
pc.lock.RLock()
|
|
||||||
defer pc.lock.RUnlock()
|
|
||||||
|
|
||||||
slot, ok := pc.slots[name]
|
|
||||||
if ok {
|
|
||||||
atomic.AddInt64(&slot.lifecount, 1)
|
|
||||||
atomic.AddInt64(&slot.lastcount, 1)
|
|
||||||
atomic.AddInt64(&slot.lifecycle, int64(duration))
|
|
||||||
atomic.AddInt64(&slot.lastcycle, int64(duration))
|
|
||||||
}
|
|
||||||
return ok
|
|
||||||
}()
|
|
||||||
|
|
||||||
if !updated {
|
|
||||||
func() {
|
|
||||||
pc.lock.Lock()
|
|
||||||
defer pc.lock.Unlock()
|
|
||||||
|
|
||||||
pc.slots[name] = &profileSlot{
|
|
||||||
lifecount: 1,
|
|
||||||
lastcount: 1,
|
|
||||||
lifecycle: int64(duration),
|
|
||||||
lastcycle: int64(duration),
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
once.Do(flushRepeatly)
|
func flushRepeatedly() {
|
||||||
}
|
|
||||||
|
|
||||||
func flushRepeatly() {
|
|
||||||
threading.GoSafe(func() {
|
threading.GoSafe(func() {
|
||||||
for {
|
for {
|
||||||
time.Sleep(flushInterval)
|
time.Sleep(flushInterval)
|
||||||
@@ -76,42 +44,64 @@ func flushRepeatly() {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func report(name string, duration time.Duration) {
|
||||||
|
slot := loadOrStoreSlot(name, duration)
|
||||||
|
|
||||||
|
atomic.AddInt64(&slot.lifecount, 1)
|
||||||
|
atomic.AddInt64(&slot.lastcount, 1)
|
||||||
|
atomic.AddInt64(&slot.lifecycle, int64(duration))
|
||||||
|
atomic.AddInt64(&slot.lastcycle, int64(duration))
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadOrStoreSlot(name string, duration time.Duration) *profileSlot {
|
||||||
|
pc.lock.RLock()
|
||||||
|
slot, ok := pc.slots[name]
|
||||||
|
pc.lock.RUnlock()
|
||||||
|
|
||||||
|
if ok {
|
||||||
|
return slot
|
||||||
|
}
|
||||||
|
|
||||||
|
pc.lock.Lock()
|
||||||
|
defer pc.lock.Unlock()
|
||||||
|
|
||||||
|
// double-check
|
||||||
|
if slot, ok = pc.slots[name]; ok {
|
||||||
|
return slot
|
||||||
|
}
|
||||||
|
|
||||||
|
slot = &profileSlot{}
|
||||||
|
pc.slots[name] = slot
|
||||||
|
return slot
|
||||||
|
}
|
||||||
|
|
||||||
func generateReport() string {
|
func generateReport() string {
|
||||||
var buffer bytes.Buffer
|
var builder strings.Builder
|
||||||
buffer.WriteString("Profiling report\n")
|
builder.WriteString("Profiling report\n")
|
||||||
var data [][]string
|
builder.WriteString("QUEUE,LIFECOUNT,LIFECYCLE,LASTCOUNT,LASTCYCLE\n")
|
||||||
|
|
||||||
calcFn := func(total, count int64) string {
|
calcFn := func(total, count int64) string {
|
||||||
if count == 0 {
|
if count == 0 {
|
||||||
return "-"
|
return "-"
|
||||||
}
|
}
|
||||||
|
|
||||||
return (time.Duration(total) / time.Duration(count)).String()
|
return (time.Duration(total) / time.Duration(count)).String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func() {
|
|
||||||
pc.lock.Lock()
|
pc.lock.Lock()
|
||||||
defer pc.lock.Unlock()
|
|
||||||
|
|
||||||
for key, slot := range pc.slots {
|
for key, slot := range pc.slots {
|
||||||
data = append(data, []string{
|
builder.WriteString(fmt.Sprintf("%s,%d,%s,%d,%s\n",
|
||||||
key,
|
key,
|
||||||
strconv.FormatInt(slot.lifecount, 10),
|
slot.lifecount,
|
||||||
calcFn(slot.lifecycle, slot.lifecount),
|
calcFn(slot.lifecycle, slot.lifecount),
|
||||||
strconv.FormatInt(slot.lastcount, 10),
|
slot.lastcount,
|
||||||
calcFn(slot.lastcycle, slot.lastcount),
|
calcFn(slot.lastcycle, slot.lastcount),
|
||||||
})
|
))
|
||||||
|
|
||||||
// reset the data for last cycle
|
// reset last cycle stats
|
||||||
slot.lastcount = 0
|
atomic.StoreInt64(&slot.lastcount, 0)
|
||||||
slot.lastcycle = 0
|
atomic.StoreInt64(&slot.lastcycle, 0)
|
||||||
}
|
}
|
||||||
}()
|
pc.lock.Unlock()
|
||||||
|
|
||||||
table := tablewriter.NewWriter(&buffer)
|
return builder.String()
|
||||||
table.SetHeader([]string{"QUEUE", "LIFECOUNT", "LIFECYCLE", "LASTCOUNT", "LASTCYCLE"})
|
|
||||||
table.SetBorder(false)
|
|
||||||
table.AppendBulk(data)
|
|
||||||
table.Render()
|
|
||||||
|
|
||||||
return buffer.String()
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestReport(t *testing.T) {
|
func TestReport(t *testing.T) {
|
||||||
once.Do(func() {})
|
|
||||||
assert.NotContains(t, generateReport(), "foo")
|
assert.NotContains(t, generateReport(), "foo")
|
||||||
report("foo", time.Second)
|
report("foo", time.Second)
|
||||||
assert.Contains(t, generateReport(), "foo")
|
assert.Contains(t, generateReport(), "foo")
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"github.com/zeromicro/go-zero/core/stat"
|
"github.com/zeromicro/go-zero/core/stat"
|
||||||
"github.com/zeromicro/go-zero/core/trace"
|
"github.com/zeromicro/go-zero/core/trace"
|
||||||
"github.com/zeromicro/go-zero/internal/devserver"
|
"github.com/zeromicro/go-zero/internal/devserver"
|
||||||
|
"github.com/zeromicro/go-zero/internal/profiling"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -38,6 +39,8 @@ type (
|
|||||||
Telemetry trace.Config `json:",optional"`
|
Telemetry trace.Config `json:",optional"`
|
||||||
DevServer DevServerConfig `json:",optional"`
|
DevServer DevServerConfig `json:",optional"`
|
||||||
Shutdown proc.ShutdownConf `json:",optional"`
|
Shutdown proc.ShutdownConf `json:",optional"`
|
||||||
|
// Profiling is the configuration for continuous profiling.
|
||||||
|
Profiling profiling.Config `json:",optional"`
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -70,7 +73,9 @@ func (sc ServiceConf) SetUp() error {
|
|||||||
if len(sc.MetricsUrl) > 0 {
|
if len(sc.MetricsUrl) > 0 {
|
||||||
stat.SetReportWriter(stat.NewRemoteWriter(sc.MetricsUrl))
|
stat.SetReportWriter(stat.NewRemoteWriter(sc.MetricsUrl))
|
||||||
}
|
}
|
||||||
|
|
||||||
devserver.StartAgent(sc.DevServer)
|
devserver.StartAgent(sc.DevServer)
|
||||||
|
profiling.Start(sc.Profiling)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/core/logx"
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
"github.com/zeromicro/go-zero/core/proc"
|
"github.com/zeromicro/go-zero/core/proc"
|
||||||
"github.com/zeromicro/go-zero/core/syncx"
|
|
||||||
"github.com/zeromicro/go-zero/core/threading"
|
"github.com/zeromicro/go-zero/core/threading"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -35,7 +36,7 @@ type (
|
|||||||
// NewServiceGroup returns a ServiceGroup.
|
// NewServiceGroup returns a ServiceGroup.
|
||||||
func NewServiceGroup() *ServiceGroup {
|
func NewServiceGroup() *ServiceGroup {
|
||||||
sg := new(ServiceGroup)
|
sg := new(ServiceGroup)
|
||||||
sg.stopOnce = syncx.Once(sg.doStop)
|
sg.stopOnce = sync.OnceFunc(sg.doStop)
|
||||||
return sg
|
return sg
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ import (
|
|||||||
const (
|
const (
|
||||||
clusterNameKey = "CLUSTER_NAME"
|
clusterNameKey = "CLUSTER_NAME"
|
||||||
testEnv = "test.v"
|
testEnv = "test.v"
|
||||||
timeFormat = "2006-01-02 15:04:05"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -45,7 +44,7 @@ func Report(msg string) {
|
|||||||
if fn != nil {
|
if fn != nil {
|
||||||
reported := lessExecutor.DoOrDiscard(func() {
|
reported := lessExecutor.DoOrDiscard(func() {
|
||||||
var builder strings.Builder
|
var builder strings.Builder
|
||||||
builder.WriteString(fmt.Sprintln(time.Now().Format(timeFormat)))
|
builder.WriteString(fmt.Sprintln(time.Now().Format(time.DateTime)))
|
||||||
if len(clusterName) > 0 {
|
if len(clusterName) > 0 {
|
||||||
builder.WriteString(fmt.Sprintf("cluster: %s\n", clusterName))
|
builder.WriteString(fmt.Sprintf("cluster: %s\n", clusterName))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -609,6 +609,28 @@ func (s *Redis) GetBitCtx(ctx context.Context, key string, offset int64) (int, e
|
|||||||
return int(v), nil
|
return int(v), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetDel is the implementation of redis getdel command.
|
||||||
|
// Available since: redis version 6.2.0
|
||||||
|
func (s *Redis) GetDel(key string) (string, error) {
|
||||||
|
return s.GetDelCtx(context.Background(), key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDelCtx is the implementation of redis getdel command.
|
||||||
|
// Available since: redis version 6.2.0
|
||||||
|
func (s *Redis) GetDelCtx(ctx context.Context, key string) (string, error) {
|
||||||
|
conn, err := getRedis(s)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
val, err := conn.GetDel(ctx, key).Result()
|
||||||
|
if errors.Is(err, red.Nil) {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return val, err
|
||||||
|
}
|
||||||
|
|
||||||
// GetSet is the implementation of redis getset command.
|
// GetSet is the implementation of redis getset command.
|
||||||
func (s *Redis) GetSet(key, value string) (string, error) {
|
func (s *Redis) GetSet(key, value string) (string, error) {
|
||||||
return s.GetSetCtx(context.Background(), key, value)
|
return s.GetSetCtx(context.Background(), key, value)
|
||||||
|
|||||||
@@ -1071,6 +1071,34 @@ func TestRedis_Set(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRedis_GetDel(t *testing.T) {
|
||||||
|
t.Run("get_del", func(t *testing.T) {
|
||||||
|
runOnRedis(t, func(client *Redis) {
|
||||||
|
val, err := newRedis(client.Addr).GetDel("hello")
|
||||||
|
assert.Equal(t, "", val)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
err = client.Set("hello", "world")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
val, err = client.Get("hello")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "world", val)
|
||||||
|
val, err = client.GetDel("hello")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "world", val)
|
||||||
|
val, err = client.Get("hello")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "", val)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("get_del_with_error", func(t *testing.T) {
|
||||||
|
runOnRedisWithError(t, func(client *Redis) {
|
||||||
|
_, err := newRedis(client.Addr, badType()).GetDel("hello")
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestRedis_GetSet(t *testing.T) {
|
func TestRedis_GetSet(t *testing.T) {
|
||||||
t.Run("set_get", func(t *testing.T) {
|
t.Run("set_get", func(t *testing.T) {
|
||||||
runOnRedis(t, func(client *Redis) {
|
runOnRedis(t, func(client *Redis) {
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ func CreateBlockingNode(r *Redis) (ClosableNode, error) {
|
|||||||
case NodeType:
|
case NodeType:
|
||||||
client := red.NewClient(&red.Options{
|
client := red.NewClient(&red.Options{
|
||||||
Addr: r.Addr,
|
Addr: r.Addr,
|
||||||
|
Username: r.User,
|
||||||
Password: r.Pass,
|
Password: r.Pass,
|
||||||
DB: defaultDatabase,
|
DB: defaultDatabase,
|
||||||
MaxRetries: maxRetries,
|
MaxRetries: maxRetries,
|
||||||
@@ -32,6 +33,7 @@ func CreateBlockingNode(r *Redis) (ClosableNode, error) {
|
|||||||
case ClusterType:
|
case ClusterType:
|
||||||
client := red.NewClusterClient(&red.ClusterOptions{
|
client := red.NewClusterClient(&red.ClusterOptions{
|
||||||
Addrs: splitClusterAddrs(r.Addr),
|
Addrs: splitClusterAddrs(r.Addr),
|
||||||
|
Username: r.User,
|
||||||
Password: r.Pass,
|
Password: r.Pass,
|
||||||
MaxRetries: maxRetries,
|
MaxRetries: maxRetries,
|
||||||
PoolSize: 1,
|
PoolSize: 1,
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ func getClient(r *Redis) (*red.Client, error) {
|
|||||||
}
|
}
|
||||||
store := red.NewClient(&red.Options{
|
store := red.NewClient(&red.Options{
|
||||||
Addr: r.Addr,
|
Addr: r.Addr,
|
||||||
|
Username: r.User,
|
||||||
Password: r.Pass,
|
Password: r.Pass,
|
||||||
DB: defaultDatabase,
|
DB: defaultDatabase,
|
||||||
MaxRetries: maxRetries,
|
MaxRetries: maxRetries,
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ func getCluster(r *Redis) (*red.ClusterClient, error) {
|
|||||||
}
|
}
|
||||||
store := red.NewClusterClient(&red.ClusterOptions{
|
store := red.NewClusterClient(&red.ClusterOptions{
|
||||||
Addrs: splitClusterAddrs(r.Addr),
|
Addrs: splitClusterAddrs(r.Addr),
|
||||||
|
Username: r.User,
|
||||||
Password: r.Pass,
|
Password: r.Pass,
|
||||||
MaxRetries: maxRetries,
|
MaxRetries: maxRetries,
|
||||||
MinIdleConns: idleConns,
|
MinIdleConns: idleConns,
|
||||||
|
|||||||
29
core/stores/sqlx/config.go
Normal file
29
core/stores/sqlx/config.go
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
package sqlx
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
var (
|
||||||
|
errEmptyDatasource = errors.New("empty datasource")
|
||||||
|
errEmptyDriverName = errors.New("empty driver name")
|
||||||
|
)
|
||||||
|
|
||||||
|
// SqlConf defines the configuration for sqlx.
|
||||||
|
type SqlConf struct {
|
||||||
|
DataSource string
|
||||||
|
DriverName string `json:",default=mysql"`
|
||||||
|
Replicas []string `json:",optional"`
|
||||||
|
Policy string `json:",default=round-robin,options=round-robin|random"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate validates the SqlxConf.
|
||||||
|
func (sc SqlConf) Validate() error {
|
||||||
|
if len(sc.DataSource) == 0 {
|
||||||
|
return errEmptyDatasource
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(sc.DriverName) == 0 {
|
||||||
|
return errEmptyDriverName
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
29
core/stores/sqlx/config_test.go
Normal file
29
core/stores/sqlx/config_test.go
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
package sqlx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/zeromicro/go-zero/core/conf"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestValidate(t *testing.T) {
|
||||||
|
text := []byte(`DataSource: primary:password@tcp(127.0.0.1:3306)/primary_db
|
||||||
|
`)
|
||||||
|
|
||||||
|
var sc SqlConf
|
||||||
|
err := conf.LoadFromYamlBytes(text, &sc)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "mysql", sc.DriverName)
|
||||||
|
assert.Equal(t, policyRoundRobin, sc.Policy)
|
||||||
|
assert.Nil(t, sc.Validate())
|
||||||
|
|
||||||
|
sc = SqlConf{}
|
||||||
|
assert.Equal(t, errEmptyDatasource, sc.Validate())
|
||||||
|
|
||||||
|
sc.DataSource = "primary:password@tcp(127.0.0.1:3306)/primary_db"
|
||||||
|
assert.Equal(t, errEmptyDriverName, sc.Validate())
|
||||||
|
|
||||||
|
sc.DriverName = "mysql"
|
||||||
|
assert.Nil(t, sc.Validate())
|
||||||
|
}
|
||||||
@@ -267,6 +267,20 @@ func TestUnmarshalRowStruct(t *testing.T) {
|
|||||||
}, "select name, age from users where user=?", "anyone"), ErrNotMatchDestination)
|
}, "select name, age from users where user=?", "anyone"), ErrNotMatchDestination)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
value := new(struct {
|
||||||
|
Name string
|
||||||
|
age int
|
||||||
|
})
|
||||||
|
|
||||||
|
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
||||||
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
|
||||||
|
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRow(value, rows, true)
|
||||||
|
}, "select name, age from users where user=?", "anyone"), ErrNotMatchDestination)
|
||||||
|
})
|
||||||
|
|
||||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("8")
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("8")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
@@ -310,6 +324,20 @@ func TestUnmarshalRowStructWithTags(t *testing.T) {
|
|||||||
}, "select name, age from users where user=?", "anyone"), ErrNotReadableValue)
|
}, "select name, age from users where user=?", "anyone"), ErrNotReadableValue)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
value := new(struct {
|
||||||
|
age int `db:"age"`
|
||||||
|
Name string `db:"name"`
|
||||||
|
})
|
||||||
|
|
||||||
|
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
||||||
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
|
||||||
|
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRow(value, rows, true)
|
||||||
|
}, "select name, age from users where user=?", "anyone"), ErrNotReadableValue)
|
||||||
|
})
|
||||||
|
|
||||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
var value struct {
|
var value struct {
|
||||||
Age *int `db:"age"`
|
Age *int `db:"age"`
|
||||||
@@ -1307,6 +1335,7 @@ func TestAnonymousStructPr(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAnonymousStructPrError(t *testing.T) {
|
func TestAnonymousStructPrError(t *testing.T) {
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
type Score struct {
|
type Score struct {
|
||||||
Discipline string `db:"discipline"`
|
Discipline string `db:"discipline"`
|
||||||
score uint `db:"score"`
|
score uint `db:"score"`
|
||||||
@@ -1319,13 +1348,12 @@ func TestAnonymousStructPrError(t *testing.T) {
|
|||||||
*ClassType
|
*ClassType
|
||||||
Score
|
Score
|
||||||
}
|
}
|
||||||
|
|
||||||
var value []*struct {
|
var value []*struct {
|
||||||
Age int64 `db:"age"`
|
Age int64 `db:"age"`
|
||||||
Class
|
Class
|
||||||
Name string `db:"name"`
|
Name string `db:"name"`
|
||||||
}
|
}
|
||||||
|
|
||||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
|
||||||
rs := sqlmock.NewRows([]string{
|
rs := sqlmock.NewRows([]string{
|
||||||
"name",
|
"name",
|
||||||
"age",
|
"age",
|
||||||
@@ -1338,10 +1366,50 @@ func TestAnonymousStructPrError(t *testing.T) {
|
|||||||
AddRow("second", 3, "grade one", "chinese", "class three grade two", 99)
|
AddRow("second", 3, "grade one", "chinese", "class three grade two", 99)
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").
|
mock.ExpectQuery("select (.+) from users where user=?").
|
||||||
WithArgs("anyone").WillReturnRows(rs)
|
WithArgs("anyone").WillReturnRows(rs)
|
||||||
assert.Error(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
return unmarshalRows(&value, rows, true)
|
return unmarshalRows(&value, rows, true)
|
||||||
}, "select name, age, grade, discipline, class_name, score from users where user=?",
|
}, "select name, age, grade, discipline, class_name, score from users where user=?",
|
||||||
"anyone"))
|
"anyone"), ErrNotReadableValue)
|
||||||
|
if len(value) > 0 {
|
||||||
|
assert.Equal(t, value[0].score, 0)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
type Score struct {
|
||||||
|
Discipline string
|
||||||
|
score uint
|
||||||
|
}
|
||||||
|
type ClassType struct {
|
||||||
|
Grade sql.NullString
|
||||||
|
ClassName *string
|
||||||
|
}
|
||||||
|
type Class struct {
|
||||||
|
*ClassType
|
||||||
|
Score
|
||||||
|
}
|
||||||
|
|
||||||
|
var value []*struct {
|
||||||
|
Age int64
|
||||||
|
Class
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
rs := sqlmock.NewRows([]string{
|
||||||
|
"name",
|
||||||
|
"age",
|
||||||
|
"grade",
|
||||||
|
"discipline",
|
||||||
|
"class_name",
|
||||||
|
"score",
|
||||||
|
}).
|
||||||
|
AddRow("first", 2, nil, "math", "experimental class", 100).
|
||||||
|
AddRow("second", 3, "grade one", "chinese", "class three grade two", 99)
|
||||||
|
mock.ExpectQuery("select (.+) from users where user=?").
|
||||||
|
WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRows(&value, rows, true)
|
||||||
|
}, "select name, age, grade, discipline, class_name, score from users where user=?",
|
||||||
|
"anyone"), ErrNotMatchDestination)
|
||||||
if len(value) > 0 {
|
if len(value) > 0 {
|
||||||
assert.Equal(t, value[0].score, 0)
|
assert.Equal(t, value[0].score, 0)
|
||||||
}
|
}
|
||||||
|
|||||||
65
core/stores/sqlx/rwstrategy.go
Normal file
65
core/stores/sqlx/rwstrategy.go
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
package sqlx
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
const (
|
||||||
|
// policyRoundRobin round-robin policy for selecting replicas.
|
||||||
|
policyRoundRobin = "round-robin"
|
||||||
|
// policyRandom random policy for selecting replicas.
|
||||||
|
policyRandom = "random"
|
||||||
|
|
||||||
|
// readPrimaryMode indicates that the operation is a read,
|
||||||
|
// but should be performed on the primary database instance.
|
||||||
|
//
|
||||||
|
// This mode is used in scenarios where data freshness and consistency are critical,
|
||||||
|
// such as immediately after writes or where replication lag may cause stale reads.
|
||||||
|
readPrimaryMode readWriteMode = "read-primary"
|
||||||
|
|
||||||
|
// readReplicaMode indicates that the operation is a read from replicas.
|
||||||
|
// This is suitable for scenarios where eventual consistency is acceptable,
|
||||||
|
// and the goal is to offload traffic from the primary and improve read scalability.
|
||||||
|
readReplicaMode readWriteMode = "read-replica"
|
||||||
|
|
||||||
|
// writeMode indicates that the operation is a write operation (to primary).
|
||||||
|
writeMode readWriteMode = "write"
|
||||||
|
|
||||||
|
// notSpecifiedMode indicates that the read/write mode is not specified.
|
||||||
|
notSpecifiedMode readWriteMode = ""
|
||||||
|
)
|
||||||
|
|
||||||
|
type readWriteModeKey struct{}
|
||||||
|
|
||||||
|
// WithReadPrimary sets the context to read-primary mode.
|
||||||
|
func WithReadPrimary(ctx context.Context) context.Context {
|
||||||
|
return context.WithValue(ctx, readWriteModeKey{}, readPrimaryMode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithReadReplica sets the context to read-replica mode.
|
||||||
|
func WithReadReplica(ctx context.Context) context.Context {
|
||||||
|
return context.WithValue(ctx, readWriteModeKey{}, readReplicaMode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithWrite sets the context to write mode, indicating that the operation is a write operation.
|
||||||
|
func WithWrite(ctx context.Context) context.Context {
|
||||||
|
return context.WithValue(ctx, readWriteModeKey{}, writeMode)
|
||||||
|
}
|
||||||
|
|
||||||
|
type readWriteMode string
|
||||||
|
|
||||||
|
func (m readWriteMode) isValid() bool {
|
||||||
|
return m == readPrimaryMode || m == readReplicaMode || m == writeMode
|
||||||
|
}
|
||||||
|
|
||||||
|
func getReadWriteMode(ctx context.Context) readWriteMode {
|
||||||
|
if mode := ctx.Value(readWriteModeKey{}); mode != nil {
|
||||||
|
if v, ok := mode.(readWriteMode); ok && v.isValid() {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return notSpecifiedMode
|
||||||
|
}
|
||||||
|
|
||||||
|
func usePrimary(ctx context.Context) bool {
|
||||||
|
return getReadWriteMode(ctx) != readReplicaMode
|
||||||
|
}
|
||||||
142
core/stores/sqlx/rwstrategy_test.go
Normal file
142
core/stores/sqlx/rwstrategy_test.go
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
package sqlx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIsValid(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
mode readWriteMode
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid read-primary mode",
|
||||||
|
mode: readPrimaryMode,
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid read-replica mode",
|
||||||
|
mode: readReplicaMode,
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid write mode",
|
||||||
|
mode: writeMode,
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "not specified mode (empty)",
|
||||||
|
mode: notSpecifiedMode,
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid custom string",
|
||||||
|
mode: readWriteMode("delete"),
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "case sensitive check",
|
||||||
|
mode: readWriteMode("READ"),
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
actual := tc.mode.isValid()
|
||||||
|
assert.Equal(t, tc.expected, actual)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithReadMode(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
readPrimaryCtx := WithReadPrimary(ctx)
|
||||||
|
|
||||||
|
val := readPrimaryCtx.Value(readWriteModeKey{})
|
||||||
|
assert.Equal(t, readPrimaryMode, val)
|
||||||
|
|
||||||
|
readReplicaCtx := WithReadReplica(ctx)
|
||||||
|
val = readReplicaCtx.Value(readWriteModeKey{})
|
||||||
|
assert.Equal(t, readReplicaMode, val)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithWriteMode(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
writeCtx := WithWrite(ctx)
|
||||||
|
|
||||||
|
val := writeCtx.Value(readWriteModeKey{})
|
||||||
|
assert.Equal(t, writeMode, val)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetReadWriteMode(t *testing.T) {
|
||||||
|
t.Run("valid read-primary mode", func(t *testing.T) {
|
||||||
|
ctx := context.WithValue(context.Background(), readWriteModeKey{}, readPrimaryMode)
|
||||||
|
assert.Equal(t, readPrimaryMode, getReadWriteMode(ctx))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("valid read-replica mode", func(t *testing.T) {
|
||||||
|
ctx := context.WithValue(context.Background(), readWriteModeKey{}, readReplicaMode)
|
||||||
|
assert.Equal(t, readReplicaMode, getReadWriteMode(ctx))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("valid write mode", func(t *testing.T) {
|
||||||
|
ctx := context.WithValue(context.Background(), readWriteModeKey{}, writeMode)
|
||||||
|
assert.Equal(t, writeMode, getReadWriteMode(ctx))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid mode value (wrong type)", func(t *testing.T) {
|
||||||
|
ctx := context.WithValue(context.Background(), readWriteModeKey{}, "not-a-mode")
|
||||||
|
assert.Equal(t, notSpecifiedMode, getReadWriteMode(ctx))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid mode value (wrong value)", func(t *testing.T) {
|
||||||
|
ctx := context.WithValue(context.Background(), readWriteModeKey{}, readWriteMode("delete"))
|
||||||
|
assert.Equal(t, notSpecifiedMode, getReadWriteMode(ctx))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("no mode set", func(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
assert.Equal(t, notSpecifiedMode, getReadWriteMode(ctx))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsePrimary(t *testing.T) {
|
||||||
|
t.Run("context with read-replica mode", func(t *testing.T) {
|
||||||
|
ctx := context.WithValue(context.Background(), readWriteModeKey{}, readReplicaMode)
|
||||||
|
assert.False(t, usePrimary(ctx))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("context with read-primary mode", func(t *testing.T) {
|
||||||
|
ctx := context.WithValue(context.Background(), readWriteModeKey{}, readPrimaryMode)
|
||||||
|
assert.True(t, usePrimary(ctx))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("context with write mode", func(t *testing.T) {
|
||||||
|
ctx := context.WithValue(context.Background(), readWriteModeKey{}, writeMode)
|
||||||
|
assert.True(t, usePrimary(ctx))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("context with invalid mode", func(t *testing.T) {
|
||||||
|
ctx := context.WithValue(context.Background(), readWriteModeKey{}, readWriteMode("invalid"))
|
||||||
|
assert.True(t, usePrimary(ctx))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("context with no mode set", func(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
assert.True(t, usePrimary(ctx))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithModeTwice(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
ctx = WithReadPrimary(ctx)
|
||||||
|
writeCtx := WithWrite(ctx)
|
||||||
|
|
||||||
|
val := writeCtx.Value(readWriteModeKey{})
|
||||||
|
assert.Equal(t, writeMode, val)
|
||||||
|
}
|
||||||
@@ -4,6 +4,9 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/core/breaker"
|
"github.com/zeromicro/go-zero/core/breaker"
|
||||||
"github.com/zeromicro/go-zero/core/errorx"
|
"github.com/zeromicro/go-zero/core/errorx"
|
||||||
@@ -52,9 +55,10 @@ type (
|
|||||||
beginTx beginnable
|
beginTx beginnable
|
||||||
brk breaker.Breaker
|
brk breaker.Breaker
|
||||||
accept breaker.Acceptable
|
accept breaker.Acceptable
|
||||||
|
index uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
connProvider func() (*sql.DB, error)
|
connProvider func(ctx context.Context) (*sql.DB, error)
|
||||||
|
|
||||||
sessionConn interface {
|
sessionConn interface {
|
||||||
Exec(query string, args ...any) (sql.Result, error)
|
Exec(query string, args ...any) (sql.Result, error)
|
||||||
@@ -64,10 +68,41 @@ type (
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// MustNewConn returns a SqlConn with the given SqlConf.
|
||||||
|
func MustNewConn(c SqlConf, opts ...SqlOption) SqlConn {
|
||||||
|
conn, err := NewConn(c, opts...)
|
||||||
|
if err != nil {
|
||||||
|
logx.Must(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return conn
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewConn returns a SqlConn with the given SqlConf.
|
||||||
|
func NewConn(c SqlConf, opts ...SqlOption) (SqlConn, error) {
|
||||||
|
if err := c.Validate(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
conn := &commonSqlConn{
|
||||||
|
onError: func(ctx context.Context, err error) {
|
||||||
|
logInstanceError(ctx, c.DataSource, err)
|
||||||
|
},
|
||||||
|
beginTx: begin,
|
||||||
|
brk: breaker.NewBreaker(),
|
||||||
|
}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(conn)
|
||||||
|
}
|
||||||
|
conn.connProv = getConnProvider(conn, c.DriverName, c.DataSource, c.Policy, c.Replicas)
|
||||||
|
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
// NewSqlConn returns a SqlConn with given driver name and datasource.
|
// NewSqlConn returns a SqlConn with given driver name and datasource.
|
||||||
func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
|
func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
|
||||||
conn := &commonSqlConn{
|
conn := &commonSqlConn{
|
||||||
connProv: func() (*sql.DB, error) {
|
connProv: func(context.Context) (*sql.DB, error) {
|
||||||
return getSqlConn(driverName, datasource)
|
return getSqlConn(driverName, datasource)
|
||||||
},
|
},
|
||||||
onError: func(ctx context.Context, err error) {
|
onError: func(ctx context.Context, err error) {
|
||||||
@@ -87,7 +122,7 @@ func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
|
|||||||
// Use it with caution; it's provided for other ORM to interact with.
|
// Use it with caution; it's provided for other ORM to interact with.
|
||||||
func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn {
|
func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn {
|
||||||
conn := &commonSqlConn{
|
conn := &commonSqlConn{
|
||||||
connProv: func() (*sql.DB, error) {
|
connProv: func(ctx context.Context) (*sql.DB, error) {
|
||||||
return db, nil
|
return db, nil
|
||||||
},
|
},
|
||||||
onError: func(ctx context.Context, err error) {
|
onError: func(ctx context.Context, err error) {
|
||||||
@@ -123,7 +158,7 @@ func (db *commonSqlConn) ExecCtx(ctx context.Context, q string, args ...any) (
|
|||||||
|
|
||||||
err = db.brk.DoWithAcceptableCtx(ctx, func() error {
|
err = db.brk.DoWithAcceptableCtx(ctx, func() error {
|
||||||
var conn *sql.DB
|
var conn *sql.DB
|
||||||
conn, err = db.connProv()
|
conn, err = db.connProv(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
db.onError(ctx, err)
|
db.onError(ctx, err)
|
||||||
return err
|
return err
|
||||||
@@ -151,7 +186,7 @@ func (db *commonSqlConn) PrepareCtx(ctx context.Context, query string) (stmt Stm
|
|||||||
|
|
||||||
err = db.brk.DoWithAcceptableCtx(ctx, func() error {
|
err = db.brk.DoWithAcceptableCtx(ctx, func() error {
|
||||||
var conn *sql.DB
|
var conn *sql.DB
|
||||||
conn, err = db.connProv()
|
conn, err = db.connProv(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
db.onError(ctx, err)
|
db.onError(ctx, err)
|
||||||
return err
|
return err
|
||||||
@@ -242,7 +277,7 @@ func (db *commonSqlConn) QueryRowsPartialCtx(ctx context.Context, v any,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (db *commonSqlConn) RawDB() (*sql.DB, error) {
|
func (db *commonSqlConn) RawDB() (*sql.DB, error) {
|
||||||
return db.connProv()
|
return db.connProv(context.Background())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *commonSqlConn) Transact(fn func(Session) error) error {
|
func (db *commonSqlConn) Transact(fn func(Session) error) error {
|
||||||
@@ -288,7 +323,7 @@ func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows)
|
|||||||
q string, args ...any) (err error) {
|
q string, args ...any) (err error) {
|
||||||
var scanFailed bool
|
var scanFailed bool
|
||||||
err = db.brk.DoWithAcceptableCtx(ctx, func() error {
|
err = db.brk.DoWithAcceptableCtx(ctx, func() error {
|
||||||
conn, err := db.connProv()
|
conn, err := db.connProv(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
db.onError(ctx, err)
|
db.onError(ctx, err)
|
||||||
return err
|
return err
|
||||||
@@ -311,6 +346,38 @@ func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows)
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getConnProvider(sc *commonSqlConn, driverName, datasource, policy string, replicas []string) connProvider {
|
||||||
|
return func(ctx context.Context) (*sql.DB, error) {
|
||||||
|
replicaCount := len(replicas)
|
||||||
|
|
||||||
|
if replicaCount == 0 || usePrimary(ctx) {
|
||||||
|
return getSqlConn(driverName, datasource)
|
||||||
|
}
|
||||||
|
|
||||||
|
var dsn string
|
||||||
|
|
||||||
|
if replicaCount == 1 {
|
||||||
|
dsn = replicas[0]
|
||||||
|
} else {
|
||||||
|
if len(policy) == 0 {
|
||||||
|
policy = policyRoundRobin
|
||||||
|
}
|
||||||
|
|
||||||
|
switch policy {
|
||||||
|
case policyRandom:
|
||||||
|
dsn = replicas[rand.Intn(replicaCount)]
|
||||||
|
case policyRoundRobin:
|
||||||
|
index := atomic.AddUint32(&sc.index, 1) - 1
|
||||||
|
dsn = replicas[index%uint32(replicaCount)]
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unknown policy: %s", policy)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return getSqlConn(driverName, dsn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// WithAcceptable returns a SqlOption that setting the acceptable function.
|
// WithAcceptable returns a SqlOption that setting the acceptable function.
|
||||||
// acceptable is the func to check if the error can be accepted.
|
// acceptable is the func to check if the error can be accepted.
|
||||||
func WithAcceptable(acceptable func(err error) bool) SqlOption {
|
func WithAcceptable(acceptable func(err error) bool) SqlOption {
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package sqlx
|
package sqlx
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
@@ -98,7 +99,7 @@ func TestSqlConn_RawDB(t *testing.T) {
|
|||||||
func TestSqlConn_Errors(t *testing.T) {
|
func TestSqlConn_Errors(t *testing.T) {
|
||||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
conn := NewSqlConnFromDB(db)
|
conn := NewSqlConnFromDB(db)
|
||||||
conn.(*commonSqlConn).connProv = func() (*sql.DB, error) {
|
conn.(*commonSqlConn).connProv = func(ctx context.Context) (*sql.DB, error) {
|
||||||
return nil, errors.New("error")
|
return nil, errors.New("error")
|
||||||
}
|
}
|
||||||
_, err := conn.Prepare("any")
|
_, err := conn.Prepare("any")
|
||||||
@@ -138,6 +139,148 @@ func TestSqlConn_Errors(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConfigSqlConn(t *testing.T) {
|
||||||
|
db, mock, err := sqlmock.New()
|
||||||
|
assert.NotNil(t, db)
|
||||||
|
assert.NotNil(t, mock)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
connManager.Inject(mockedDatasource, db)
|
||||||
|
mock.ExpectExec("any")
|
||||||
|
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"foo"}))
|
||||||
|
|
||||||
|
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
|
||||||
|
conn := MustNewConn(conf, withMysqlAcceptable())
|
||||||
|
|
||||||
|
_, err = conn.Exec("any", "value")
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
_, err = conn.Prepare("any")
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
|
||||||
|
var val string
|
||||||
|
assert.NotNil(t, conn.QueryRow(&val, "any"))
|
||||||
|
assert.NotNil(t, conn.QueryRowPartial(&val, "any"))
|
||||||
|
assert.NotNil(t, conn.QueryRows(&val, "any"))
|
||||||
|
assert.NotNil(t, conn.QueryRowsPartial(&val, "any"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigSqlConnStatement(t *testing.T) {
|
||||||
|
db, mock, err := sqlmock.New()
|
||||||
|
assert.NotNil(t, db)
|
||||||
|
assert.NotNil(t, mock)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
connManager.Inject(mockedDatasource, db)
|
||||||
|
|
||||||
|
mock.ExpectPrepare("any")
|
||||||
|
mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
|
||||||
|
mock.ExpectPrepare("any")
|
||||||
|
row := sqlmock.NewRows([]string{"foo"}).AddRow("bar")
|
||||||
|
mock.ExpectQuery("any").WillReturnRows(row)
|
||||||
|
|
||||||
|
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
|
||||||
|
conn := MustNewConn(conf, withMysqlAcceptable())
|
||||||
|
stmt, err := conn.Prepare("any")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
res, err := stmt.Exec()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
lastInsertID, err := res.LastInsertId()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, int64(2), lastInsertID)
|
||||||
|
rowsAffected, err := res.RowsAffected()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, int64(3), rowsAffected)
|
||||||
|
|
||||||
|
stmt, err = conn.Prepare("any")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
var val string
|
||||||
|
err = stmt.QueryRow(&val)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "bar", val)
|
||||||
|
|
||||||
|
mock.ExpectPrepare("any")
|
||||||
|
rows := sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar")
|
||||||
|
mock.ExpectQuery("any").WillReturnRows(rows)
|
||||||
|
|
||||||
|
stmt, err = conn.Prepare("any")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
var vals []string
|
||||||
|
assert.NoError(t, stmt.QueryRowsPartial(&vals))
|
||||||
|
assert.ElementsMatch(t, []string{"foo", "bar"}, vals)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigSqlConnQuery(t *testing.T) {
|
||||||
|
db, mock, err := sqlmock.New()
|
||||||
|
assert.NotNil(t, db)
|
||||||
|
assert.NotNil(t, mock)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
connManager.Inject(mockedDatasource, db)
|
||||||
|
|
||||||
|
t.Run("QueryRow", func(t *testing.T) {
|
||||||
|
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"foo"}).AddRow("bar"))
|
||||||
|
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
|
||||||
|
conn := MustNewConn(conf)
|
||||||
|
var val string
|
||||||
|
assert.NoError(t, conn.QueryRow(&val, "any"))
|
||||||
|
assert.Equal(t, "bar", val)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("QueryRowPartial", func(t *testing.T) {
|
||||||
|
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"foo"}).AddRow("bar"))
|
||||||
|
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
|
||||||
|
conn := MustNewConn(conf)
|
||||||
|
var val string
|
||||||
|
assert.NoError(t, conn.QueryRowPartial(&val, "any"))
|
||||||
|
assert.Equal(t, "bar", val)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("QueryRows", func(t *testing.T) {
|
||||||
|
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar"))
|
||||||
|
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
|
||||||
|
conn := MustNewConn(conf)
|
||||||
|
var vals []string
|
||||||
|
assert.NoError(t, conn.QueryRows(&vals, "any"))
|
||||||
|
assert.ElementsMatch(t, []string{"foo", "bar"}, vals)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("QueryRowsPartial", func(t *testing.T) {
|
||||||
|
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar"))
|
||||||
|
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
|
||||||
|
conn := MustNewConn(conf)
|
||||||
|
var vals []string
|
||||||
|
assert.NoError(t, conn.QueryRowsPartial(&vals, "any"))
|
||||||
|
assert.ElementsMatch(t, []string{"foo", "bar"}, vals)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigSqlConnErr(t *testing.T) {
|
||||||
|
t.Run("panic on empty config", func(t *testing.T) {
|
||||||
|
original := logx.ExitOnFatal.True()
|
||||||
|
logx.ExitOnFatal.Set(false)
|
||||||
|
defer logx.ExitOnFatal.Set(original)
|
||||||
|
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
MustNewConn(SqlConf{})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
t.Run("on error", func(t *testing.T) {
|
||||||
|
db, mock, err := sqlmock.New()
|
||||||
|
assert.NotNil(t, db)
|
||||||
|
assert.NotNil(t, mock)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
connManager.Inject(mockedDatasource, db)
|
||||||
|
|
||||||
|
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
|
||||||
|
conn := MustNewConn(conf)
|
||||||
|
conn.(*commonSqlConn).connProv = func(ctx context.Context) (*sql.DB, error) {
|
||||||
|
return nil, errors.New("error")
|
||||||
|
}
|
||||||
|
_, err = conn.Prepare("any")
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestStatement(t *testing.T) {
|
func TestStatement(t *testing.T) {
|
||||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
mock.ExpectPrepare("any").WillBeClosed()
|
mock.ExpectPrepare("any").WillBeClosed()
|
||||||
@@ -303,6 +446,93 @@ func TestWithAcceptable(t *testing.T) {
|
|||||||
assert.True(t, conn.accept(acceptableErr3))
|
assert.True(t, conn.accept(acceptableErr3))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProvider(t *testing.T) {
|
||||||
|
defer func() {
|
||||||
|
_ = connManager.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
primaryDSN := "primary:password@tcp(127.0.0.1:3306)/primary_db"
|
||||||
|
replicasDSN := []string{
|
||||||
|
"replica_one:pwd@tcp(localhost:3306)/replica_one",
|
||||||
|
"replica_two:pwd@tcp(localhost:3306)/replica_two",
|
||||||
|
"replica_three:pwd@tcp(localhost:3306)/replica_three",
|
||||||
|
}
|
||||||
|
|
||||||
|
primaryDB, err := connManager.GetResource(primaryDSN, func() (io.Closer, error) { return sql.Open(mysqlDriverName, primaryDSN) })
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.NotNil(t, primaryDB)
|
||||||
|
replicaOneDB, err := connManager.GetResource(replicasDSN[0], func() (io.Closer, error) { return sql.Open(mysqlDriverName, replicasDSN[0]) })
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.NotNil(t, replicaOneDB)
|
||||||
|
replicaTwoDB, err := connManager.GetResource(replicasDSN[1], func() (io.Closer, error) { return sql.Open(mysqlDriverName, replicasDSN[1]) })
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.NotNil(t, replicaTwoDB)
|
||||||
|
replicaThreeDB, err := connManager.GetResource(replicasDSN[2], func() (io.Closer, error) { return sql.Open(mysqlDriverName, replicasDSN[2]) })
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.NotNil(t, replicaThreeDB)
|
||||||
|
|
||||||
|
sc := &commonSqlConn{}
|
||||||
|
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRoundRobin, nil)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
db, err := sc.connProv(ctx)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, primaryDB, db)
|
||||||
|
|
||||||
|
ctx = WithWrite(ctx)
|
||||||
|
db, err = sc.connProv(ctx)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, primaryDB, db)
|
||||||
|
|
||||||
|
ctx = WithReadPrimary(ctx)
|
||||||
|
db, err = sc.connProv(ctx)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, primaryDB, db)
|
||||||
|
|
||||||
|
// no mode set, should return primary
|
||||||
|
ctx = context.Background()
|
||||||
|
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRoundRobin, replicasDSN)
|
||||||
|
db, err = sc.connProv(ctx)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, primaryDB, db)
|
||||||
|
|
||||||
|
ctx = WithReadReplica(ctx)
|
||||||
|
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRoundRobin, []string{replicasDSN[0]})
|
||||||
|
db, err = sc.connProv(ctx)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, replicaOneDB, db)
|
||||||
|
|
||||||
|
// default policy is round-robin
|
||||||
|
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRoundRobin, replicasDSN)
|
||||||
|
replicas := []io.Closer{replicaOneDB, replicaTwoDB, replicaThreeDB}
|
||||||
|
for i := 0; i < len(replicasDSN); i++ {
|
||||||
|
db, err = sc.connProv(ctx)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, replicas[i], db)
|
||||||
|
}
|
||||||
|
|
||||||
|
// random policy
|
||||||
|
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRandom, replicasDSN)
|
||||||
|
for i := 0; i < len(replicasDSN); i++ {
|
||||||
|
db, err = sc.connProv(ctx)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Contains(t, replicas, db)
|
||||||
|
}
|
||||||
|
|
||||||
|
// unknown policy
|
||||||
|
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, "unknown", replicasDSN)
|
||||||
|
_, err = sc.connProv(ctx)
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
|
||||||
|
// empty policy transforms to round-robin
|
||||||
|
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, "", replicasDSN)
|
||||||
|
for i := 0; i < len(replicasDSN); i++ {
|
||||||
|
db, err = sc.connProv(ctx)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, replicas[i], db)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func buildConn() (mock sqlmock.Sqlmock, err error) {
|
func buildConn() (mock sqlmock.Sqlmock, err error) {
|
||||||
_, err = connManager.GetResource(mockedDatasource, func() (io.Closer, error) {
|
_, err = connManager.GetResource(mockedDatasource, func() (io.Closer, error) {
|
||||||
var db *sql.DB
|
var db *sql.DB
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ func getCachedSqlConn(driverName, server string) (*sql.DB, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if driverName != mysqlDriverName {
|
if driverName == mysqlDriverName {
|
||||||
if cfg, e := mysql.ParseDSN(server); e != nil {
|
if cfg, e := mysql.ParseDSN(server); e != nil {
|
||||||
// if cannot parse, don't collect the metrics
|
// if cannot parse, don't collect the metrics
|
||||||
logx.Error(e)
|
logx.Error(e)
|
||||||
|
|||||||
@@ -156,7 +156,7 @@ func begin(db *sql.DB) (trans, error) {
|
|||||||
|
|
||||||
func transact(ctx context.Context, db *commonSqlConn, b beginnable,
|
func transact(ctx context.Context, db *commonSqlConn, b beginnable,
|
||||||
fn func(context.Context, Session) error) (err error) {
|
fn func(context.Context, Session) error) (err error) {
|
||||||
conn, err := db.connProv()
|
conn, err := db.connProv(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
db.onError(ctx, err)
|
db.onError(ctx, err)
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -117,7 +117,7 @@ func TestTxExceptions(t *testing.T) {
|
|||||||
|
|
||||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
conn := &commonSqlConn{
|
conn := &commonSqlConn{
|
||||||
connProv: func() (*sql.DB, error) {
|
connProv: func(ctx context.Context) (*sql.DB, error) {
|
||||||
return nil, errors.New("foo")
|
return nil, errors.New("foo")
|
||||||
},
|
},
|
||||||
beginTx: begin,
|
beginTx: begin,
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package stringx
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"slices"
|
||||||
"unicode"
|
"unicode"
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/core/lang"
|
"github.com/zeromicro/go-zero/core/lang"
|
||||||
@@ -15,14 +16,9 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// Contains checks if str is in list.
|
// Contains checks if str is in list.
|
||||||
|
// Deprecated: use slices.Contains instead.
|
||||||
func Contains(list []string, str string) bool {
|
func Contains(list []string, str string) bool {
|
||||||
for _, each := range list {
|
return slices.Contains(list, str)
|
||||||
if each == str {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Filter filters chars from s with given filter function.
|
// Filter filters chars from s with given filter function.
|
||||||
@@ -123,11 +119,7 @@ func Remove(strings []string, strs ...string) []string {
|
|||||||
// Reverse reverses s.
|
// Reverse reverses s.
|
||||||
func Reverse(s string) string {
|
func Reverse(s string) string {
|
||||||
runes := []rune(s)
|
runes := []rune(s)
|
||||||
|
slices.Reverse(runes)
|
||||||
for from, to := 0, len(runes)-1; from < to; from, to = from+1, to-1 {
|
|
||||||
runes[from], runes[to] = runes[to], runes[from]
|
|
||||||
}
|
|
||||||
|
|
||||||
return string(runes)
|
return string(runes)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,28 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestContainsString(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
slice []string
|
||||||
|
value string
|
||||||
|
expect bool
|
||||||
|
}{
|
||||||
|
{[]string{"1"}, "1", true},
|
||||||
|
{[]string{"1"}, "2", false},
|
||||||
|
{[]string{"1", "2"}, "1", true},
|
||||||
|
{[]string{"1", "2"}, "3", false},
|
||||||
|
{nil, "3", false},
|
||||||
|
{nil, "", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, each := range cases {
|
||||||
|
t.Run(path.Join(each.slice...), func(t *testing.T) {
|
||||||
|
actual := Contains(each.slice, each.value)
|
||||||
|
assert.Equal(t, each.expect, actual)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestNotEmpty(t *testing.T) {
|
func TestNotEmpty(t *testing.T) {
|
||||||
cases := []struct {
|
cases := []struct {
|
||||||
args []string
|
args []string
|
||||||
@@ -41,28 +63,6 @@ func TestNotEmpty(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestContainsString(t *testing.T) {
|
|
||||||
cases := []struct {
|
|
||||||
slice []string
|
|
||||||
value string
|
|
||||||
expect bool
|
|
||||||
}{
|
|
||||||
{[]string{"1"}, "1", true},
|
|
||||||
{[]string{"1"}, "2", false},
|
|
||||||
{[]string{"1", "2"}, "1", true},
|
|
||||||
{[]string{"1", "2"}, "3", false},
|
|
||||||
{nil, "3", false},
|
|
||||||
{nil, "", false},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, each := range cases {
|
|
||||||
t.Run(path.Join(each.slice...), func(t *testing.T) {
|
|
||||||
actual := Contains(each.slice, each.value)
|
|
||||||
assert.Equal(t, each.expect, actual)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFilter(t *testing.T) {
|
func TestFilter(t *testing.T) {
|
||||||
cases := []struct {
|
cases := []struct {
|
||||||
input string
|
input string
|
||||||
|
|||||||
@@ -3,9 +3,7 @@ package syncx
|
|||||||
import "sync"
|
import "sync"
|
||||||
|
|
||||||
// Once returns a func that guarantees fn can only called once.
|
// Once returns a func that guarantees fn can only called once.
|
||||||
|
// Deprecated: use sync.OnceFunc instead.
|
||||||
func Once(fn func()) func() {
|
func Once(fn func()) func() {
|
||||||
once := new(sync.Once)
|
return sync.OnceFunc(fn)
|
||||||
return func() {
|
|
||||||
once.Do(fn)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"runtime"
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
const factor = 10
|
const factor = 10
|
||||||
@@ -100,6 +101,6 @@ func (r *StableRunner[I, O]) Wait() {
|
|||||||
close(r.done)
|
close(r.done)
|
||||||
r.runner.Wait()
|
r.runner.Wait()
|
||||||
for atomic.LoadUint64(&r.consumedIndex) < atomic.LoadUint64(&r.writtenIndex) {
|
for atomic.LoadUint64(&r.consumedIndex) < atomic.LoadUint64(&r.writtenIndex) {
|
||||||
runtime.Gosched()
|
time.Sleep(time.Millisecond)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
package utils
|
package utils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"cmp"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/core/mathx"
|
|
||||||
"github.com/zeromicro/go-zero/core/stringx"
|
"github.com/zeromicro/go-zero/core/stringx"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -39,7 +39,7 @@ func compare(v1, v2 string) int {
|
|||||||
fields1, fields2 := strings.Split(v1, "."), strings.Split(v2, ".")
|
fields1, fields2 := strings.Split(v1, "."), strings.Split(v2, ".")
|
||||||
ver1, ver2 := strsToInts(fields1), strsToInts(fields2)
|
ver1, ver2 := strsToInts(fields1), strsToInts(fields2)
|
||||||
ver1len, ver2len := len(ver1), len(ver2)
|
ver1len, ver2len := len(ver1), len(ver2)
|
||||||
shorter := mathx.MinInt(ver1len, ver2len)
|
shorter := min(ver1len, ver2len)
|
||||||
|
|
||||||
for i := 0; i < shorter; i++ {
|
for i := 0; i < shorter; i++ {
|
||||||
if ver1[i] == ver2[i] {
|
if ver1[i] == ver2[i] {
|
||||||
@@ -50,14 +50,7 @@ func compare(v1, v2 string) int {
|
|||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return cmp.Compare(ver1len, ver2len)
|
||||||
if ver1len < ver2len {
|
|
||||||
return -1
|
|
||||||
} else if ver1len == ver2len {
|
|
||||||
return 0
|
|
||||||
} else {
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func strsToInts(strs []string) []int64 {
|
func strsToInts(strs []string) []int64 {
|
||||||
|
|||||||
@@ -185,6 +185,8 @@ func (s *Server) buildHttpHandler(target *HttpClientConf) http.HandlerFunc {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// set the timeout if it's configured, take effect only if it's greater than 0
|
||||||
|
// and less than the deadline of the original request
|
||||||
if target.Timeout > 0 {
|
if target.Timeout > 0 {
|
||||||
timeout := time.Duration(target.Timeout) * time.Millisecond
|
timeout := time.Duration(target.Timeout) * time.Millisecond
|
||||||
ctx, cancel := context.WithTimeout(r.Context(), timeout)
|
ctx, cancel := context.WithTimeout(r.Context(), timeout)
|
||||||
@@ -276,7 +278,7 @@ func buildRequestWithNewTarget(r *http.Request, target *HttpClientConf) (*http.R
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &http.Request{
|
newReq := &http.Request{
|
||||||
Method: r.Method,
|
Method: r.Method,
|
||||||
URL: &u,
|
URL: &u,
|
||||||
Header: r.Header.Clone(),
|
Header: r.Header.Clone(),
|
||||||
@@ -285,7 +287,10 @@ func buildRequestWithNewTarget(r *http.Request, target *HttpClientConf) (*http.R
|
|||||||
ProtoMinor: r.ProtoMinor,
|
ProtoMinor: r.ProtoMinor,
|
||||||
ContentLength: r.ContentLength,
|
ContentLength: r.ContentLength,
|
||||||
Body: io.NopCloser(r.Body),
|
Body: io.NopCloser(r.Body),
|
||||||
}, nil
|
}
|
||||||
|
|
||||||
|
// make sure the context is passed to the new request
|
||||||
|
return newReq.WithContext(r.Context()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func createDescriptorSource(cli zrpc.Client, up Upstream) (grpcurl.DescriptorSource, error) {
|
func createDescriptorSource(cli zrpc.Client, up Upstream) (grpcurl.DescriptorSource, error) {
|
||||||
|
|||||||
@@ -201,6 +201,13 @@ func TestHttpToHttp(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
|
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("method not allowed", func(t *testing.T) {
|
||||||
|
resp, err := httpc.Do(context.Background(), http.MethodPost,
|
||||||
|
"http://localhost:18882/api/ping", nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHttpToHttpBadUpstream(t *testing.T) {
|
func TestHttpToHttpBadUpstream(t *testing.T) {
|
||||||
|
|||||||
42
go.mod
42
go.mod
@@ -4,25 +4,25 @@ go 1.21
|
|||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/DATA-DOG/go-sqlmock v1.5.2
|
github.com/DATA-DOG/go-sqlmock v1.5.2
|
||||||
github.com/alicebob/miniredis/v2 v2.34.0
|
github.com/alicebob/miniredis/v2 v2.35.0
|
||||||
github.com/fatih/color v1.18.0
|
github.com/fatih/color v1.18.0
|
||||||
github.com/fullstorydev/grpcurl v1.9.2
|
github.com/fullstorydev/grpcurl v1.9.3
|
||||||
github.com/go-sql-driver/mysql v1.8.1
|
github.com/go-sql-driver/mysql v1.9.0
|
||||||
github.com/golang-jwt/jwt/v4 v4.5.1
|
github.com/golang-jwt/jwt/v4 v4.5.2
|
||||||
github.com/golang/mock v1.6.0
|
github.com/golang/mock v1.6.0
|
||||||
github.com/golang/protobuf v1.5.4
|
github.com/golang/protobuf v1.5.4
|
||||||
github.com/google/uuid v1.6.0
|
github.com/google/uuid v1.6.0
|
||||||
github.com/jackc/pgx/v5 v5.7.2
|
github.com/grafana/pyroscope-go v1.2.2
|
||||||
|
github.com/jackc/pgx/v5 v5.7.4
|
||||||
github.com/jhump/protoreflect v1.17.0
|
github.com/jhump/protoreflect v1.17.0
|
||||||
github.com/olekukonko/tablewriter v0.0.5
|
|
||||||
github.com/pelletier/go-toml/v2 v2.2.2
|
github.com/pelletier/go-toml/v2 v2.2.2
|
||||||
github.com/prometheus/client_golang v1.20.5
|
github.com/prometheus/client_golang v1.21.1
|
||||||
github.com/redis/go-redis/v9 v9.7.0
|
github.com/redis/go-redis/v9 v9.11.0
|
||||||
github.com/spaolacci/murmur3 v1.1.0
|
github.com/spaolacci/murmur3 v1.1.0
|
||||||
github.com/stretchr/testify v1.10.0
|
github.com/stretchr/testify v1.10.0
|
||||||
go.etcd.io/etcd/api/v3 v3.5.15
|
go.etcd.io/etcd/api/v3 v3.5.15
|
||||||
go.etcd.io/etcd/client/v3 v3.5.15
|
go.etcd.io/etcd/client/v3 v3.5.15
|
||||||
go.mongodb.org/mongo-driver v1.17.2
|
go.mongodb.org/mongo-driver v1.17.4
|
||||||
go.opentelemetry.io/otel v1.24.0
|
go.opentelemetry.io/otel v1.24.0
|
||||||
go.opentelemetry.io/otel/exporters/jaeger v1.17.0
|
go.opentelemetry.io/otel/exporters/jaeger v1.17.0
|
||||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.24.0
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.24.0
|
||||||
@@ -33,12 +33,12 @@ require (
|
|||||||
go.opentelemetry.io/otel/trace v1.24.0
|
go.opentelemetry.io/otel/trace v1.24.0
|
||||||
go.uber.org/automaxprocs v1.6.0
|
go.uber.org/automaxprocs v1.6.0
|
||||||
go.uber.org/goleak v1.3.0
|
go.uber.org/goleak v1.3.0
|
||||||
golang.org/x/net v0.34.0
|
golang.org/x/net v0.35.0
|
||||||
golang.org/x/sys v0.29.0
|
golang.org/x/sys v0.30.0
|
||||||
golang.org/x/time v0.9.0
|
golang.org/x/time v0.10.0
|
||||||
google.golang.org/genproto/googleapis/api v0.0.0-20240711142825-46eb208f015d
|
google.golang.org/genproto/googleapis/api v0.0.0-20240711142825-46eb208f015d
|
||||||
google.golang.org/grpc v1.65.0
|
google.golang.org/grpc v1.65.0
|
||||||
google.golang.org/protobuf v1.36.4
|
google.golang.org/protobuf v1.36.5
|
||||||
gopkg.in/cheggaaa/pb.v1 v1.0.28
|
gopkg.in/cheggaaa/pb.v1 v1.0.28
|
||||||
gopkg.in/h2non/gock.v1 v1.1.2
|
gopkg.in/h2non/gock.v1 v1.1.2
|
||||||
gopkg.in/yaml.v2 v2.4.0
|
gopkg.in/yaml.v2 v2.4.0
|
||||||
@@ -50,7 +50,6 @@ require (
|
|||||||
|
|
||||||
require (
|
require (
|
||||||
filippo.io/edwards25519 v1.1.0 // indirect
|
filippo.io/edwards25519 v1.1.0 // indirect
|
||||||
github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 // indirect
|
|
||||||
github.com/beorn7/perks v1.0.1 // indirect
|
github.com/beorn7/perks v1.0.1 // indirect
|
||||||
github.com/bufbuild/protocompile v0.14.1 // indirect
|
github.com/bufbuild/protocompile v0.14.1 // indirect
|
||||||
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
|
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
|
||||||
@@ -73,6 +72,7 @@ require (
|
|||||||
github.com/google/gnostic-models v0.6.8 // indirect
|
github.com/google/gnostic-models v0.6.8 // indirect
|
||||||
github.com/google/go-cmp v0.6.0 // indirect
|
github.com/google/go-cmp v0.6.0 // indirect
|
||||||
github.com/google/gofuzz v1.2.0 // indirect
|
github.com/google/gofuzz v1.2.0 // indirect
|
||||||
|
github.com/grafana/pyroscope-go/godeltaprof v0.1.8 // indirect
|
||||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 // indirect
|
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 // indirect
|
||||||
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 // indirect
|
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 // indirect
|
||||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||||
@@ -80,7 +80,7 @@ require (
|
|||||||
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||||
github.com/josharian/intern v1.0.0 // indirect
|
github.com/josharian/intern v1.0.0 // indirect
|
||||||
github.com/json-iterator/go v1.1.12 // indirect
|
github.com/json-iterator/go v1.1.12 // indirect
|
||||||
github.com/klauspost/compress v1.17.9 // indirect
|
github.com/klauspost/compress v1.17.11 // indirect
|
||||||
github.com/kylelemons/godebug v1.1.0 // indirect
|
github.com/kylelemons/godebug v1.1.0 // indirect
|
||||||
github.com/mailru/easyjson v0.7.7 // indirect
|
github.com/mailru/easyjson v0.7.7 // indirect
|
||||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||||
@@ -93,7 +93,7 @@ require (
|
|||||||
github.com/openzipkin/zipkin-go v0.4.3 // indirect
|
github.com/openzipkin/zipkin-go v0.4.3 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
github.com/prometheus/client_model v0.6.1 // indirect
|
github.com/prometheus/client_model v0.6.1 // indirect
|
||||||
github.com/prometheus/common v0.55.0 // indirect
|
github.com/prometheus/common v0.62.0 // indirect
|
||||||
github.com/prometheus/procfs v0.15.1 // indirect
|
github.com/prometheus/procfs v0.15.1 // indirect
|
||||||
github.com/rivo/uniseg v0.2.0 // indirect
|
github.com/rivo/uniseg v0.2.0 // indirect
|
||||||
github.com/stretchr/objx v0.5.2 // indirect
|
github.com/stretchr/objx v0.5.2 // indirect
|
||||||
@@ -109,11 +109,11 @@ require (
|
|||||||
go.uber.org/atomic v1.10.0 // indirect
|
go.uber.org/atomic v1.10.0 // indirect
|
||||||
go.uber.org/multierr v1.9.0 // indirect
|
go.uber.org/multierr v1.9.0 // indirect
|
||||||
go.uber.org/zap v1.24.0 // indirect
|
go.uber.org/zap v1.24.0 // indirect
|
||||||
golang.org/x/crypto v0.32.0 // indirect
|
golang.org/x/crypto v0.33.0 // indirect
|
||||||
golang.org/x/oauth2 v0.21.0 // indirect
|
golang.org/x/oauth2 v0.24.0 // indirect
|
||||||
golang.org/x/sync v0.10.0 // indirect
|
golang.org/x/sync v0.11.0 // indirect
|
||||||
golang.org/x/term v0.28.0 // indirect
|
golang.org/x/term v0.29.0 // indirect
|
||||||
golang.org/x/text v0.21.0 // indirect
|
golang.org/x/text v0.22.0 // indirect
|
||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240701130421-f6361c86f094 // indirect
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20240701130421-f6361c86f094 // indirect
|
||||||
gopkg.in/inf.v0 v0.9.1 // indirect
|
gopkg.in/inf.v0 v0.9.1 // indirect
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
|
|||||||
85
go.sum
85
go.sum
@@ -2,10 +2,8 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
|
|||||||
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
||||||
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
|
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
|
||||||
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
|
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
|
||||||
github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 h1:uvdUDbHQHO85qeSydJtItA4T55Pw6BtAejd0APRJOCE=
|
github.com/alicebob/miniredis/v2 v2.35.0 h1:QwLphYqCEAo1eu1TqPRN2jgVMPBweeQcR21jeqDCONI=
|
||||||
github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc=
|
github.com/alicebob/miniredis/v2 v2.35.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM=
|
||||||
github.com/alicebob/miniredis/v2 v2.34.0 h1:mBFWMaJSNL9RwdGRyEDoAAv8OQc5UlEhLDQggTglU/0=
|
|
||||||
github.com/alicebob/miniredis/v2 v2.34.0/go.mod h1:kWShP4b58T1CW0Y5dViCd5ztzrDqRWqM3nksiyXk5s8=
|
|
||||||
github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8=
|
github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8=
|
||||||
github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
|
github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
|
||||||
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
||||||
@@ -40,8 +38,8 @@ github.com/envoyproxy/protoc-gen-validate v1.0.4 h1:gVPz/FMfvh57HdSJQyvBtF00j8JU
|
|||||||
github.com/envoyproxy/protoc-gen-validate v1.0.4/go.mod h1:qys6tmnRsYrQqIhm2bvKZH4Blx/1gTIZ2UKVY1M+Yew=
|
github.com/envoyproxy/protoc-gen-validate v1.0.4/go.mod h1:qys6tmnRsYrQqIhm2bvKZH4Blx/1gTIZ2UKVY1M+Yew=
|
||||||
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
|
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
|
||||||
github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU=
|
github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU=
|
||||||
github.com/fullstorydev/grpcurl v1.9.2 h1:ObqVQTZW7aFnhuqQoppUrvep2duMBanB0UYK2Mm8euo=
|
github.com/fullstorydev/grpcurl v1.9.3 h1:PC1Xi3w+JAvEE2Tg2Gf2RfVgPbf9+tbuQr1ZkyVU3jk=
|
||||||
github.com/fullstorydev/grpcurl v1.9.2/go.mod h1:jLfcF55HAz6TYIJY9xFFWgsl0D7o2HlxA5Z4lUG0Tdo=
|
github.com/fullstorydev/grpcurl v1.9.3/go.mod h1:/b4Wxe8bG6ndAjlfSUjwseQReUDUvBJiFEB7UllOlUE=
|
||||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||||
github.com/go-logr/logr v1.3.0/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
github.com/go-logr/logr v1.3.0/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||||
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
|
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
|
||||||
@@ -55,15 +53,15 @@ github.com/go-openapi/jsonreference v0.20.2/go.mod h1:Bl1zwGIM8/wsvqjsOQLJ/SH+En
|
|||||||
github.com/go-openapi/swag v0.22.3/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+VcZ0yl14=
|
github.com/go-openapi/swag v0.22.3/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+VcZ0yl14=
|
||||||
github.com/go-openapi/swag v0.22.4 h1:QLMzNJnMGPRNDCbySlcj1x01tzU8/9LTTL9hZZZogBU=
|
github.com/go-openapi/swag v0.22.4 h1:QLMzNJnMGPRNDCbySlcj1x01tzU8/9LTTL9hZZZogBU=
|
||||||
github.com/go-openapi/swag v0.22.4/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+VcZ0yl14=
|
github.com/go-openapi/swag v0.22.4/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+VcZ0yl14=
|
||||||
github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
|
github.com/go-sql-driver/mysql v1.9.0 h1:Y0zIbQXhQKmQgTp44Y1dp3wTXcn804QoTptLZT1vtvo=
|
||||||
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
|
github.com/go-sql-driver/mysql v1.9.0/go.mod h1:pDetrLJeA3oMujJuvXc8RJoasr589B6A9fwzD3QMrqw=
|
||||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
|
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
|
||||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls=
|
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls=
|
||||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||||
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
|
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
|
||||||
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||||
github.com/golang-jwt/jwt/v4 v4.5.1 h1:JdqV9zKUdtaa9gdPlywC3aeoEsR681PlKC+4F5gQgeo=
|
github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI=
|
||||||
github.com/golang-jwt/jwt/v4 v4.5.1/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
|
github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
|
||||||
github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
|
github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
|
||||||
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
|
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
|
||||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||||
@@ -82,6 +80,10 @@ github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1 h1:K6RDEckDVWvDI9JAJY
|
|||||||
github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
|
github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
|
||||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
|
github.com/grafana/pyroscope-go v1.2.2 h1:uvKCyZMD724RkaCEMrSTC38Yn7AnFe8S2wiAIYdDPCE=
|
||||||
|
github.com/grafana/pyroscope-go v1.2.2/go.mod h1:zzT9QXQAp2Iz2ZdS216UiV8y9uXJYQiGE1q8v1FyhqU=
|
||||||
|
github.com/grafana/pyroscope-go/godeltaprof v0.1.8 h1:iwOtYXeeVSAeYefJNaxDytgjKtUuKQbJqgAIjlnicKg=
|
||||||
|
github.com/grafana/pyroscope-go/godeltaprof v0.1.8/go.mod h1:2+l7K7twW49Ct4wFluZD3tZ6e0SjanjcUUBPVD/UuGU=
|
||||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 h1:bkypFPDjIYGfCYD5mRBvpqxfYX1YCS1PXdKYWi8FsN0=
|
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 h1:bkypFPDjIYGfCYD5mRBvpqxfYX1YCS1PXdKYWi8FsN0=
|
||||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0/go.mod h1:P+Lt/0by1T8bfcF3z737NnSbmxQAppXMRziHUxPOC8k=
|
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0/go.mod h1:P+Lt/0by1T8bfcF3z737NnSbmxQAppXMRziHUxPOC8k=
|
||||||
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw=
|
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw=
|
||||||
@@ -90,8 +92,8 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI
|
|||||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||||
github.com/jackc/pgx/v5 v5.7.2 h1:mLoDLV6sonKlvjIEsV56SkWNCnuNv531l94GaIzO+XI=
|
github.com/jackc/pgx/v5 v5.7.4 h1:9wKznZrhWa2QiHL+NjTSPP6yjl3451BX3imWDnokYlg=
|
||||||
github.com/jackc/pgx/v5 v5.7.2/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ=
|
github.com/jackc/pgx/v5 v5.7.4/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ=
|
||||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||||
github.com/jhump/protoreflect v1.17.0 h1:qOEr613fac2lOuTgWN4tPAtLL7fUSbuJL5X5XumQh94=
|
github.com/jhump/protoreflect v1.17.0 h1:qOEr613fac2lOuTgWN4tPAtLL7fUSbuJL5X5XumQh94=
|
||||||
@@ -103,8 +105,8 @@ github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHm
|
|||||||
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
||||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||||
github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE=
|
github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE=
|
||||||
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
|
github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc=
|
||||||
github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
|
github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0=
|
||||||
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||||
@@ -121,7 +123,6 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
|
|||||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
|
|
||||||
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
|
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
|
||||||
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
@@ -135,8 +136,6 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq
|
|||||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
||||||
github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32 h1:W6apQkHrMkS0Muv8G/TipAy/FJl/rCYT0+EuS8+Z0z4=
|
github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32 h1:W6apQkHrMkS0Muv8G/TipAy/FJl/rCYT0+EuS8+Z0z4=
|
||||||
github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32/go.mod h1:9wM+0iRr9ahx58uYLpLIr5fm8diHn0JbqRycJi6w0Ms=
|
github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32/go.mod h1:9wM+0iRr9ahx58uYLpLIr5fm8diHn0JbqRycJi6w0Ms=
|
||||||
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
|
|
||||||
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
|
|
||||||
github.com/onsi/ginkgo/v2 v2.13.0 h1:0jY9lJquiL8fcf3M4LAXN5aMlS/b2BV86HFFPCPMgE4=
|
github.com/onsi/ginkgo/v2 v2.13.0 h1:0jY9lJquiL8fcf3M4LAXN5aMlS/b2BV86HFFPCPMgE4=
|
||||||
github.com/onsi/ginkgo/v2 v2.13.0/go.mod h1:TE309ZR8s5FsKKpuB1YAQYBzCaAfUgatB/xlT/ETL/o=
|
github.com/onsi/ginkgo/v2 v2.13.0/go.mod h1:TE309ZR8s5FsKKpuB1YAQYBzCaAfUgatB/xlT/ETL/o=
|
||||||
github.com/onsi/gomega v1.29.0 h1:KIA/t2t5UBzoirT4H9tsML45GEbo3ouUnBHsCfD2tVg=
|
github.com/onsi/gomega v1.29.0 h1:KIA/t2t5UBzoirT4H9tsML45GEbo3ouUnBHsCfD2tVg=
|
||||||
@@ -151,16 +150,16 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
|
|||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g=
|
github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g=
|
||||||
github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P1IAHhP5U=
|
github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P1IAHhP5U=
|
||||||
github.com/prometheus/client_golang v1.20.5 h1:cxppBPuYhUnsO6yo/aoRol4L7q7UFfdm+bR9r+8l63Y=
|
github.com/prometheus/client_golang v1.21.1 h1:DOvXXTqVzvkIewV/CDPFdejpMCGeMcbGCQ8YOmu+Ibk=
|
||||||
github.com/prometheus/client_golang v1.20.5/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE=
|
github.com/prometheus/client_golang v1.21.1/go.mod h1:U9NM32ykUErtVBxdvD3zfi+EuFkkaBvMb09mIfe0Zgg=
|
||||||
github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E=
|
github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E=
|
||||||
github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY=
|
github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY=
|
||||||
github.com/prometheus/common v0.55.0 h1:KEi6DK7lXW/m7Ig5i47x0vRzuBsHuvJdi5ee6Y3G1dc=
|
github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io=
|
||||||
github.com/prometheus/common v0.55.0/go.mod h1:2SECS4xJG1kd8XF9IcM1gMX6510RAEL65zxzNImwdc8=
|
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
|
||||||
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
|
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
|
||||||
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
|
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
|
||||||
github.com/redis/go-redis/v9 v9.7.0 h1:HhLSs+B6O021gwzl+locl0zEDnyNkxMtf/Z3NNBMa9E=
|
github.com/redis/go-redis/v9 v9.11.0 h1:E3S08Gl/nJNn5vkxd2i78wZxWAPNZgUNTp8WIJUAiIs=
|
||||||
github.com/redis/go-redis/v9 v9.7.0/go.mod h1:f6zhXITC7JUJIlPEiBOTXxJgPLdZcA93GewI7inzyWw=
|
github.com/redis/go-redis/v9 v9.11.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
|
||||||
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
||||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||||
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
||||||
@@ -203,8 +202,8 @@ go.etcd.io/etcd/client/pkg/v3 v3.5.15 h1:fo0HpWz/KlHGMCC+YejpiCmyWDEuIpnTDzpJLB5
|
|||||||
go.etcd.io/etcd/client/pkg/v3 v3.5.15/go.mod h1:mXDI4NAOwEiszrHCb0aqfAYNCrZP4e9hRca3d1YK8EU=
|
go.etcd.io/etcd/client/pkg/v3 v3.5.15/go.mod h1:mXDI4NAOwEiszrHCb0aqfAYNCrZP4e9hRca3d1YK8EU=
|
||||||
go.etcd.io/etcd/client/v3 v3.5.15 h1:23M0eY4Fd/inNv1ZfU3AxrbbOdW79r9V9Rl62Nm6ip4=
|
go.etcd.io/etcd/client/v3 v3.5.15 h1:23M0eY4Fd/inNv1ZfU3AxrbbOdW79r9V9Rl62Nm6ip4=
|
||||||
go.etcd.io/etcd/client/v3 v3.5.15/go.mod h1:CLSJxrYjvLtHsrPKsy7LmZEE+DK2ktfd2bN4RhBMwlU=
|
go.etcd.io/etcd/client/v3 v3.5.15/go.mod h1:CLSJxrYjvLtHsrPKsy7LmZEE+DK2ktfd2bN4RhBMwlU=
|
||||||
go.mongodb.org/mongo-driver v1.17.2 h1:gvZyk8352qSfzyZ2UMWcpDpMSGEr1eqE4T793SqyhzM=
|
go.mongodb.org/mongo-driver v1.17.4 h1:jUorfmVzljjr0FLzYQsGP8cgN/qzzxlY9Vh0C9KFXVw=
|
||||||
go.mongodb.org/mongo-driver v1.17.2/go.mod h1:Hy04i7O2kC4RS06ZrhPRqj/u4DTYkFDAAccj+rVKqgQ=
|
go.mongodb.org/mongo-driver v1.17.4/go.mod h1:Hy04i7O2kC4RS06ZrhPRqj/u4DTYkFDAAccj+rVKqgQ=
|
||||||
go.opentelemetry.io/otel v1.24.0 h1:0LAOdjNmQeSTzGBzduGe/rU4tZhMwL5rWgtp9Ku5Jfo=
|
go.opentelemetry.io/otel v1.24.0 h1:0LAOdjNmQeSTzGBzduGe/rU4tZhMwL5rWgtp9Ku5Jfo=
|
||||||
go.opentelemetry.io/otel v1.24.0/go.mod h1:W7b9Ozg4nkF5tWI5zsXkaKKDjdVjpD4oAt9Qi/MArHo=
|
go.opentelemetry.io/otel v1.24.0/go.mod h1:W7b9Ozg4nkF5tWI5zsXkaKKDjdVjpD4oAt9Qi/MArHo=
|
||||||
go.opentelemetry.io/otel/exporters/jaeger v1.17.0 h1:D7UpUy2Xc2wsi1Ras6V40q806WM07rqoCWzXu7Sqy+4=
|
go.opentelemetry.io/otel/exporters/jaeger v1.17.0 h1:D7UpUy2Xc2wsi1Ras6V40q806WM07rqoCWzXu7Sqy+4=
|
||||||
@@ -241,8 +240,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
|
|||||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||||
golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc=
|
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
|
||||||
golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc=
|
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
|
||||||
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||||
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||||
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||||
@@ -254,17 +253,17 @@ golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwY
|
|||||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||||
golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0=
|
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
|
||||||
golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k=
|
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
|
||||||
golang.org/x/oauth2 v0.21.0 h1:tsimM75w1tF/uws5rbeHzIWxEqElMehnc+iW793zsZs=
|
golang.org/x/oauth2 v0.24.0 h1:KTBBxWqUa0ykRPLtV69rRto9TLXcqYkeswu48x/gvNE=
|
||||||
golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
|
golang.org/x/oauth2 v0.24.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
|
||||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
|
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
|
||||||
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
@@ -276,20 +275,20 @@ golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBc
|
|||||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
|
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
|
||||||
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||||
golang.org/x/term v0.28.0 h1:/Ts8HFuMR2E6IP/jlo7QVLZHggjKQbhu/7H0LJFr3Gg=
|
golang.org/x/term v0.29.0 h1:L6pJp37ocefwRRtYPKSWOWzOtWSxVajvz2ldH/xi3iU=
|
||||||
golang.org/x/term v0.28.0/go.mod h1:Sw/lC2IAUZ92udQNf3WodGtn4k/XoLyZoh8v/8uiwek=
|
golang.org/x/term v0.29.0/go.mod h1:6bl4lRlvVuDgSf3179VpIxBF0o10JUpXWOnI7nErv7s=
|
||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||||
golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
|
golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
|
||||||
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
|
golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
|
||||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
|
||||||
golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY=
|
golang.org/x/time v0.10.0 h1:3usCWA8tQn0L8+hFJQNgzpWbd89begxN66o1Ojdn5L4=
|
||||||
golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
golang.org/x/time v0.10.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||||
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
||||||
@@ -308,8 +307,8 @@ google.golang.org/genproto/googleapis/rpc v0.0.0-20240701130421-f6361c86f094 h1:
|
|||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240701130421-f6361c86f094/go.mod h1:Ue6ibwXGpU+dqIcODieyLOcgj7z8+IcskoNIgZxtrFY=
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20240701130421-f6361c86f094/go.mod h1:Ue6ibwXGpU+dqIcODieyLOcgj7z8+IcskoNIgZxtrFY=
|
||||||
google.golang.org/grpc v1.65.0 h1:bs/cUb4lp1G5iImFFd3u5ixQzweKizoZJAwBNLR42lc=
|
google.golang.org/grpc v1.65.0 h1:bs/cUb4lp1G5iImFFd3u5ixQzweKizoZJAwBNLR42lc=
|
||||||
google.golang.org/grpc v1.65.0/go.mod h1:WgYC2ypjlB0EiQi6wdKixMqukr6lBc0Vo+oOgjrM5ZQ=
|
google.golang.org/grpc v1.65.0/go.mod h1:WgYC2ypjlB0EiQi6wdKixMqukr6lBc0Vo+oOgjrM5ZQ=
|
||||||
google.golang.org/protobuf v1.36.4 h1:6A3ZDJHn/eNqc1i+IdefRzy/9PokBTPvcqMySR7NNIM=
|
google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM=
|
||||||
google.golang.org/protobuf v1.36.4/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
|
google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||||
|
|||||||
263
internal/profiling/profiling.go
Normal file
263
internal/profiling/profiling.go
Normal file
@@ -0,0 +1,263 @@
|
|||||||
|
package profiling
|
||||||
|
|
||||||
|
import (
|
||||||
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/grafana/pyroscope-go"
|
||||||
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
|
"github.com/zeromicro/go-zero/core/proc"
|
||||||
|
"github.com/zeromicro/go-zero/core/stat"
|
||||||
|
"github.com/zeromicro/go-zero/core/threading"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultCheckInterval = time.Second * 10
|
||||||
|
defaultProfilingDuration = time.Minute * 2
|
||||||
|
defaultUploadRate = time.Second * 15
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
Config struct {
|
||||||
|
// Name is the name of the application.
|
||||||
|
Name string `json:",optional,inherit"`
|
||||||
|
// ServerAddr is the address of the profiling server.
|
||||||
|
ServerAddr string
|
||||||
|
// AuthUser is the username for basic authentication.
|
||||||
|
AuthUser string `json:",optional"`
|
||||||
|
// AuthPassword is the password for basic authentication.
|
||||||
|
AuthPassword string `json:",optional"`
|
||||||
|
// UploadRate is the duration for which profiling data is uploaded.
|
||||||
|
UploadRate time.Duration `json:",default=15s"`
|
||||||
|
// CheckInterval is the interval to check if profiling should start.
|
||||||
|
CheckInterval time.Duration `json:",default=10s"`
|
||||||
|
// ProfilingDuration is the duration for which profiling data is collected.
|
||||||
|
ProfilingDuration time.Duration `json:",default=2m"`
|
||||||
|
// CpuThreshold the collection is allowed only when the current service cpu > CpuThreshold
|
||||||
|
CpuThreshold int64 `json:",default=700,range=[0:1000)"`
|
||||||
|
|
||||||
|
// ProfileType is the type of profiling to be performed.
|
||||||
|
ProfileType ProfileType
|
||||||
|
}
|
||||||
|
|
||||||
|
ProfileType struct {
|
||||||
|
// Logger is a flag to enable or disable logging.
|
||||||
|
Logger bool `json:",default=false"`
|
||||||
|
// CPU is a flag to disable CPU profiling.
|
||||||
|
CPU bool `json:",default=true"`
|
||||||
|
// Goroutines is a flag to disable goroutine profiling.
|
||||||
|
Goroutines bool `json:",default=true"`
|
||||||
|
// Memory is a flag to disable memory profiling.
|
||||||
|
Memory bool `json:",default=true"`
|
||||||
|
// Mutex is a flag to disable mutex profiling.
|
||||||
|
Mutex bool `json:",default=false"`
|
||||||
|
// Block is a flag to disable block profiling.
|
||||||
|
Block bool `json:",default=false"`
|
||||||
|
}
|
||||||
|
|
||||||
|
profiler interface {
|
||||||
|
Start() error
|
||||||
|
Stop() error
|
||||||
|
}
|
||||||
|
|
||||||
|
pyroscopeProfiler struct {
|
||||||
|
c Config
|
||||||
|
profiler *pyroscope.Profiler
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
once sync.Once
|
||||||
|
|
||||||
|
newProfiler = func(c Config) profiler {
|
||||||
|
return newPyroscopeProfiler(c)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
// Start initializes the pyroscope profiler with the given configuration.
|
||||||
|
func Start(c Config) {
|
||||||
|
// check if the profiling is enabled
|
||||||
|
if len(c.ServerAddr) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// set default values for the configuration
|
||||||
|
if c.ProfilingDuration <= 0 {
|
||||||
|
c.ProfilingDuration = defaultProfilingDuration
|
||||||
|
}
|
||||||
|
|
||||||
|
// set default values for the configuration
|
||||||
|
if c.CheckInterval <= 0 {
|
||||||
|
c.CheckInterval = defaultCheckInterval
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.UploadRate <= 0 {
|
||||||
|
c.UploadRate = defaultUploadRate
|
||||||
|
}
|
||||||
|
|
||||||
|
once.Do(func() {
|
||||||
|
logx.Info("continuous profiling started")
|
||||||
|
|
||||||
|
threading.GoSafe(func() {
|
||||||
|
startPyroscope(c, proc.Done())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// startPyroscope starts the pyroscope profiler with the given configuration.
|
||||||
|
func startPyroscope(c Config, done <-chan struct{}) {
|
||||||
|
var (
|
||||||
|
pr profiler
|
||||||
|
err error
|
||||||
|
latestProfilingTime time.Time
|
||||||
|
intervalTicker = time.NewTicker(c.CheckInterval)
|
||||||
|
profilingTicker = time.NewTicker(c.ProfilingDuration)
|
||||||
|
)
|
||||||
|
|
||||||
|
defer profilingTicker.Stop()
|
||||||
|
defer intervalTicker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-intervalTicker.C:
|
||||||
|
// Check if the machine is overloaded and if the profiler is not running
|
||||||
|
if pr == nil && isCpuOverloaded(c) {
|
||||||
|
pr = newProfiler(c)
|
||||||
|
if err := pr.Start(); err != nil {
|
||||||
|
logx.Errorf("failed to start profiler: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// record the latest profiling time
|
||||||
|
latestProfilingTime = time.Now()
|
||||||
|
logx.Infof("pyroscope profiler started.")
|
||||||
|
}
|
||||||
|
case <-profilingTicker.C:
|
||||||
|
// check if the profiling duration has passed
|
||||||
|
if !time.Now().After(latestProfilingTime.Add(c.ProfilingDuration)) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if the profiler is already running, if so, skip
|
||||||
|
if pr != nil {
|
||||||
|
if err = pr.Stop(); err != nil {
|
||||||
|
logx.Errorf("failed to stop profiler: %v", err)
|
||||||
|
}
|
||||||
|
logx.Infof("pyroscope profiler stopped.")
|
||||||
|
pr = nil
|
||||||
|
}
|
||||||
|
case <-done:
|
||||||
|
logx.Infof("continuous profiling stopped.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// genPyroscopeConf generates the pyroscope configuration based on the given config.
|
||||||
|
func genPyroscopeConf(c Config) pyroscope.Config {
|
||||||
|
pConf := pyroscope.Config{
|
||||||
|
UploadRate: c.UploadRate,
|
||||||
|
ApplicationName: c.Name,
|
||||||
|
BasicAuthUser: c.AuthUser, // http basic auth user
|
||||||
|
BasicAuthPassword: c.AuthPassword, // http basic auth password
|
||||||
|
ServerAddress: c.ServerAddr,
|
||||||
|
Logger: nil,
|
||||||
|
HTTPHeaders: map[string]string{},
|
||||||
|
// you can provide static tags via a map:
|
||||||
|
Tags: map[string]string{
|
||||||
|
"name": c.Name,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.ProfileType.Logger {
|
||||||
|
pConf.Logger = logx.WithCallerSkip(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.ProfileType.CPU {
|
||||||
|
pConf.ProfileTypes = append(pConf.ProfileTypes, pyroscope.ProfileCPU)
|
||||||
|
}
|
||||||
|
if c.ProfileType.Goroutines {
|
||||||
|
pConf.ProfileTypes = append(pConf.ProfileTypes, pyroscope.ProfileGoroutines)
|
||||||
|
}
|
||||||
|
if c.ProfileType.Memory {
|
||||||
|
pConf.ProfileTypes = append(pConf.ProfileTypes, pyroscope.ProfileAllocObjects, pyroscope.ProfileAllocSpace,
|
||||||
|
pyroscope.ProfileInuseObjects, pyroscope.ProfileInuseSpace)
|
||||||
|
}
|
||||||
|
if c.ProfileType.Mutex {
|
||||||
|
pConf.ProfileTypes = append(pConf.ProfileTypes, pyroscope.ProfileMutexCount, pyroscope.ProfileMutexDuration)
|
||||||
|
}
|
||||||
|
if c.ProfileType.Block {
|
||||||
|
pConf.ProfileTypes = append(pConf.ProfileTypes, pyroscope.ProfileBlockCount, pyroscope.ProfileBlockDuration)
|
||||||
|
}
|
||||||
|
|
||||||
|
logx.Infof("applicationName: %s", pConf.ApplicationName)
|
||||||
|
|
||||||
|
return pConf
|
||||||
|
}
|
||||||
|
|
||||||
|
// isCpuOverloaded checks the machine performance based on the given configuration.
|
||||||
|
func isCpuOverloaded(c Config) bool {
|
||||||
|
currentValue := stat.CpuUsage()
|
||||||
|
if currentValue >= c.CpuThreshold {
|
||||||
|
logx.Infof("continuous profiling cpu overload, cpu: %d", currentValue)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func newPyroscopeProfiler(c Config) profiler {
|
||||||
|
return &pyroscopeProfiler{
|
||||||
|
c: c,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *pyroscopeProfiler) Start() error {
|
||||||
|
pConf := genPyroscopeConf(p.c)
|
||||||
|
// set mutex and block profile rate
|
||||||
|
setFraction(p.c)
|
||||||
|
prof, err := pyroscope.Start(pConf)
|
||||||
|
if err != nil {
|
||||||
|
resetFraction(p.c)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
p.profiler = prof
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *pyroscopeProfiler) Stop() error {
|
||||||
|
if p.profiler == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := p.profiler.Stop(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
resetFraction(p.c)
|
||||||
|
p.profiler = nil
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setFraction(c Config) {
|
||||||
|
// These 2 lines are only required if you're using mutex or block profiling
|
||||||
|
if c.ProfileType.Mutex {
|
||||||
|
runtime.SetMutexProfileFraction(10) // 10/seconds
|
||||||
|
}
|
||||||
|
if c.ProfileType.Block {
|
||||||
|
runtime.SetBlockProfileRate(1000 * 1000) // 1/millisecond
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func resetFraction(c Config) {
|
||||||
|
// These 2 lines are only required if you're using mutex or block profiling
|
||||||
|
if c.ProfileType.Mutex {
|
||||||
|
runtime.SetMutexProfileFraction(0)
|
||||||
|
}
|
||||||
|
if c.ProfileType.Block {
|
||||||
|
runtime.SetBlockProfileRate(0)
|
||||||
|
}
|
||||||
|
}
|
||||||
177
internal/profiling/profiling_test.go
Normal file
177
internal/profiling/profiling_test.go
Normal file
@@ -0,0 +1,177 @@
|
|||||||
|
package profiling
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/grafana/pyroscope-go"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/zeromicro/go-zero/core/conf"
|
||||||
|
"github.com/zeromicro/go-zero/core/syncx"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestStart(t *testing.T) {
|
||||||
|
t.Run("profiling", func(t *testing.T) {
|
||||||
|
var c Config
|
||||||
|
assert.NoError(t, conf.FillDefault(&c))
|
||||||
|
c.Name = "test"
|
||||||
|
p := newProfiler(c)
|
||||||
|
assert.NotNil(t, p)
|
||||||
|
assert.NoError(t, p.Start())
|
||||||
|
assert.NoError(t, p.Stop())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid config", func(t *testing.T) {
|
||||||
|
mp := &mockProfiler{}
|
||||||
|
newProfiler = func(c Config) profiler {
|
||||||
|
return mp
|
||||||
|
}
|
||||||
|
|
||||||
|
Start(Config{})
|
||||||
|
|
||||||
|
Start(Config{
|
||||||
|
ServerAddr: "localhost:4040",
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("test start profiler", func(t *testing.T) {
|
||||||
|
mp := &mockProfiler{}
|
||||||
|
newProfiler = func(c Config) profiler {
|
||||||
|
return mp
|
||||||
|
}
|
||||||
|
|
||||||
|
c := Config{
|
||||||
|
Name: "test",
|
||||||
|
ServerAddr: "localhost:4040",
|
||||||
|
CheckInterval: time.Millisecond,
|
||||||
|
ProfilingDuration: time.Millisecond * 10,
|
||||||
|
CpuThreshold: 0,
|
||||||
|
}
|
||||||
|
var done = make(chan struct{})
|
||||||
|
go startPyroscope(c, done)
|
||||||
|
|
||||||
|
time.Sleep(time.Millisecond * 50)
|
||||||
|
close(done)
|
||||||
|
|
||||||
|
assert.True(t, mp.started.True())
|
||||||
|
assert.True(t, mp.stopped.True())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("test start profiler with cpu overloaded", func(t *testing.T) {
|
||||||
|
mp := &mockProfiler{}
|
||||||
|
newProfiler = func(c Config) profiler {
|
||||||
|
return mp
|
||||||
|
}
|
||||||
|
|
||||||
|
c := Config{
|
||||||
|
Name: "test",
|
||||||
|
ServerAddr: "localhost:4040",
|
||||||
|
CheckInterval: time.Millisecond,
|
||||||
|
ProfilingDuration: time.Millisecond * 10,
|
||||||
|
CpuThreshold: 900,
|
||||||
|
}
|
||||||
|
var done = make(chan struct{})
|
||||||
|
go startPyroscope(c, done)
|
||||||
|
|
||||||
|
time.Sleep(time.Millisecond * 50)
|
||||||
|
close(done)
|
||||||
|
|
||||||
|
assert.False(t, mp.started.True())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("start/stop err", func(t *testing.T) {
|
||||||
|
mp := &mockProfiler{
|
||||||
|
err: assert.AnError,
|
||||||
|
}
|
||||||
|
newProfiler = func(c Config) profiler {
|
||||||
|
return mp
|
||||||
|
}
|
||||||
|
|
||||||
|
c := Config{
|
||||||
|
Name: "test",
|
||||||
|
ServerAddr: "localhost:4040",
|
||||||
|
CheckInterval: time.Millisecond,
|
||||||
|
ProfilingDuration: time.Millisecond * 10,
|
||||||
|
CpuThreshold: 0,
|
||||||
|
}
|
||||||
|
var done = make(chan struct{})
|
||||||
|
go startPyroscope(c, done)
|
||||||
|
|
||||||
|
time.Sleep(time.Millisecond * 50)
|
||||||
|
close(done)
|
||||||
|
|
||||||
|
assert.False(t, mp.started.True())
|
||||||
|
assert.False(t, mp.stopped.True())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenPyroscopeConf(t *testing.T) {
|
||||||
|
c := Config{
|
||||||
|
Name: "",
|
||||||
|
ServerAddr: "localhost:4040",
|
||||||
|
AuthUser: "user",
|
||||||
|
AuthPassword: "password",
|
||||||
|
ProfileType: ProfileType{
|
||||||
|
Logger: true,
|
||||||
|
CPU: true,
|
||||||
|
Goroutines: true,
|
||||||
|
Memory: true,
|
||||||
|
Mutex: true,
|
||||||
|
Block: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pyroscopeConf := genPyroscopeConf(c)
|
||||||
|
assert.Equal(t, c.ServerAddr, pyroscopeConf.ServerAddress)
|
||||||
|
assert.Equal(t, c.AuthUser, pyroscopeConf.BasicAuthUser)
|
||||||
|
assert.Equal(t, c.AuthPassword, pyroscopeConf.BasicAuthPassword)
|
||||||
|
assert.Equal(t, c.Name, pyroscopeConf.ApplicationName)
|
||||||
|
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileCPU)
|
||||||
|
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileGoroutines)
|
||||||
|
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileAllocObjects)
|
||||||
|
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileAllocSpace)
|
||||||
|
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileInuseObjects)
|
||||||
|
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileInuseSpace)
|
||||||
|
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileMutexCount)
|
||||||
|
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileMutexDuration)
|
||||||
|
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileBlockCount)
|
||||||
|
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileBlockDuration)
|
||||||
|
|
||||||
|
setFraction(c)
|
||||||
|
resetFraction(c)
|
||||||
|
|
||||||
|
newPyroscopeProfiler(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewPyroscopeProfiler(t *testing.T) {
|
||||||
|
p := newPyroscopeProfiler(Config{})
|
||||||
|
|
||||||
|
assert.Error(t, p.Start())
|
||||||
|
assert.NoError(t, p.Stop())
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockProfiler struct {
|
||||||
|
mutex sync.Mutex
|
||||||
|
started syncx.AtomicBool
|
||||||
|
stopped syncx.AtomicBool
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockProfiler) Start() error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
if m.err == nil {
|
||||||
|
m.started.Set(true)
|
||||||
|
}
|
||||||
|
m.mutex.Unlock()
|
||||||
|
return m.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockProfiler) Stop() error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
if m.err == nil {
|
||||||
|
m.stopped.Set(true)
|
||||||
|
}
|
||||||
|
m.mutex.Unlock()
|
||||||
|
return m.err
|
||||||
|
}
|
||||||
43
mcp/config.go
Normal file
43
mcp/config.go
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/zeromicro/go-zero/rest"
|
||||||
|
)
|
||||||
|
|
||||||
|
// McpConf defines the configuration for an MCP server.
|
||||||
|
// It embeds rest.RestConf for HTTP server settings
|
||||||
|
// and adds MCP-specific configuration options.
|
||||||
|
type McpConf struct {
|
||||||
|
rest.RestConf
|
||||||
|
Mcp struct {
|
||||||
|
// Name is the server name reported in initialize responses
|
||||||
|
Name string `json:",optional"`
|
||||||
|
|
||||||
|
// Version is the server version reported in initialize responses
|
||||||
|
Version string `json:",default=1.0.0"`
|
||||||
|
|
||||||
|
// ProtocolVersion is the MCP protocol version implemented
|
||||||
|
ProtocolVersion string `json:",default=2024-11-05"`
|
||||||
|
|
||||||
|
// BaseUrl is the base URL for the server, used in SSE endpoint messages
|
||||||
|
// If not set, defaults to http://localhost:{Port}
|
||||||
|
BaseUrl string `json:",optional"`
|
||||||
|
|
||||||
|
// SseEndpoint is the path for Server-Sent Events connections
|
||||||
|
SseEndpoint string `json:",default=/sse"`
|
||||||
|
|
||||||
|
// MessageEndpoint is the path for JSON-RPC requests
|
||||||
|
MessageEndpoint string `json:",default=/message"`
|
||||||
|
|
||||||
|
// Cors contains allowed CORS origins
|
||||||
|
Cors []string `json:",optional"`
|
||||||
|
|
||||||
|
// SseTimeout is the maximum time allowed for SSE connections
|
||||||
|
SseTimeout time.Duration `json:",default=24h"`
|
||||||
|
|
||||||
|
// MessageTimeout is the maximum time allowed for request execution
|
||||||
|
MessageTimeout time.Duration `json:",default=30s"`
|
||||||
|
}
|
||||||
|
}
|
||||||
63
mcp/config_test.go
Normal file
63
mcp/config_test.go
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/zeromicro/go-zero/core/conf"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMcpConfDefaults(t *testing.T) {
|
||||||
|
// Test default values are set correctly when unmarshalled from JSON
|
||||||
|
jsonConfig := `name: test-service
|
||||||
|
port: 8080
|
||||||
|
mcp:
|
||||||
|
name: test-mcp-server
|
||||||
|
version: 1.0.0
|
||||||
|
`
|
||||||
|
|
||||||
|
var c McpConf
|
||||||
|
err := conf.LoadFromYamlBytes([]byte(jsonConfig), &c)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Check default values
|
||||||
|
assert.Equal(t, "test-mcp-server", c.Mcp.Name)
|
||||||
|
assert.Equal(t, "1.0.0", c.Mcp.Version, "Default version should be 1.0.0")
|
||||||
|
assert.Equal(t, "2024-11-05", c.Mcp.ProtocolVersion, "Default protocol version should be 2024-11-05")
|
||||||
|
assert.Equal(t, "/sse", c.Mcp.SseEndpoint, "Default SSE endpoint should be /sse")
|
||||||
|
assert.Equal(t, "/message", c.Mcp.MessageEndpoint, "Default message endpoint should be /message")
|
||||||
|
assert.Equal(t, 30*time.Second, c.Mcp.MessageTimeout, "Default message timeout should be 30s")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMcpConfCustomValues(t *testing.T) {
|
||||||
|
// Test custom values can be set
|
||||||
|
jsonConfig := `{
|
||||||
|
"Name": "test-service",
|
||||||
|
"Port": 8080,
|
||||||
|
"Mcp": {
|
||||||
|
"Name": "test-mcp-server",
|
||||||
|
"Version": "2.0.0",
|
||||||
|
"ProtocolVersion": "2025-01-01",
|
||||||
|
"BaseUrl": "http://example.com",
|
||||||
|
"SseEndpoint": "/custom-sse",
|
||||||
|
"MessageEndpoint": "/custom-message",
|
||||||
|
"Cors": ["http://localhost:3000", "http://example.com"],
|
||||||
|
"MessageTimeout": "60s"
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
var c McpConf
|
||||||
|
err := conf.LoadFromJsonBytes([]byte(jsonConfig), &c)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Check custom values
|
||||||
|
assert.Equal(t, "test-mcp-server", c.Mcp.Name, "Name should be inherited from RestConf")
|
||||||
|
assert.Equal(t, "2.0.0", c.Mcp.Version, "Version should be customizable")
|
||||||
|
assert.Equal(t, "2025-01-01", c.Mcp.ProtocolVersion, "Protocol version should be customizable")
|
||||||
|
assert.Equal(t, "http://example.com", c.Mcp.BaseUrl, "BaseUrl should be customizable")
|
||||||
|
assert.Equal(t, "/custom-sse", c.Mcp.SseEndpoint, "SSE endpoint should be customizable")
|
||||||
|
assert.Equal(t, "/custom-message", c.Mcp.MessageEndpoint, "Message endpoint should be customizable")
|
||||||
|
assert.Equal(t, []string{"http://localhost:3000", "http://example.com"}, c.Mcp.Cors, "CORS settings should be customizable")
|
||||||
|
assert.Equal(t, 60*time.Second, c.Mcp.MessageTimeout, "Tool timeout should be customizable")
|
||||||
|
}
|
||||||
443
mcp/integration_test.go
Normal file
443
mcp/integration_test.go
Normal file
@@ -0,0 +1,443 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// syncResponseRecorder is a thread-safe wrapper around httptest.ResponseRecorder
|
||||||
|
type syncResponseRecorder struct {
|
||||||
|
*httptest.ResponseRecorder
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new synchronized response recorder
|
||||||
|
func newSyncResponseRecorder() *syncResponseRecorder {
|
||||||
|
return &syncResponseRecorder{
|
||||||
|
ResponseRecorder: httptest.NewRecorder(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Override Write method to synchronize access
|
||||||
|
func (srr *syncResponseRecorder) Write(p []byte) (int, error) {
|
||||||
|
srr.mu.Lock()
|
||||||
|
defer srr.mu.Unlock()
|
||||||
|
return srr.ResponseRecorder.Write(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Override WriteHeader method to synchronize access
|
||||||
|
func (srr *syncResponseRecorder) WriteHeader(statusCode int) {
|
||||||
|
srr.mu.Lock()
|
||||||
|
defer srr.mu.Unlock()
|
||||||
|
srr.ResponseRecorder.WriteHeader(statusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Override Result method to synchronize access
|
||||||
|
func (srr *syncResponseRecorder) Result() *http.Response {
|
||||||
|
srr.mu.Lock()
|
||||||
|
defer srr.mu.Unlock()
|
||||||
|
return srr.ResponseRecorder.Result()
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHTTPHandlerIntegration tests the HTTP handlers with a real server instance
|
||||||
|
func TestHTTPHandlerIntegration(t *testing.T) {
|
||||||
|
// Skip in short test mode
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a test configuration
|
||||||
|
conf := McpConf{}
|
||||||
|
conf.Mcp.Name = "test-integration"
|
||||||
|
conf.Mcp.Version = "1.0.0-test"
|
||||||
|
conf.Mcp.MessageTimeout = 1 * time.Second
|
||||||
|
|
||||||
|
// Create a mock server directly
|
||||||
|
server := &sseMcpServer{
|
||||||
|
conf: conf,
|
||||||
|
clients: make(map[string]*mcpClient),
|
||||||
|
tools: make(map[string]Tool),
|
||||||
|
prompts: make(map[string]Prompt),
|
||||||
|
resources: make(map[string]Resource),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register a test tool
|
||||||
|
err := server.RegisterTool(Tool{
|
||||||
|
Name: "echo",
|
||||||
|
Description: "Echo tool for testing",
|
||||||
|
InputSchema: InputSchema{
|
||||||
|
Properties: map[string]any{
|
||||||
|
"message": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"description": "Message to echo",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
|
if msg, ok := params["message"].(string); ok {
|
||||||
|
return fmt.Sprintf("Echo: %s", msg), nil
|
||||||
|
}
|
||||||
|
return "Echo: no message provided", nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Create a test HTTP request to the SSE endpoint
|
||||||
|
req := httptest.NewRequest("GET", "/sse", nil)
|
||||||
|
w := newSyncResponseRecorder()
|
||||||
|
|
||||||
|
// Create a done channel to signal completion of test
|
||||||
|
done := make(chan bool)
|
||||||
|
|
||||||
|
// Start the SSE handler in a goroutine
|
||||||
|
go func() {
|
||||||
|
// lock.Lock()
|
||||||
|
server.handleSSE(w, req)
|
||||||
|
// lock.Unlock()
|
||||||
|
done <- true
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Allow time for the handler to process
|
||||||
|
select {
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
// Expected - handler would normally block indefinitely
|
||||||
|
case <-done:
|
||||||
|
// This shouldn't happen immediately - the handler should block
|
||||||
|
t.Error("SSE handler returned unexpectedly")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check the initial headers
|
||||||
|
resp := w.Result()
|
||||||
|
assert.Equal(t, "chunked", resp.Header.Get("Transfer-Encoding"))
|
||||||
|
resp.Body.Close()
|
||||||
|
|
||||||
|
// The handler creates a client and sends the endpoint message
|
||||||
|
var sessionId string
|
||||||
|
|
||||||
|
// Give the handler time to set up the client
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
// Check that a client was created
|
||||||
|
server.clientsLock.Lock()
|
||||||
|
assert.Equal(t, 1, len(server.clients))
|
||||||
|
for id := range server.clients {
|
||||||
|
sessionId = id
|
||||||
|
}
|
||||||
|
server.clientsLock.Unlock()
|
||||||
|
|
||||||
|
require.NotEmpty(t, sessionId, "Expected a session ID to be created")
|
||||||
|
|
||||||
|
// Now that we have a session ID, we can test the message endpoint
|
||||||
|
messageBody, _ := json.Marshal(Request{
|
||||||
|
JsonRpc: "2.0",
|
||||||
|
ID: 1,
|
||||||
|
Method: methodInitialize,
|
||||||
|
Params: json.RawMessage(`{}`),
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create a message request
|
||||||
|
reqURL := fmt.Sprintf("/message?%s=%s", sessionIdKey, sessionId)
|
||||||
|
msgReq := httptest.NewRequest("POST", reqURL, bytes.NewReader(messageBody))
|
||||||
|
msgW := newSyncResponseRecorder()
|
||||||
|
|
||||||
|
// Process the message
|
||||||
|
server.handleRequest(msgW, msgReq)
|
||||||
|
|
||||||
|
// Check the response
|
||||||
|
msgResp := msgW.Result()
|
||||||
|
assert.Equal(t, http.StatusAccepted, msgResp.StatusCode)
|
||||||
|
msgResp.Body.Close() // Ensure response body is closed
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandlerResponseFlow tests the flow of a full request/response cycle
|
||||||
|
func TestHandlerResponseFlow(t *testing.T) {
|
||||||
|
// Create a mock server for testing
|
||||||
|
server := &sseMcpServer{
|
||||||
|
conf: McpConf{},
|
||||||
|
clients: map[string]*mcpClient{
|
||||||
|
"test-session": {
|
||||||
|
id: "test-session",
|
||||||
|
channel: make(chan string, 10),
|
||||||
|
initialized: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
tools: make(map[string]Tool),
|
||||||
|
prompts: make(map[string]Prompt),
|
||||||
|
resources: make(map[string]Resource),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register test resources
|
||||||
|
server.RegisterTool(Tool{
|
||||||
|
Name: "test.tool",
|
||||||
|
Description: "Test tool",
|
||||||
|
InputSchema: InputSchema{Type: "object"},
|
||||||
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
|
return "tool result", nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
server.RegisterPrompt(Prompt{
|
||||||
|
Name: "test.prompt",
|
||||||
|
Description: "Test prompt",
|
||||||
|
})
|
||||||
|
|
||||||
|
server.RegisterResource(Resource{
|
||||||
|
Name: "test.resource",
|
||||||
|
URI: "http://example.com",
|
||||||
|
Description: "Test resource",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create a request with session ID parameter
|
||||||
|
reqURL := fmt.Sprintf("/message?%s=%s", sessionIdKey, "test-session")
|
||||||
|
|
||||||
|
// Test tools/list request
|
||||||
|
toolsListBody, _ := json.Marshal(Request{
|
||||||
|
JsonRpc: "2.0",
|
||||||
|
ID: 1,
|
||||||
|
Method: methodToolsList,
|
||||||
|
Params: json.RawMessage(`{}`),
|
||||||
|
})
|
||||||
|
|
||||||
|
toolsReq := httptest.NewRequest("POST", reqURL, bytes.NewReader(toolsListBody))
|
||||||
|
toolsW := newSyncResponseRecorder()
|
||||||
|
|
||||||
|
// Process the request
|
||||||
|
server.handleRequest(toolsW, toolsReq)
|
||||||
|
|
||||||
|
// Check the response code
|
||||||
|
toolsResp := toolsW.Result()
|
||||||
|
assert.Equal(t, http.StatusAccepted, toolsResp.StatusCode)
|
||||||
|
toolsResp.Body.Close()
|
||||||
|
|
||||||
|
// Check the channel message
|
||||||
|
client := server.clients["test-session"]
|
||||||
|
select {
|
||||||
|
case message := <-client.channel:
|
||||||
|
assert.Contains(t, message, `"tools":[{"name":"test.tool"`)
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Fatal("Timed out waiting for tools/list response")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test prompts/list request
|
||||||
|
promptsListBody, _ := json.Marshal(Request{
|
||||||
|
JsonRpc: "2.0",
|
||||||
|
ID: 2,
|
||||||
|
Method: methodPromptsList,
|
||||||
|
Params: json.RawMessage(`{}`),
|
||||||
|
})
|
||||||
|
|
||||||
|
promptsReq := httptest.NewRequest("POST", reqURL, bytes.NewReader(promptsListBody))
|
||||||
|
promptsW := newSyncResponseRecorder()
|
||||||
|
|
||||||
|
// Process the request
|
||||||
|
server.handleRequest(promptsW, promptsReq)
|
||||||
|
|
||||||
|
// Check the response code
|
||||||
|
promptsResp := promptsW.Result()
|
||||||
|
assert.Equal(t, http.StatusAccepted, promptsResp.StatusCode)
|
||||||
|
promptsResp.Body.Close()
|
||||||
|
|
||||||
|
// Check the channel message
|
||||||
|
select {
|
||||||
|
case message := <-client.channel:
|
||||||
|
assert.Contains(t, message, `"prompts":[{"name":"test.prompt"`)
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Fatal("Timed out waiting for prompts/list response")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test resources/list request
|
||||||
|
resourcesListBody, _ := json.Marshal(Request{
|
||||||
|
JsonRpc: "2.0",
|
||||||
|
ID: 3,
|
||||||
|
Method: methodResourcesList,
|
||||||
|
Params: json.RawMessage(`{}`),
|
||||||
|
})
|
||||||
|
|
||||||
|
resourcesReq := httptest.NewRequest("POST", reqURL, bytes.NewReader(resourcesListBody))
|
||||||
|
resourcesW := newSyncResponseRecorder()
|
||||||
|
|
||||||
|
// Process the request
|
||||||
|
server.handleRequest(resourcesW, resourcesReq)
|
||||||
|
|
||||||
|
// Check the response code
|
||||||
|
resourcesResp := resourcesW.Result()
|
||||||
|
assert.Equal(t, http.StatusAccepted, resourcesResp.StatusCode)
|
||||||
|
resourcesResp.Body.Close()
|
||||||
|
|
||||||
|
// Check the channel message
|
||||||
|
select {
|
||||||
|
case message := <-client.channel:
|
||||||
|
assert.Contains(t, message, `"name":"test.resource"`)
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Fatal("Timed out waiting for resources/list response")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProcessListMethods tests the list processing methods with pagination
|
||||||
|
func TestProcessListMethods(t *testing.T) {
|
||||||
|
server := &sseMcpServer{
|
||||||
|
tools: make(map[string]Tool),
|
||||||
|
prompts: make(map[string]Prompt),
|
||||||
|
resources: make(map[string]Resource),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add some test data
|
||||||
|
for i := 1; i <= 5; i++ {
|
||||||
|
tool := Tool{
|
||||||
|
Name: fmt.Sprintf("tool%d", i),
|
||||||
|
Description: fmt.Sprintf("Tool %d", i),
|
||||||
|
InputSchema: InputSchema{Type: "object"},
|
||||||
|
}
|
||||||
|
server.tools[tool.Name] = tool
|
||||||
|
|
||||||
|
prompt := Prompt{
|
||||||
|
Name: fmt.Sprintf("prompt%d", i),
|
||||||
|
Description: fmt.Sprintf("Prompt %d", i),
|
||||||
|
}
|
||||||
|
server.prompts[prompt.Name] = prompt
|
||||||
|
|
||||||
|
resource := Resource{
|
||||||
|
Name: fmt.Sprintf("resource%d", i),
|
||||||
|
URI: fmt.Sprintf("http://example.com/%d", i),
|
||||||
|
Description: fmt.Sprintf("Resource %d", i),
|
||||||
|
}
|
||||||
|
server.resources[resource.Name] = resource
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a test client
|
||||||
|
client := &mcpClient{
|
||||||
|
id: "test-client",
|
||||||
|
channel: make(chan string, 10),
|
||||||
|
initialized: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test processListTools
|
||||||
|
req := Request{
|
||||||
|
JsonRpc: "2.0",
|
||||||
|
ID: 1,
|
||||||
|
Method: methodToolsList,
|
||||||
|
Params: json.RawMessage(`{"cursor": "", "_meta": {"progressToken": "token1"}}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
server.processListTools(context.Background(), client, req)
|
||||||
|
|
||||||
|
// Read response
|
||||||
|
select {
|
||||||
|
case response := <-client.channel:
|
||||||
|
assert.Contains(t, response, `"tools":`)
|
||||||
|
assert.Contains(t, response, `"progressToken":"token1"`)
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Fatal("Timed out waiting for tools/list response")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test processListPrompts
|
||||||
|
req.ID = 2
|
||||||
|
req.Method = methodPromptsList
|
||||||
|
req.Params = json.RawMessage(`{"cursor": "next"}`)
|
||||||
|
server.processListPrompts(context.Background(), client, req)
|
||||||
|
|
||||||
|
// Read response
|
||||||
|
select {
|
||||||
|
case response := <-client.channel:
|
||||||
|
assert.Contains(t, response, `"prompts":`)
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Fatal("Timed out waiting for prompts/list response")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test processListResources
|
||||||
|
req.ID = 3
|
||||||
|
req.Method = methodResourcesList
|
||||||
|
req.Params = json.RawMessage(`{"cursor": "next"}`)
|
||||||
|
server.processListResources(context.Background(), client, req)
|
||||||
|
|
||||||
|
// Read response
|
||||||
|
select {
|
||||||
|
case response := <-client.channel:
|
||||||
|
assert.Contains(t, response, `"resources":`)
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Fatal("Timed out waiting for resources/list response")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestErrorResponseHandling tests error handling in the server
|
||||||
|
func TestErrorResponseHandling(t *testing.T) {
|
||||||
|
server := &sseMcpServer{
|
||||||
|
tools: make(map[string]Tool),
|
||||||
|
prompts: make(map[string]Prompt),
|
||||||
|
resources: make(map[string]Resource),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a test client
|
||||||
|
client := &mcpClient{
|
||||||
|
id: "test-client",
|
||||||
|
channel: make(chan string, 10),
|
||||||
|
initialized: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test invalid method
|
||||||
|
req := Request{
|
||||||
|
JsonRpc: "2.0",
|
||||||
|
ID: 1,
|
||||||
|
Method: "invalid_method",
|
||||||
|
Params: json.RawMessage(`{}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mock handleRequest by directly calling error handler
|
||||||
|
server.sendErrorResponse(context.Background(), client, req.ID, "Method not found", errCodeMethodNotFound)
|
||||||
|
|
||||||
|
// Check response
|
||||||
|
select {
|
||||||
|
case response := <-client.channel:
|
||||||
|
assert.Contains(t, response, `"error":{"code":-32601,"message":"Method not found"}`)
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Fatal("Timed out waiting for error response")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test invalid tool
|
||||||
|
toolReq := Request{
|
||||||
|
JsonRpc: "2.0",
|
||||||
|
ID: 2,
|
||||||
|
Method: methodToolsCall,
|
||||||
|
Params: json.RawMessage(`{"name":"non_existent_tool"}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call process method directly
|
||||||
|
server.processToolCall(context.Background(), client, toolReq)
|
||||||
|
|
||||||
|
// Check response
|
||||||
|
select {
|
||||||
|
case response := <-client.channel:
|
||||||
|
assert.Contains(t, response, `"error":{"code":-32602,"message":"Tool 'non_existent_tool' not found"}`)
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Fatal("Timed out waiting for error response")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test invalid prompt
|
||||||
|
promptReq := Request{
|
||||||
|
JsonRpc: "2.0",
|
||||||
|
ID: 3,
|
||||||
|
Method: methodPromptsGet,
|
||||||
|
Params: json.RawMessage(`{"name":"non_existent_prompt"}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call process method directly
|
||||||
|
server.processGetPrompt(context.Background(), client, promptReq)
|
||||||
|
|
||||||
|
// Check response
|
||||||
|
select {
|
||||||
|
case response := <-client.channel:
|
||||||
|
assert.Contains(t, response, `"error":{"code":-32602,"message":"Prompt 'non_existent_prompt' not found"}`)
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Fatal("Timed out waiting for error response")
|
||||||
|
}
|
||||||
|
}
|
||||||
23
mcp/parser.go
Normal file
23
mcp/parser.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/zeromicro/go-zero/core/mapping"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ParseArguments parses the arguments and populates the request object
|
||||||
|
func ParseArguments(args any, req any) error {
|
||||||
|
switch arguments := args.(type) {
|
||||||
|
case map[string]string:
|
||||||
|
m := make(map[string]any, len(arguments))
|
||||||
|
for k, v := range arguments {
|
||||||
|
m[k] = v
|
||||||
|
}
|
||||||
|
return mapping.UnmarshalJsonMap(m, req, mapping.WithStringValues())
|
||||||
|
case map[string]any:
|
||||||
|
return mapping.UnmarshalJsonMap(arguments, req)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unsupported argument type: %T", arguments)
|
||||||
|
}
|
||||||
|
}
|
||||||
139
mcp/parser_test.go
Normal file
139
mcp/parser_test.go
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestParseArguments_MapStringString tests parsing map[string]string arguments
|
||||||
|
func TestParseArguments_MapStringString(t *testing.T) {
|
||||||
|
// Sample request struct to populate
|
||||||
|
type requestStruct struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Count int `json:"count"`
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create test arguments
|
||||||
|
args := map[string]string{
|
||||||
|
"name": "test-name",
|
||||||
|
"message": "hello world",
|
||||||
|
"count": "42",
|
||||||
|
"enabled": "true",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a target object to populate
|
||||||
|
var req requestStruct
|
||||||
|
|
||||||
|
// Parse the arguments
|
||||||
|
err := ParseArguments(args, &req)
|
||||||
|
|
||||||
|
// Verify results
|
||||||
|
assert.NoError(t, err, "Should parse map[string]string without error")
|
||||||
|
assert.Equal(t, "test-name", req.Name, "Name should be correctly parsed")
|
||||||
|
assert.Equal(t, "hello world", req.Message, "Message should be correctly parsed")
|
||||||
|
assert.Equal(t, 42, req.Count, "Count should be correctly parsed to int")
|
||||||
|
assert.True(t, req.Enabled, "Enabled should be correctly parsed to bool")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestParseArguments_MapStringAny tests parsing map[string]any arguments
|
||||||
|
func TestParseArguments_MapStringAny(t *testing.T) {
|
||||||
|
// Sample request struct to populate
|
||||||
|
type requestStruct struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Count int `json:"count"`
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
Tags []string `json:"tags"`
|
||||||
|
Metadata map[string]string `json:"metadata"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create test arguments with mixed types
|
||||||
|
args := map[string]any{
|
||||||
|
"name": "test-name",
|
||||||
|
"message": "hello world",
|
||||||
|
"count": 42, // note: this is already an int
|
||||||
|
"enabled": true, // note: this is already a bool
|
||||||
|
"tags": []string{"tag1", "tag2"},
|
||||||
|
"metadata": map[string]string{
|
||||||
|
"key1": "value1",
|
||||||
|
"key2": "value2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a target object to populate
|
||||||
|
var req requestStruct
|
||||||
|
|
||||||
|
// Parse the arguments
|
||||||
|
err := ParseArguments(args, &req)
|
||||||
|
|
||||||
|
// Verify results
|
||||||
|
assert.NoError(t, err, "Should parse map[string]any without error")
|
||||||
|
assert.Equal(t, "test-name", req.Name, "Name should be correctly parsed")
|
||||||
|
assert.Equal(t, "hello world", req.Message, "Message should be correctly parsed")
|
||||||
|
assert.Equal(t, 42, req.Count, "Count should be correctly parsed")
|
||||||
|
assert.True(t, req.Enabled, "Enabled should be correctly parsed")
|
||||||
|
assert.Equal(t, []string{"tag1", "tag2"}, req.Tags, "Tags should be correctly parsed")
|
||||||
|
assert.Equal(t, map[string]string{
|
||||||
|
"key1": "value1",
|
||||||
|
"key2": "value2",
|
||||||
|
}, req.Metadata, "Metadata should be correctly parsed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestParseArguments_UnsupportedType tests parsing with an unsupported type
|
||||||
|
func TestParseArguments_UnsupportedType(t *testing.T) {
|
||||||
|
// Sample request struct to populate
|
||||||
|
type requestStruct struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use an unsupported argument type (slice)
|
||||||
|
args := []string{"not", "a", "map"}
|
||||||
|
|
||||||
|
// Create a target object to populate
|
||||||
|
var req requestStruct
|
||||||
|
|
||||||
|
// Parse the arguments
|
||||||
|
err := ParseArguments(args, &req)
|
||||||
|
|
||||||
|
// Verify error is returned with correct message
|
||||||
|
assert.Error(t, err, "Should return error for unsupported type")
|
||||||
|
assert.Contains(t, err.Error(), "unsupported argument type", "Error should mention unsupported type")
|
||||||
|
assert.Contains(t, err.Error(), "[]string", "Error should include the actual type")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestParseArguments_EmptyMap tests parsing with empty maps
|
||||||
|
func TestParseArguments_EmptyMap(t *testing.T) {
|
||||||
|
// Sample request struct to populate
|
||||||
|
type requestStruct struct {
|
||||||
|
Name string `json:"name,optional"`
|
||||||
|
Message string `json:"message,optional"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test empty map[string]string
|
||||||
|
t.Run("EmptyMapStringString", func(t *testing.T) {
|
||||||
|
args := map[string]string{}
|
||||||
|
var req requestStruct
|
||||||
|
|
||||||
|
err := ParseArguments(args, &req)
|
||||||
|
|
||||||
|
assert.NoError(t, err, "Should parse empty map[string]string without error")
|
||||||
|
assert.Empty(t, req.Name, "Name should be empty string")
|
||||||
|
assert.Empty(t, req.Message, "Message should be empty string")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test empty map[string]any
|
||||||
|
t.Run("EmptyMapStringAny", func(t *testing.T) {
|
||||||
|
args := map[string]any{}
|
||||||
|
var req requestStruct
|
||||||
|
|
||||||
|
err := ParseArguments(args, &req)
|
||||||
|
|
||||||
|
assert.NoError(t, err, "Should parse empty map[string]any without error")
|
||||||
|
assert.Empty(t, req.Name, "Name should be empty string")
|
||||||
|
assert.Empty(t, req.Message, "Message should be empty string")
|
||||||
|
})
|
||||||
|
}
|
||||||
870
mcp/readme.md
Normal file
870
mcp/readme.md
Normal file
@@ -0,0 +1,870 @@
|
|||||||
|
# Model Context Protocol (MCP) Implementation
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
This package implements the Model Context Protocol (MCP) server specification in Go, providing a framework for real-time communication between AI models and clients using Server-Sent Events (SSE). The implementation follows the standardized protocol for building AI-assisted applications with bidirectional communication capabilities.
|
||||||
|
|
||||||
|
## Core Components
|
||||||
|
|
||||||
|
### Server-Sent Events (SSE) Communication
|
||||||
|
- **Real-time Communication**: Robust SSE-based communication system that maintains persistent connections with clients
|
||||||
|
- **Connection Management**: Client registration, message broadcasting, and client cleanup mechanisms
|
||||||
|
- **Event Handling**: Event types for tools, prompts, and resources changes
|
||||||
|
|
||||||
|
### JSON-RPC Implementation
|
||||||
|
- **Request Processing**: Complete JSON-RPC request processor for handling MCP protocol methods
|
||||||
|
- **Response Formatting**: Proper response formatting according to JSON-RPC specifications
|
||||||
|
- **Error Handling**: Comprehensive error handling with appropriate error codes
|
||||||
|
|
||||||
|
### Tool Management
|
||||||
|
- **Tool Registration**: System to register custom tools with handlers
|
||||||
|
- **Tool Execution**: Mechanism to execute tool functions with proper timeout handling
|
||||||
|
- **Result Handling**: Flexible result handling supporting various return types (string, JSON, images)
|
||||||
|
|
||||||
|
### Prompt System
|
||||||
|
- **Prompt Registration**: System for registering both static and dynamic prompts
|
||||||
|
- **Argument Validation**: Validation for required arguments and default values for optional ones
|
||||||
|
- **Message Generation**: Handlers that generate properly formatted conversation messages
|
||||||
|
|
||||||
|
### Resource Management
|
||||||
|
- **Resource Registration**: System for managing and accessing external resources
|
||||||
|
- **Content Delivery**: Handlers for delivering resource content to clients on demand
|
||||||
|
- **Resource Subscription**: Mechanisms for clients to subscribe to resource updates
|
||||||
|
|
||||||
|
### Protocol Features
|
||||||
|
- **Initialization Sequence**: Proper handshaking with capability negotiation
|
||||||
|
- **Notification Handling**: Support for both standard and client-specific notifications
|
||||||
|
- **Message Routing**: Intelligent routing of requests to appropriate handlers
|
||||||
|
|
||||||
|
## Technical Highlights
|
||||||
|
|
||||||
|
### Configuration System
|
||||||
|
- **Flexible Configuration**: Configuration system with sensible defaults and customization options
|
||||||
|
- **CORS Support**: Configurable CORS settings for cross-origin requests
|
||||||
|
- **Server Information**: Proper server identification and versioning
|
||||||
|
|
||||||
|
### Client Session Management
|
||||||
|
- **Session Tracking**: Client session tracking with unique identifiers
|
||||||
|
- **Connection Health**: Ping/pong mechanism to maintain connection health
|
||||||
|
- **Initialization State**: Client initialization state tracking
|
||||||
|
|
||||||
|
### Content Handling
|
||||||
|
- **Multi-format Content**: Support for text, code, and binary content
|
||||||
|
- **MIME Type Support**: Proper MIME type identification for various content types
|
||||||
|
- **Audience Annotations**: Content audience annotations for user/assistant targeting
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
### Setting Up an MCP Server
|
||||||
|
|
||||||
|
To create and start an MCP server:
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/zeromicro/go-zero/core/conf"
|
||||||
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
|
"github.com/zeromicro/go-zero/mcp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// Load configuration from YAML file
|
||||||
|
var c mcp.McpConf
|
||||||
|
conf.MustLoad("config.yaml", &c)
|
||||||
|
|
||||||
|
// Optional: Disable stats logging
|
||||||
|
logx.DisableStat()
|
||||||
|
|
||||||
|
// Create MCP server
|
||||||
|
server := mcp.NewMcpServer(c)
|
||||||
|
|
||||||
|
// Register tools, prompts, and resources (examples below)
|
||||||
|
|
||||||
|
// Start the server and ensure it's stopped on exit
|
||||||
|
defer server.Stop()
|
||||||
|
server.Start()
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Sample configuration file (config.yaml):
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
name: mcp-server
|
||||||
|
host: localhost
|
||||||
|
port: 8080
|
||||||
|
mcp:
|
||||||
|
name: my-mcp-server
|
||||||
|
messageTimeout: 30s # Timeout for tool calls
|
||||||
|
cors:
|
||||||
|
- http://localhost:3000 # Optional CORS configuration
|
||||||
|
```
|
||||||
|
|
||||||
|
### Registering Tools
|
||||||
|
|
||||||
|
Tools allow AI models to execute custom code through the MCP protocol.
|
||||||
|
|
||||||
|
#### Basic Tool Example:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Register a simple echo tool
|
||||||
|
echoTool := mcp.Tool{
|
||||||
|
Name: "echo",
|
||||||
|
Description: "Echoes back the message provided by the user",
|
||||||
|
InputSchema: mcp.InputSchema{
|
||||||
|
Properties: map[string]any{
|
||||||
|
"message": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"description": "The message to echo back",
|
||||||
|
},
|
||||||
|
"prefix": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"description": "Optional prefix to add to the echoed message",
|
||||||
|
"default": "Echo: ",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Required: []string{"message"},
|
||||||
|
},
|
||||||
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
|
var req struct {
|
||||||
|
Message string `json:"message"`
|
||||||
|
Prefix string `json:"prefix,optional"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mcp.ParseArguments(params, &req); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse params: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
prefix := "Echo: "
|
||||||
|
if len(req.Prefix) > 0 {
|
||||||
|
prefix = req.Prefix
|
||||||
|
}
|
||||||
|
|
||||||
|
return prefix + req.Message, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
server.RegisterTool(echoTool)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Tool with Different Response Types:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Tool returning JSON data
|
||||||
|
dataTool := mcp.Tool{
|
||||||
|
Name: "data.generate",
|
||||||
|
Description: "Generates sample data in various formats",
|
||||||
|
InputSchema: mcp.InputSchema{
|
||||||
|
Properties: map[string]any{
|
||||||
|
"format": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"description": "Format of data (json, text)",
|
||||||
|
"enum": []string{"json", "text"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
|
var req struct {
|
||||||
|
Format string `json:"format"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mcp.ParseArguments(params, &req); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse params: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Format == "json" {
|
||||||
|
// Return structured data
|
||||||
|
return map[string]any{
|
||||||
|
"items": []map[string]any{
|
||||||
|
{"id": 1, "name": "Item 1"},
|
||||||
|
{"id": 2, "name": "Item 2"},
|
||||||
|
},
|
||||||
|
"count": 2,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default to text
|
||||||
|
return "Sample text data", nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
server.RegisterTool(dataTool)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Image Generation Tool Example:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Tool returning image content
|
||||||
|
imageTool := mcp.Tool{
|
||||||
|
Name: "image.generate",
|
||||||
|
Description: "Generates a simple image",
|
||||||
|
InputSchema: mcp.InputSchema{
|
||||||
|
Properties: map[string]any{
|
||||||
|
"type": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"description": "Type of image to generate",
|
||||||
|
"default": "placeholder",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
|
// Return image content directly
|
||||||
|
return mcp.ImageContent{
|
||||||
|
Data: "base64EncodedImageData...", // Base64 encoded image data
|
||||||
|
MimeType: "image/png",
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
server.RegisterTool(imageTool)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Using ToolResult for Custom Outputs:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Tool that returns a custom ToolResult type
|
||||||
|
customResultTool := mcp.Tool{
|
||||||
|
Name: "custom.result",
|
||||||
|
Description: "Returns a custom formatted result",
|
||||||
|
InputSchema: mcp.InputSchema{
|
||||||
|
Properties: map[string]any{
|
||||||
|
"resultType": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"enum": []string{"text", "image"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
|
var req struct {
|
||||||
|
ResultType string `json:"resultType"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mcp.ParseArguments(params, &req); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse params: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.ResultType == "image" {
|
||||||
|
return mcp.ToolResult{
|
||||||
|
Type: mcp.ContentTypeImage,
|
||||||
|
Content: map[string]any{
|
||||||
|
"data": "base64EncodedImageData...",
|
||||||
|
"mimeType": "image/jpeg",
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default to text
|
||||||
|
return mcp.ToolResult{
|
||||||
|
Type: mcp.ContentTypeText,
|
||||||
|
Content: "This is a text result from ToolResult",
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
server.RegisterTool(customResultTool)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Registering Prompts
|
||||||
|
|
||||||
|
Prompts are reusable conversation templates for AI models.
|
||||||
|
|
||||||
|
#### Static Prompt Example:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Register a simple static prompt with placeholders
|
||||||
|
server.RegisterPrompt(mcp.Prompt{
|
||||||
|
Name: "hello",
|
||||||
|
Description: "A simple hello prompt",
|
||||||
|
Arguments: []mcp.PromptArgument{
|
||||||
|
{
|
||||||
|
Name: "name",
|
||||||
|
Description: "The name to greet",
|
||||||
|
Required: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Content: "Say hello to {{name}} and introduce yourself as an AI assistant.",
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Dynamic Prompt with Handler Function:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Register a prompt with a dynamic handler function
|
||||||
|
server.RegisterPrompt(mcp.Prompt{
|
||||||
|
Name: "dynamic-prompt",
|
||||||
|
Description: "A prompt that uses a handler to generate dynamic content",
|
||||||
|
Arguments: []mcp.PromptArgument{
|
||||||
|
{
|
||||||
|
Name: "username",
|
||||||
|
Description: "User's name for personalized greeting",
|
||||||
|
Required: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "topic",
|
||||||
|
Description: "Topic of expertise",
|
||||||
|
Required: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) {
|
||||||
|
var req struct {
|
||||||
|
Username string `json:"username"`
|
||||||
|
Topic string `json:"topic"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mcp.ParseArguments(args, &req); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse args: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a user message
|
||||||
|
userMessage := mcp.PromptMessage{
|
||||||
|
Role: mcp.RoleUser,
|
||||||
|
Content: mcp.TextContent{
|
||||||
|
Text: fmt.Sprintf("Hello, I'm %s and I'd like to learn about %s.", req.Username, req.Topic),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create an assistant response with current time
|
||||||
|
currentTime := time.Now().Format(time.RFC1123)
|
||||||
|
assistantMessage := mcp.PromptMessage{
|
||||||
|
Role: mcp.RoleAssistant,
|
||||||
|
Content: mcp.TextContent{
|
||||||
|
Text: fmt.Sprintf("Hello %s! I'm an AI assistant and I'll help you learn about %s. The current time is %s.",
|
||||||
|
req.Username, req.Topic, currentTime),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return both messages as a conversation
|
||||||
|
return []mcp.PromptMessage{userMessage, assistantMessage}, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Multi-Message Prompt with Code Examples:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Register a prompt that provides code examples in different programming languages
|
||||||
|
server.RegisterPrompt(mcp.Prompt{
|
||||||
|
Name: "code-example",
|
||||||
|
Description: "Provides code examples in different programming languages",
|
||||||
|
Arguments: []mcp.PromptArgument{
|
||||||
|
{
|
||||||
|
Name: "language",
|
||||||
|
Description: "Programming language for the example",
|
||||||
|
Required: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "complexity",
|
||||||
|
Description: "Complexity level (simple, medium, advanced)",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) {
|
||||||
|
var req struct {
|
||||||
|
Language string `json:"language"`
|
||||||
|
Complexity string `json:"complexity,optional"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mcp.ParseArguments(args, &req); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse args: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate language
|
||||||
|
supportedLanguages := map[string]bool{"go": true, "python": true, "javascript": true, "rust": true}
|
||||||
|
if !supportedLanguages[req.Language] {
|
||||||
|
return nil, fmt.Errorf("unsupported language: %s", req.Language)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate code example based on language and complexity
|
||||||
|
var codeExample string
|
||||||
|
|
||||||
|
switch req.Language {
|
||||||
|
case "go":
|
||||||
|
if req.Complexity == "simple" {
|
||||||
|
codeExample = `
|
||||||
|
package main
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
fmt.Println("Hello, World!")
|
||||||
|
}`
|
||||||
|
} else {
|
||||||
|
codeExample = `
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
now := time.Now()
|
||||||
|
fmt.Printf("Hello, World! Current time is %s\n", now.Format(time.RFC3339))
|
||||||
|
}`
|
||||||
|
}
|
||||||
|
case "python":
|
||||||
|
// Python example code
|
||||||
|
if req.Complexity == "simple" {
|
||||||
|
codeExample = `
|
||||||
|
def greet(name):
|
||||||
|
return f"Hello, {name}!"
|
||||||
|
|
||||||
|
print(greet("World"))`
|
||||||
|
} else {
|
||||||
|
codeExample = `
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
def greet(name, include_time=False):
|
||||||
|
message = f"Hello, {name}!"
|
||||||
|
if include_time:
|
||||||
|
message += f" Current time is {datetime.datetime.now().isoformat()}"
|
||||||
|
return message
|
||||||
|
|
||||||
|
print(greet("World", include_time=True))`
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create messages array according to MCP spec
|
||||||
|
messages := []mcp.PromptMessage{
|
||||||
|
{
|
||||||
|
Role: mcp.RoleAssistant,
|
||||||
|
Content: mcp.TextContent{
|
||||||
|
Text: fmt.Sprintf("You are a helpful coding assistant specialized in %s programming.", req.Language),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: mcp.RoleUser,
|
||||||
|
Content: mcp.TextContent{
|
||||||
|
Text: fmt.Sprintf("Show me a %s example of a Hello World program in %s.", req.Complexity, req.Language),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: mcp.RoleAssistant,
|
||||||
|
Content: mcp.TextContent{
|
||||||
|
Text: fmt.Sprintf("Here's a %s example in %s:\n\n```%s%s\n```\n\nHow can I help you implement this?",
|
||||||
|
req.Complexity, req.Language, req.Language, codeExample),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return messages, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Registering Resources
|
||||||
|
|
||||||
|
Resources provide access to external content such as files or generated data.
|
||||||
|
|
||||||
|
#### Basic Resource Example:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Register a static resource
|
||||||
|
server.RegisterResource(mcp.Resource{
|
||||||
|
Name: "example-document",
|
||||||
|
URI: "file:///example/document.txt",
|
||||||
|
Description: "An example document",
|
||||||
|
MimeType: "text/plain",
|
||||||
|
Handler: func(ctx context.Context) (mcp.ResourceContent, error) {
|
||||||
|
return mcp.ResourceContent{
|
||||||
|
URI: "file:///example/document.txt",
|
||||||
|
MimeType: "text/plain",
|
||||||
|
Text: "This is an example document content.",
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Dynamic Resource with Code Example:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Register a Go code resource with dynamic handler
|
||||||
|
server.RegisterResource(mcp.Resource{
|
||||||
|
Name: "go-example",
|
||||||
|
URI: "file:///project/src/main.go",
|
||||||
|
Description: "A simple Go example with multiple files",
|
||||||
|
MimeType: "text/x-go",
|
||||||
|
Handler: func(ctx context.Context) (mcp.ResourceContent, error) {
|
||||||
|
// Return ResourceContent with all required fields
|
||||||
|
return mcp.ResourceContent{
|
||||||
|
URI: "file:///project/src/main.go",
|
||||||
|
MimeType: "text/x-go",
|
||||||
|
Text: "package main\n\nimport (\n\t\"fmt\"\n\t\"./greeting\"\n)\n\nfunc main() {\n\tfmt.Println(greeting.Hello(\"world\"))\n}",
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
// Register a companion file for the above example
|
||||||
|
server.RegisterResource(mcp.Resource{
|
||||||
|
Name: "go-greeting",
|
||||||
|
URI: "file:///project/src/greeting/greeting.go",
|
||||||
|
Description: "A greeting package for the Go example",
|
||||||
|
MimeType: "text/x-go",
|
||||||
|
Handler: func(ctx context.Context) (mcp.ResourceContent, error) {
|
||||||
|
return mcp.ResourceContent{
|
||||||
|
URI: "file:///project/src/greeting/greeting.go",
|
||||||
|
MimeType: "text/x-go",
|
||||||
|
Text: "package greeting\n\nfunc Hello(name string) string {\n\treturn \"Hello, \" + name + \"!\"\n}",
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Binary Resource Example:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Register a binary resource (like an image)
|
||||||
|
server.RegisterResource(mcp.Resource{
|
||||||
|
Name: "example-image",
|
||||||
|
URI: "file:///example/image.png",
|
||||||
|
Description: "An example image",
|
||||||
|
MimeType: "image/png",
|
||||||
|
Handler: func(ctx context.Context) (mcp.ResourceContent, error) {
|
||||||
|
// Read image from file or generate it
|
||||||
|
imageData := "base64EncodedImageData..." // Base64 encoded image data
|
||||||
|
|
||||||
|
return mcp.ResourceContent{
|
||||||
|
URI: "file:///example/image.png",
|
||||||
|
MimeType: "image/png",
|
||||||
|
Blob: imageData, // For binary data
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Using Resources in Prompts
|
||||||
|
|
||||||
|
You can embed resources in prompt responses to create rich interactions with proper MCP-compliant structure:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Register a prompt that embeds a resource
|
||||||
|
server.RegisterPrompt(mcp.Prompt{
|
||||||
|
Name: "resource-example",
|
||||||
|
Description: "A prompt that embeds a resource",
|
||||||
|
Arguments: []mcp.PromptArgument{
|
||||||
|
{
|
||||||
|
Name: "file_type",
|
||||||
|
Description: "Type of file to show (rust or go)",
|
||||||
|
Required: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) {
|
||||||
|
var req struct {
|
||||||
|
FileType string `json:"file_type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mcp.ParseArguments(args, &req); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse args: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resourceURI, mimeType, fileContent string
|
||||||
|
if req.FileType == "rust" {
|
||||||
|
resourceURI = "file:///project/src/main.rs"
|
||||||
|
mimeType = "text/x-rust"
|
||||||
|
fileContent = "fn main() {\n println!(\"Hello world!\");\n}"
|
||||||
|
} else {
|
||||||
|
resourceURI = "file:///project/src/main.go"
|
||||||
|
mimeType = "text/x-go"
|
||||||
|
fileContent = "package main\n\nimport \"fmt\"\n\nfunc main() {\n\tfmt.Println(\"Hello, world!\")\n}"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create message with embedded resource using proper MCP format
|
||||||
|
return []mcp.PromptMessage{
|
||||||
|
{
|
||||||
|
Role: mcp.RoleUser,
|
||||||
|
Content: mcp.TextContent{
|
||||||
|
Text: fmt.Sprintf("Can you explain this %s code?", req.FileType),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: mcp.RoleAssistant,
|
||||||
|
Content: mcp.EmbeddedResource{
|
||||||
|
Type: mcp.ContentTypeResource,
|
||||||
|
Resource: struct {
|
||||||
|
URI string `json:"uri"`
|
||||||
|
MimeType string `json:"mimeType"`
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
Blob string `json:"blob,omitempty"`
|
||||||
|
}{
|
||||||
|
URI: resourceURI,
|
||||||
|
MimeType: mimeType,
|
||||||
|
Text: fileContent,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: mcp.RoleAssistant,
|
||||||
|
Content: mcp.TextContent{
|
||||||
|
Text: fmt.Sprintf("Above is a simple Hello World example in %s. Let me explain how it works.", req.FileType),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Multiple File Resources Example
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Register a prompt that demonstrates embedding multiple resource files
|
||||||
|
server.RegisterPrompt(mcp.Prompt{
|
||||||
|
Name: "go-code-example",
|
||||||
|
Description: "A prompt that correctly embeds multiple resource files",
|
||||||
|
Arguments: []mcp.PromptArgument{
|
||||||
|
{
|
||||||
|
Name: "format",
|
||||||
|
Description: "How to format the code display",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) {
|
||||||
|
var req struct {
|
||||||
|
Format string `json:"format,optional"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mcp.ParseArguments(args, &req); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse args: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the Go code for multiple files
|
||||||
|
var mainGoText string = "package main\n\nimport (\n\t\"fmt\"\n\t\"./greeting\"\n)\n\nfunc main() {\n\tfmt.Println(greeting.Hello(\"world\"))\n}"
|
||||||
|
var greetingGoText string = "package greeting\n\nfunc Hello(name string) string {\n\treturn \"Hello, \" + name + \"!\"\n}"
|
||||||
|
|
||||||
|
// Create message with properly formatted embedded resource per MCP spec
|
||||||
|
messages := []mcp.PromptMessage{
|
||||||
|
{
|
||||||
|
Role: mcp.RoleUser,
|
||||||
|
Content: mcp.TextContent{
|
||||||
|
Text: "Show me a simple Go example with proper imports.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: mcp.RoleAssistant,
|
||||||
|
Content: mcp.TextContent{
|
||||||
|
Text: "Here's a simple Go example project:",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: mcp.RoleAssistant,
|
||||||
|
Content: mcp.EmbeddedResource{
|
||||||
|
Type: mcp.ContentTypeResource,
|
||||||
|
Resource: struct {
|
||||||
|
URI string `json:"uri"`
|
||||||
|
MimeType string `json:"mimeType"`
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
Blob string `json:"blob,omitempty"`
|
||||||
|
}{
|
||||||
|
URI: "file:///project/src/main.go",
|
||||||
|
MimeType: "text/x-go",
|
||||||
|
Text: mainGoText,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add explanation and additional file if requested
|
||||||
|
if req.Format == "with_explanation" {
|
||||||
|
messages = append(messages, mcp.PromptMessage{
|
||||||
|
Role: mcp.RoleAssistant,
|
||||||
|
Content: mcp.TextContent{
|
||||||
|
Text: "This example demonstrates a simple Go application with modular structure. The main.go file imports from a local 'greeting' package that provides the Hello function.",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
// Also show the greeting.go file with correct resource format
|
||||||
|
messages = append(messages, mcp.PromptMessage{
|
||||||
|
Role: mcp.RoleAssistant,
|
||||||
|
Content: mcp.EmbeddedResource{
|
||||||
|
Type: mcp.ContentTypeResource,
|
||||||
|
Resource: struct {
|
||||||
|
URI string `json:"uri"`
|
||||||
|
MimeType string `json:"mimeType"`
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
Blob string `json:"blob,omitempty"`
|
||||||
|
}{
|
||||||
|
URI: "file:///project/src/greeting/greeting.go",
|
||||||
|
MimeType: "text/x-go",
|
||||||
|
Text: greetingGoText,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return messages, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Complete Application Example
|
||||||
|
|
||||||
|
Here's a complete example demonstrating all the components:
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/zeromicro/go-zero/core/conf"
|
||||||
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
|
"github.com/zeromicro/go-zero/mcp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// Load configuration
|
||||||
|
var c mcp.McpConf
|
||||||
|
if err := conf.Load("config.yaml", &c); err != nil {
|
||||||
|
log.Fatalf("Failed to load config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up logging
|
||||||
|
logx.DisableStat()
|
||||||
|
|
||||||
|
// Create MCP server
|
||||||
|
server := mcp.NewMcpServer(c)
|
||||||
|
defer server.Stop()
|
||||||
|
|
||||||
|
// Register a simple echo tool
|
||||||
|
echoTool := mcp.Tool{
|
||||||
|
Name: "echo",
|
||||||
|
Description: "Echoes back the message provided by the user",
|
||||||
|
InputSchema: mcp.InputSchema{
|
||||||
|
Properties: map[string]any{
|
||||||
|
"message": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"description": "The message to echo back",
|
||||||
|
},
|
||||||
|
"prefix": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"description": "Optional prefix to add to the echoed message",
|
||||||
|
"default": "Echo: ",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Required: []string{"message"},
|
||||||
|
},
|
||||||
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
|
var req struct {
|
||||||
|
Message string `json:"message"`
|
||||||
|
Prefix string `json:"prefix,optional"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mcp.ParseArguments(params, &req); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse args: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
prefix := "Echo: "
|
||||||
|
if len(req.Prefix) > 0 {
|
||||||
|
prefix = req.Prefix
|
||||||
|
}
|
||||||
|
|
||||||
|
return prefix + req.Message, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
server.RegisterTool(echoTool)
|
||||||
|
|
||||||
|
// Register a static prompt
|
||||||
|
server.RegisterPrompt(mcp.Prompt{
|
||||||
|
Name: "greeting",
|
||||||
|
Description: "A simple greeting prompt",
|
||||||
|
Arguments: []mcp.PromptArgument{
|
||||||
|
{
|
||||||
|
Name: "name",
|
||||||
|
Description: "The name to greet",
|
||||||
|
Required: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Content: "Hello {{name}}! How can I assist you today?",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Register a dynamic prompt
|
||||||
|
server.RegisterPrompt(mcp.Prompt{
|
||||||
|
Name: "dynamic-prompt",
|
||||||
|
Description: "A prompt that uses a handler to generate dynamic content",
|
||||||
|
Arguments: []mcp.PromptArgument{
|
||||||
|
{
|
||||||
|
Name: "username",
|
||||||
|
Description: "User's name for personalized greeting",
|
||||||
|
Required: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "topic",
|
||||||
|
Description: "Topic of expertise",
|
||||||
|
Required: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) {
|
||||||
|
var req struct {
|
||||||
|
Username string `json:"username"`
|
||||||
|
Topic string `json:"topic"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mcp.ParseArguments(args, &req); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse args: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create messages with current time
|
||||||
|
currentTime := time.Now().Format(time.RFC1123)
|
||||||
|
return []mcp.PromptMessage{
|
||||||
|
{
|
||||||
|
Role: mcp.RoleUser,
|
||||||
|
Content: mcp.TextContent{
|
||||||
|
Text: fmt.Sprintf("Hello, I'm %s and I'd like to learn about %s.", req.Username, req.Topic),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: mcp.RoleAssistant,
|
||||||
|
Content: mcp.TextContent{
|
||||||
|
Text: fmt.Sprintf("Hello %s! I'm an AI assistant and I'll help you learn about %s. The current time is %s.",
|
||||||
|
req.Username, req.Topic, currentTime),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
// Register a resource
|
||||||
|
server.RegisterResource(mcp.Resource{
|
||||||
|
Name: "example-doc",
|
||||||
|
URI: "file:///example/doc.txt",
|
||||||
|
Description: "An example document",
|
||||||
|
MimeType: "text/plain",
|
||||||
|
Handler: func(ctx context.Context) (mcp.ResourceContent, error) {
|
||||||
|
return mcp.ResourceContent{
|
||||||
|
URI: "file:///example/doc.txt",
|
||||||
|
MimeType: "text/plain",
|
||||||
|
Text: "This is the content of the example document.",
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
// Start the server
|
||||||
|
fmt.Printf("Starting MCP server on %s:%d\n", c.Host, c.Port)
|
||||||
|
server.Start()
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Error Handling
|
||||||
|
|
||||||
|
The MCP implementation provides comprehensive error handling:
|
||||||
|
|
||||||
|
- Tool execution errors are properly reported back to clients
|
||||||
|
- Missing or invalid parameters are detected and reported with appropriate error codes
|
||||||
|
- Resource and prompt lookup failures are handled gracefully
|
||||||
|
- Timeout handling for long-running tool executions using context
|
||||||
|
- Panic recovery to prevent server crashes
|
||||||
|
|
||||||
|
## Advanced Features
|
||||||
|
|
||||||
|
- **Annotations**: Add audience and priority metadata to content
|
||||||
|
- **Content Types**: Support for text, images, audio, and other content formats
|
||||||
|
- **Embedded Resources**: Include file resources directly in prompt responses
|
||||||
|
- **Context Awareness**: All handlers receive context.Context for timeout and cancellation support
|
||||||
|
- **Progress Tokens**: Support for tracking progress of long-running operations
|
||||||
|
- **Customizable Timeouts**: Configure execution timeouts for tools and operations
|
||||||
|
|
||||||
|
## Performance Considerations
|
||||||
|
|
||||||
|
- Tool execution runs with configurable timeouts to prevent blocking
|
||||||
|
- Efficient client tracking and cleanup to prevent resource leaks
|
||||||
|
- Proper concurrency handling with mutex protection for shared resources
|
||||||
|
- Buffered message channels to prevent blocking on client message delivery
|
||||||
940
mcp/server.go
Normal file
940
mcp/server.go
Normal file
@@ -0,0 +1,940 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
|
"github.com/zeromicro/go-zero/rest"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewMcpServer(c McpConf) McpServer {
|
||||||
|
var server *rest.Server
|
||||||
|
if len(c.Mcp.Cors) == 0 {
|
||||||
|
server = rest.MustNewServer(c.RestConf)
|
||||||
|
} else {
|
||||||
|
server = rest.MustNewServer(c.RestConf, rest.WithCors(c.Mcp.Cors...))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(c.Mcp.Name) == 0 {
|
||||||
|
c.Mcp.Name = c.Name
|
||||||
|
}
|
||||||
|
if len(c.Mcp.BaseUrl) == 0 {
|
||||||
|
c.Mcp.BaseUrl = fmt.Sprintf("http://localhost:%d", c.Port)
|
||||||
|
}
|
||||||
|
|
||||||
|
s := &sseMcpServer{
|
||||||
|
conf: c,
|
||||||
|
server: server,
|
||||||
|
clients: make(map[string]*mcpClient),
|
||||||
|
tools: make(map[string]Tool),
|
||||||
|
prompts: make(map[string]Prompt),
|
||||||
|
resources: make(map[string]Resource),
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSE endpoint for real-time updates
|
||||||
|
s.server.AddRoute(rest.Route{
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Path: s.conf.Mcp.SseEndpoint,
|
||||||
|
Handler: s.handleSSE,
|
||||||
|
}, rest.WithSSE(), rest.WithTimeout(c.Mcp.SseTimeout))
|
||||||
|
|
||||||
|
// JSON-RPC message endpoint for regular requests
|
||||||
|
s.server.AddRoute(rest.Route{
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Path: s.conf.Mcp.MessageEndpoint,
|
||||||
|
Handler: s.handleRequest,
|
||||||
|
}, rest.WithTimeout(c.Mcp.MessageTimeout))
|
||||||
|
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterPrompt registers a new prompt with the server
|
||||||
|
func (s *sseMcpServer) RegisterPrompt(prompt Prompt) {
|
||||||
|
s.promptsLock.Lock()
|
||||||
|
s.prompts[prompt.Name] = prompt
|
||||||
|
s.promptsLock.Unlock()
|
||||||
|
// Notify clients about the new prompt
|
||||||
|
s.broadcast(eventPromptsListChanged, map[string][]Prompt{keyPrompts: {prompt}})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterResource registers a new resource with the server
|
||||||
|
func (s *sseMcpServer) RegisterResource(resource Resource) {
|
||||||
|
s.resourcesLock.Lock()
|
||||||
|
s.resources[resource.URI] = resource
|
||||||
|
s.resourcesLock.Unlock()
|
||||||
|
// Notify clients about the new resource
|
||||||
|
s.broadcast(eventResourcesListChanged, map[string][]Resource{keyResources: {resource}})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterTool registers a new tool with the server
|
||||||
|
func (s *sseMcpServer) RegisterTool(tool Tool) error {
|
||||||
|
if tool.Handler == nil {
|
||||||
|
return fmt.Errorf("tool '%s' has no handler function", tool.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.toolsLock.Lock()
|
||||||
|
s.tools[tool.Name] = tool
|
||||||
|
s.toolsLock.Unlock()
|
||||||
|
// Notify clients about the new tool
|
||||||
|
s.broadcast(eventToolsListChanged, map[string][]Tool{keyTools: {tool}})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start implements McpServer.
|
||||||
|
func (s *sseMcpServer) Start() {
|
||||||
|
s.server.Start()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *sseMcpServer) Stop() {
|
||||||
|
s.server.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
// broadcast sends a message to all connected clients
|
||||||
|
// It uses Server-Sent Events (SSE) format for real-time communication
|
||||||
|
func (s *sseMcpServer) broadcast(event string, data any) {
|
||||||
|
jsonData, err := json.Marshal(data)
|
||||||
|
if err != nil {
|
||||||
|
logx.Errorf("Failed to marshal broadcast data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lock only while reading the clients map
|
||||||
|
s.clientsLock.Lock()
|
||||||
|
clients := make([]*mcpClient, 0, len(s.clients))
|
||||||
|
for _, client := range s.clients {
|
||||||
|
clients = append(clients, client)
|
||||||
|
}
|
||||||
|
s.clientsLock.Unlock()
|
||||||
|
|
||||||
|
clientCount := len(clients)
|
||||||
|
if clientCount == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logx.Infof("Broadcasting event '%s' to %d clients", event, clientCount)
|
||||||
|
|
||||||
|
// Use CRLF line endings as per SSE specification
|
||||||
|
message := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", event, string(jsonData))
|
||||||
|
|
||||||
|
// Send messages without holding the lock
|
||||||
|
for _, client := range clients {
|
||||||
|
select {
|
||||||
|
case client.channel <- message:
|
||||||
|
// Message sent successfully
|
||||||
|
default:
|
||||||
|
// Channel buffer is full, log warning and continue
|
||||||
|
logx.Errorf("Client channel buffer full, dropping message for client %s", client.id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanupClient removes a client from the active clients map
|
||||||
|
func (s *sseMcpServer) cleanupClient(sessionId string) {
|
||||||
|
s.clientsLock.Lock()
|
||||||
|
defer s.clientsLock.Unlock()
|
||||||
|
|
||||||
|
if client, exists := s.clients[sessionId]; exists {
|
||||||
|
// Close the channel to signal any goroutines waiting on it
|
||||||
|
close(client.channel)
|
||||||
|
// Remove from active clients
|
||||||
|
delete(s.clients, sessionId)
|
||||||
|
logx.Infof("Cleaned up client %s (remaining clients: %d)", sessionId, len(s.clients))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleRequest handles MCP JSON-RPC requests
|
||||||
|
func (s *sseMcpServer) handleRequest(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Extract sessionId from query parameters
|
||||||
|
sessionId := r.URL.Query().Get(sessionIdKey)
|
||||||
|
if len(sessionId) == 0 {
|
||||||
|
http.Error(w, fmt.Sprintf("Missing %s", sessionIdKey), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the client with this sessionId exists
|
||||||
|
s.clientsLock.Lock()
|
||||||
|
client, exists := s.clients[sessionId]
|
||||||
|
s.clientsLock.Unlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
http.Error(w, fmt.Sprintf("Invalid or expired %s", sessionIdKey), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req Request
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
http.Error(w, "Invalid request", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// For notification methods (no ID), we don't send a response
|
||||||
|
isNotification, err := req.isNotification()
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Invalid request.ID", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusAccepted)
|
||||||
|
|
||||||
|
// Special handling for initialization sequence
|
||||||
|
// Always allow initialize and notifications/initialized regardless of client state
|
||||||
|
if req.Method == methodInitialize {
|
||||||
|
logx.Infof("Processing initialize request with ID: %v", req.ID)
|
||||||
|
s.processInitialize(r.Context(), client, req)
|
||||||
|
logx.Infof("Sent initialize response for ID: %v, waiting for notifications/initialized", req.ID)
|
||||||
|
return
|
||||||
|
} else if req.Method == methodNotificationsInitialized {
|
||||||
|
// Handle initialized notification
|
||||||
|
logx.Info("Received notifications/initialized notification")
|
||||||
|
if !isNotification {
|
||||||
|
s.sendErrorResponse(r.Context(), client, req.ID,
|
||||||
|
"Method should be used as a notification", errCodeInvalidRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.processNotificationInitialized(client)
|
||||||
|
return
|
||||||
|
} else if !client.initialized && req.Method != methodNotificationsCancelled {
|
||||||
|
// Block most requests until client is initialized (except for cancellations)
|
||||||
|
s.sendErrorResponse(r.Context(), client, req.ID,
|
||||||
|
"Client not fully initialized, waiting for notifications/initialized",
|
||||||
|
errCodeClientNotInitialized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process normal requests only after initialization
|
||||||
|
switch req.Method {
|
||||||
|
case methodToolsCall:
|
||||||
|
logx.Infof("Received tools call request with ID: %v", req.ID)
|
||||||
|
s.processToolCall(r.Context(), client, req)
|
||||||
|
logx.Infof("Sent tools call response for ID: %v", req.ID)
|
||||||
|
case methodToolsList:
|
||||||
|
logx.Infof("Processing tools/list request with ID: %v", req.ID)
|
||||||
|
s.processListTools(r.Context(), client, req)
|
||||||
|
logx.Infof("Sent tools/list response for ID: %v", req.ID)
|
||||||
|
case methodPromptsList:
|
||||||
|
logx.Infof("Processing prompts/list request with ID: %v", req.ID)
|
||||||
|
s.processListPrompts(r.Context(), client, req)
|
||||||
|
logx.Infof("Sent prompts/list response for ID: %v", req.ID)
|
||||||
|
case methodPromptsGet:
|
||||||
|
logx.Infof("Processing prompts/get request with ID: %v", req.ID)
|
||||||
|
s.processGetPrompt(r.Context(), client, req)
|
||||||
|
logx.Infof("Sent prompts/get response for ID: %v", req.ID)
|
||||||
|
case methodResourcesList:
|
||||||
|
logx.Infof("Processing resources/list request with ID: %v", req.ID)
|
||||||
|
s.processListResources(r.Context(), client, req)
|
||||||
|
logx.Infof("Sent resources/list response for ID: %v", req.ID)
|
||||||
|
case methodResourcesRead:
|
||||||
|
logx.Infof("Processing resources/read request with ID: %v", req.ID)
|
||||||
|
s.processResourcesRead(r.Context(), client, req)
|
||||||
|
logx.Infof("Sent resources/read response for ID: %v", req.ID)
|
||||||
|
case methodResourcesSubscribe:
|
||||||
|
logx.Infof("Processing resources/subscribe request with ID: %v", req.ID)
|
||||||
|
s.processResourceSubscribe(r.Context(), client, req)
|
||||||
|
logx.Infof("Sent resources/subscribe response for ID: %v", req.ID)
|
||||||
|
case methodPing:
|
||||||
|
logx.Infof("Processing ping request with ID: %v", req.ID)
|
||||||
|
s.processPing(r.Context(), client, req)
|
||||||
|
case methodNotificationsCancelled:
|
||||||
|
logx.Infof("Received notifications/cancelled notification: %v", req.ID)
|
||||||
|
s.processNotificationCancelled(r.Context(), client, req)
|
||||||
|
default:
|
||||||
|
logx.Infof("Unknown method: %s from client: %v", req.Method, req.ID)
|
||||||
|
s.sendErrorResponse(r.Context(), client, req.ID, "Method not found", errCodeMethodNotFound)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleSSE handles Server-Sent Events connections
|
||||||
|
func (s *sseMcpServer) handleSSE(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Generate a unique session ID for this client
|
||||||
|
sessionId := uuid.New().String()
|
||||||
|
|
||||||
|
// Create new client with buffered channel to prevent blocking
|
||||||
|
client := &mcpClient{
|
||||||
|
id: sessionId,
|
||||||
|
channel: make(chan string, eventChanSize),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add client to active clients map
|
||||||
|
s.clientsLock.Lock()
|
||||||
|
s.clients[sessionId] = client
|
||||||
|
activeClients := len(s.clients)
|
||||||
|
s.clientsLock.Unlock()
|
||||||
|
|
||||||
|
logx.Infof("New SSE connection established for client %s (active clients: %d)",
|
||||||
|
sessionId, activeClients)
|
||||||
|
|
||||||
|
// Set proper SSE headers
|
||||||
|
w.Header().Set("Transfer-Encoding", "chunked")
|
||||||
|
|
||||||
|
// Enable streaming
|
||||||
|
flusher, ok := w.(http.Flusher)
|
||||||
|
if !ok {
|
||||||
|
logx.Error("Streaming not supported by the underlying http.ResponseWriter")
|
||||||
|
http.Error(w, "Streaming not supported", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send the message endpoint URL to the client
|
||||||
|
endpoint := fmt.Sprintf("%s%s?%s=%s",
|
||||||
|
s.conf.Mcp.BaseUrl, s.conf.Mcp.MessageEndpoint, sessionIdKey, sessionId)
|
||||||
|
|
||||||
|
// Format and send the endpoint message
|
||||||
|
endpointMsg := formatSSEMessage(eventEndpoint, []byte(endpoint))
|
||||||
|
if _, err := fmt.Fprint(w, endpointMsg); err != nil {
|
||||||
|
logx.Errorf("Failed to send endpoint message to client %s: %v", sessionId, err)
|
||||||
|
s.cleanupClient(sessionId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
flusher.Flush()
|
||||||
|
|
||||||
|
// Set up keep-alive ping and client cleanup
|
||||||
|
ticker := time.NewTicker(pingInterval.Load())
|
||||||
|
defer func() {
|
||||||
|
ticker.Stop()
|
||||||
|
s.cleanupClient(sessionId)
|
||||||
|
logx.Infof("SSE connection closed for client %s", sessionId)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Message processing loop
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case message, ok := <-client.channel:
|
||||||
|
if !ok {
|
||||||
|
// Channel was closed, end connection
|
||||||
|
logx.Infof("Client channel was closed for %s", sessionId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write message to the response
|
||||||
|
if _, err := fmt.Fprint(w, message); err != nil {
|
||||||
|
logx.Infof("Failed to write message to client %s: %v", sessionId, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
flusher.Flush()
|
||||||
|
case <-ticker.C:
|
||||||
|
// Send keep-alive ping to maintain connection
|
||||||
|
ping := fmt.Sprintf(`{"type":"ping","timestamp":"%s"}`, time.Now().String())
|
||||||
|
pingMsg := formatSSEMessage("ping", []byte(ping))
|
||||||
|
if _, err := fmt.Fprint(w, pingMsg); err != nil {
|
||||||
|
logx.Errorf("Failed to send ping to client %s, closing connection: %v", sessionId, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
flusher.Flush()
|
||||||
|
case <-r.Context().Done():
|
||||||
|
// Client disconnected or request was canceled or timed out
|
||||||
|
logx.Infof("Client %s disconnected: context done", sessionId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// processInitialize processes the initialize request
|
||||||
|
func (s *sseMcpServer) processInitialize(ctx context.Context, client *mcpClient, req Request) {
|
||||||
|
// Create a proper JSON-RPC response that preserves the client's request ID
|
||||||
|
result := initializationResponse{
|
||||||
|
ProtocolVersion: s.conf.Mcp.ProtocolVersion,
|
||||||
|
Capabilities: capabilities{
|
||||||
|
Prompts: struct {
|
||||||
|
ListChanged bool `json:"listChanged"`
|
||||||
|
}{
|
||||||
|
ListChanged: true,
|
||||||
|
},
|
||||||
|
Resources: struct {
|
||||||
|
Subscribe bool `json:"subscribe"`
|
||||||
|
ListChanged bool `json:"listChanged"`
|
||||||
|
}{
|
||||||
|
Subscribe: true,
|
||||||
|
ListChanged: true,
|
||||||
|
},
|
||||||
|
Tools: struct {
|
||||||
|
ListChanged bool `json:"listChanged"`
|
||||||
|
}{
|
||||||
|
ListChanged: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ServerInfo: serverInfo{
|
||||||
|
Name: s.conf.Mcp.Name,
|
||||||
|
Version: s.conf.Mcp.Version,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mark client as initialized
|
||||||
|
client.initialized = true
|
||||||
|
|
||||||
|
// Send response with client's original request ID
|
||||||
|
s.sendResponse(ctx, client, req.ID, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// processListTools processes the tools/list request
|
||||||
|
func (s *sseMcpServer) processListTools(ctx context.Context, client *mcpClient, req Request) {
|
||||||
|
// Extract pagination params if any
|
||||||
|
var nextCursor string
|
||||||
|
var progressToken any
|
||||||
|
|
||||||
|
// Extract meta data including progress token
|
||||||
|
if req.Params != nil {
|
||||||
|
var metaParams struct {
|
||||||
|
Cursor string `json:"cursor"`
|
||||||
|
Meta struct {
|
||||||
|
ProgressToken any `json:"progressToken"`
|
||||||
|
} `json:"_meta"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(req.Params, &metaParams); err == nil {
|
||||||
|
if len(metaParams.Cursor) > 0 {
|
||||||
|
nextCursor = metaParams.Cursor
|
||||||
|
}
|
||||||
|
progressToken = metaParams.Meta.ProgressToken
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var toolsList []Tool
|
||||||
|
s.toolsLock.Lock()
|
||||||
|
for _, tool := range s.tools {
|
||||||
|
if len(tool.InputSchema.Type) == 0 {
|
||||||
|
tool.InputSchema.Type = ContentTypeObject
|
||||||
|
}
|
||||||
|
toolsList = append(toolsList, tool)
|
||||||
|
}
|
||||||
|
s.toolsLock.Unlock()
|
||||||
|
|
||||||
|
result := ListToolsResult{
|
||||||
|
PaginatedResult: PaginatedResult{
|
||||||
|
Result: Result{},
|
||||||
|
NextCursor: Cursor(nextCursor),
|
||||||
|
},
|
||||||
|
Tools: toolsList,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add meta information if progress token was provided
|
||||||
|
if progressToken != nil {
|
||||||
|
result.Result.Meta = map[string]any{
|
||||||
|
progressTokenKey: progressToken,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.sendResponse(ctx, client, req.ID, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// processListPrompts processes the prompts/list request
|
||||||
|
func (s *sseMcpServer) processListPrompts(ctx context.Context, client *mcpClient, req Request) {
|
||||||
|
// Extract pagination params if any
|
||||||
|
var nextCursor string
|
||||||
|
if req.Params != nil {
|
||||||
|
var cursorParams struct {
|
||||||
|
Cursor string `json:"cursor"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(req.Params, &cursorParams); err == nil && cursorParams.Cursor != "" {
|
||||||
|
// If we have a valid cursor, we could use it for pagination
|
||||||
|
// For now, we're not actually implementing pagination, so this is just
|
||||||
|
// to show how it would be extracted from the request
|
||||||
|
_ = cursorParams.Cursor
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare prompt list
|
||||||
|
var promptsList []Prompt
|
||||||
|
s.promptsLock.Lock()
|
||||||
|
for _, prompt := range s.prompts {
|
||||||
|
promptsList = append(promptsList, prompt)
|
||||||
|
}
|
||||||
|
s.promptsLock.Unlock()
|
||||||
|
|
||||||
|
// In a real implementation, you'd handle pagination here
|
||||||
|
// For now, we'll return all prompts at once
|
||||||
|
result := struct {
|
||||||
|
Prompts []Prompt `json:"prompts"`
|
||||||
|
NextCursor string `json:"nextCursor,omitempty"`
|
||||||
|
Meta *struct{} `json:"_meta,omitempty"`
|
||||||
|
}{
|
||||||
|
Prompts: promptsList,
|
||||||
|
NextCursor: nextCursor,
|
||||||
|
}
|
||||||
|
|
||||||
|
s.sendResponse(ctx, client, req.ID, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// processListResources processes the resources/list request
|
||||||
|
func (s *sseMcpServer) processListResources(ctx context.Context, client *mcpClient, req Request) {
|
||||||
|
// Extract pagination params if any
|
||||||
|
var nextCursor string
|
||||||
|
var progressToken any
|
||||||
|
|
||||||
|
// Extract meta information including progress token if available
|
||||||
|
if req.Params != nil {
|
||||||
|
var metaParams PaginatedParams
|
||||||
|
if err := json.Unmarshal(req.Params, &metaParams); err == nil {
|
||||||
|
if len(metaParams.Cursor) > 0 {
|
||||||
|
nextCursor = metaParams.Cursor
|
||||||
|
}
|
||||||
|
progressToken = metaParams.Meta.ProgressToken
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var resourcesList []Resource
|
||||||
|
s.resourcesLock.Lock()
|
||||||
|
for _, resource := range s.resources {
|
||||||
|
// Create a copy without the handler function which shouldn't be sent to clients
|
||||||
|
resourceCopy := Resource{
|
||||||
|
URI: resource.URI,
|
||||||
|
Name: resource.Name,
|
||||||
|
Description: resource.Description,
|
||||||
|
MimeType: resource.MimeType,
|
||||||
|
}
|
||||||
|
resourcesList = append(resourcesList, resourceCopy)
|
||||||
|
}
|
||||||
|
s.resourcesLock.Unlock()
|
||||||
|
|
||||||
|
// Create proper ResourcesListResult according to MCP specification
|
||||||
|
result := ResourcesListResult{
|
||||||
|
PaginatedResult: PaginatedResult{
|
||||||
|
Result: Result{},
|
||||||
|
NextCursor: Cursor(nextCursor),
|
||||||
|
},
|
||||||
|
Resources: resourcesList,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add meta information if progress token was provided
|
||||||
|
if progressToken != nil {
|
||||||
|
result.Result.Meta = map[string]any{
|
||||||
|
progressTokenKey: progressToken,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.sendResponse(ctx, client, req.ID, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// processGetPrompt processes the prompts/get request
|
||||||
|
func (s *sseMcpServer) processGetPrompt(ctx context.Context, client *mcpClient, req Request) {
|
||||||
|
type GetPromptParams struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Arguments map[string]string `json:"arguments,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var params GetPromptParams
|
||||||
|
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
||||||
|
s.sendErrorResponse(ctx, client, req.ID, "Invalid parameters", errCodeInvalidParams)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if prompt exists
|
||||||
|
s.promptsLock.Lock()
|
||||||
|
prompt, exists := s.prompts[params.Name]
|
||||||
|
s.promptsLock.Unlock()
|
||||||
|
if !exists {
|
||||||
|
message := fmt.Sprintf("Prompt '%s' not found", params.Name)
|
||||||
|
s.sendErrorResponse(ctx, client, req.ID, message, errCodeInvalidParams)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logx.Infof("Processing prompt request: %s with %d arguments", prompt.Name, len(params.Arguments))
|
||||||
|
|
||||||
|
// Validate required arguments
|
||||||
|
missingArgs := validatePromptArguments(prompt, params.Arguments)
|
||||||
|
if len(missingArgs) > 0 {
|
||||||
|
message := fmt.Sprintf("Missing required arguments: %s", strings.Join(missingArgs, ", "))
|
||||||
|
s.sendErrorResponse(ctx, client, req.ID, message, errCodeInvalidParams)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure arguments are initialized to an empty map if nil
|
||||||
|
if params.Arguments == nil {
|
||||||
|
params.Arguments = make(map[string]string)
|
||||||
|
}
|
||||||
|
args := params.Arguments
|
||||||
|
|
||||||
|
// Generate messages using handler or static content
|
||||||
|
var messages []PromptMessage
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if prompt.Handler != nil {
|
||||||
|
// Use dynamic handler to generate messages
|
||||||
|
messages, err = prompt.Handler(ctx, args)
|
||||||
|
if err != nil {
|
||||||
|
logx.Errorf("Error from prompt handler: %v", err)
|
||||||
|
s.sendErrorResponse(ctx, client, req.ID,
|
||||||
|
fmt.Sprintf("Error generating prompt content: %v", err), errCodeInternalError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// No handler, generate messages from static content
|
||||||
|
var messageText string
|
||||||
|
if len(prompt.Content) > 0 {
|
||||||
|
messageText = prompt.Content
|
||||||
|
|
||||||
|
// Apply argument substitutions to static content
|
||||||
|
for key, value := range args {
|
||||||
|
placeholder := fmt.Sprintf("{{%s}}", key)
|
||||||
|
messageText = strings.Replace(messageText, placeholder, value, -1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a single user message with the content
|
||||||
|
messages = []PromptMessage{
|
||||||
|
{
|
||||||
|
Role: RoleUser,
|
||||||
|
Content: TextContent{
|
||||||
|
Text: messageText,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Construct the response according to MCP spec
|
||||||
|
result := struct {
|
||||||
|
Description string `json:"description,omitempty"`
|
||||||
|
Messages []PromptMessage `json:"messages"`
|
||||||
|
}{
|
||||||
|
Description: prompt.Description,
|
||||||
|
Messages: toTypedPromptMessages(messages),
|
||||||
|
}
|
||||||
|
|
||||||
|
s.sendResponse(ctx, client, req.ID, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// processToolCall processes the tools/call request
|
||||||
|
func (s *sseMcpServer) processToolCall(ctx context.Context, client *mcpClient, req Request) {
|
||||||
|
var toolCallParams struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Arguments map[string]any `json:"arguments,omitempty"`
|
||||||
|
Meta struct {
|
||||||
|
ProgressToken any `json:"progressToken"`
|
||||||
|
} `json:"_meta,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle different types of req.Params
|
||||||
|
// If it's a RawMessage (JSON), unmarshal it
|
||||||
|
if err := json.Unmarshal(req.Params, &toolCallParams); err != nil {
|
||||||
|
logx.Errorf("Failed to unmarshal tool call params: %v", err)
|
||||||
|
s.sendErrorResponse(ctx, client, req.ID, "Invalid tool call parameters", errCodeInvalidParams)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract progress token if available
|
||||||
|
progressToken := toolCallParams.Meta.ProgressToken
|
||||||
|
|
||||||
|
// Find the requested tool
|
||||||
|
s.toolsLock.Lock()
|
||||||
|
tool, exists := s.tools[toolCallParams.Name]
|
||||||
|
s.toolsLock.Unlock()
|
||||||
|
if !exists {
|
||||||
|
s.sendErrorResponse(ctx, client, req.ID, fmt.Sprintf("Tool '%s' not found",
|
||||||
|
toolCallParams.Name), errCodeInvalidParams)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log parameters before execution
|
||||||
|
logx.Infof("Executing tool '%s' with arguments: %#v", toolCallParams.Name, toolCallParams.Arguments)
|
||||||
|
|
||||||
|
// Execute the tool handler with timeout handling
|
||||||
|
var result any
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// Create a channel to receive the result
|
||||||
|
// make sure to have 1 size buffer to avoid channel leak if timeout
|
||||||
|
resultCh := make(chan struct {
|
||||||
|
result any
|
||||||
|
err error
|
||||||
|
}, 1)
|
||||||
|
|
||||||
|
// Execute the tool handler in a goroutine
|
||||||
|
go func() {
|
||||||
|
toolResult, toolErr := tool.Handler(ctx, toolCallParams.Arguments)
|
||||||
|
resultCh <- struct {
|
||||||
|
result any
|
||||||
|
err error
|
||||||
|
}{
|
||||||
|
result: toolResult,
|
||||||
|
err: toolErr,
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for either the result or a timeout
|
||||||
|
select {
|
||||||
|
case res := <-resultCh:
|
||||||
|
result = res.result
|
||||||
|
err = res.err
|
||||||
|
case <-ctx.Done():
|
||||||
|
// Handle request timeout
|
||||||
|
logx.Errorf("Tool execution timed out after %v: %s", s.conf.Mcp.MessageTimeout, toolCallParams.Name)
|
||||||
|
s.sendErrorResponse(ctx, client, req.ID, "Tool execution timed out", errCodeTimeout)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the base result structure with metadata
|
||||||
|
callToolResult := CallToolResult{
|
||||||
|
Result: Result{},
|
||||||
|
Content: []any{},
|
||||||
|
IsError: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add meta information if progress token was provided
|
||||||
|
if progressToken != nil {
|
||||||
|
callToolResult.Result.Meta = map[string]any{
|
||||||
|
progressTokenKey: progressToken,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if there was an error during tool execution
|
||||||
|
if err != nil {
|
||||||
|
// According to the spec, for tool-level errors (as opposed to protocol-level errors),
|
||||||
|
// we should report them inside the result with isError=true
|
||||||
|
logx.Errorf("Tool execution reported error: %v", err)
|
||||||
|
|
||||||
|
callToolResult.Content = []any{
|
||||||
|
TextContent{
|
||||||
|
Text: fmt.Sprintf("Error: %v", err),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
callToolResult.IsError = true
|
||||||
|
s.sendResponse(ctx, client, req.ID, callToolResult)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Format the response according to the CallToolResult schema
|
||||||
|
switch v := result.(type) {
|
||||||
|
case string:
|
||||||
|
// Simple string becomes text content
|
||||||
|
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||||
|
Text: v,
|
||||||
|
Annotations: &Annotations{
|
||||||
|
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
case map[string]any:
|
||||||
|
// JSON-like object becomes formatted JSON text
|
||||||
|
jsonStr, err := json.Marshal(v)
|
||||||
|
if err != nil {
|
||||||
|
jsonStr = []byte(err.Error())
|
||||||
|
}
|
||||||
|
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||||
|
Text: string(jsonStr),
|
||||||
|
Annotations: &Annotations{
|
||||||
|
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
case TextContent:
|
||||||
|
callToolResult.Content = append(callToolResult.Content, v)
|
||||||
|
case ImageContent:
|
||||||
|
callToolResult.Content = append(callToolResult.Content, v)
|
||||||
|
case []any:
|
||||||
|
callToolResult.Content = v
|
||||||
|
case ToolResult:
|
||||||
|
// Handle legacy ToolResult type
|
||||||
|
switch v.Type {
|
||||||
|
case ContentTypeText:
|
||||||
|
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||||
|
Text: fmt.Sprintf("%v", v.Content),
|
||||||
|
Annotations: &Annotations{
|
||||||
|
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
case ContentTypeImage:
|
||||||
|
if imgData, ok := v.Content.(map[string]any); ok {
|
||||||
|
callToolResult.Content = append(callToolResult.Content, ImageContent{
|
||||||
|
Data: fmt.Sprintf("%v", imgData["data"]),
|
||||||
|
MimeType: fmt.Sprintf("%v", imgData["mimeType"]),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||||
|
Text: fmt.Sprintf("%v", v.Content),
|
||||||
|
Annotations: &Annotations{
|
||||||
|
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// For any other type, convert to string
|
||||||
|
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||||
|
Text: fmt.Sprintf("%v", v),
|
||||||
|
Annotations: &Annotations{
|
||||||
|
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
callToolResult.Content = toTypedContents(callToolResult.Content)
|
||||||
|
logx.Infof("Tool call result: %#v", callToolResult)
|
||||||
|
|
||||||
|
s.sendResponse(ctx, client, req.ID, callToolResult)
|
||||||
|
}
|
||||||
|
|
||||||
|
// processResourcesRead processes the resources/read request
|
||||||
|
func (s *sseMcpServer) processResourcesRead(ctx context.Context, client *mcpClient, req Request) {
|
||||||
|
var params ResourceReadParams
|
||||||
|
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
||||||
|
s.sendErrorResponse(ctx, client, req.ID, "Invalid parameters", errCodeInvalidParams)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find resource that matches the URI
|
||||||
|
s.resourcesLock.Lock()
|
||||||
|
resource, exists := s.resources[params.URI]
|
||||||
|
s.resourcesLock.Unlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
s.sendErrorResponse(ctx, client, req.ID, fmt.Sprintf("Resource with URI '%s' not found",
|
||||||
|
params.URI), errCodeResourceNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no handler is provided, return an empty content array
|
||||||
|
if resource.Handler == nil {
|
||||||
|
result := ResourceReadResult{
|
||||||
|
Contents: []ResourceContent{
|
||||||
|
{
|
||||||
|
URI: params.URI,
|
||||||
|
MimeType: resource.MimeType,
|
||||||
|
Text: "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
s.sendResponse(ctx, client, req.ID, result)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute the resource handler
|
||||||
|
content, err := resource.Handler(ctx)
|
||||||
|
if err != nil {
|
||||||
|
s.sendErrorResponse(ctx, client, req.ID, fmt.Sprintf("Error reading resource: %v", err),
|
||||||
|
errCodeInternalError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure the URI is set if not already provided by the handler
|
||||||
|
if len(content.URI) == 0 {
|
||||||
|
content.URI = params.URI
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure MimeType is set if available from the resource definition
|
||||||
|
if len(content.MimeType) == 0 && len(resource.MimeType) > 0 {
|
||||||
|
content.MimeType = resource.MimeType
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create response with contents from the handler
|
||||||
|
// The MCP specification requires a contents array
|
||||||
|
result := ResourceReadResult{
|
||||||
|
Contents: []ResourceContent{content},
|
||||||
|
}
|
||||||
|
|
||||||
|
s.sendResponse(ctx, client, req.ID, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// processResourceSubscribe processes the resources/subscribe request
|
||||||
|
func (s *sseMcpServer) processResourceSubscribe(ctx context.Context, client *mcpClient, req Request) {
|
||||||
|
var params ResourceSubscribeParams
|
||||||
|
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
||||||
|
s.sendErrorResponse(ctx, client, req.ID, "Invalid parameters", errCodeInvalidParams)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the resource exists
|
||||||
|
s.resourcesLock.Lock()
|
||||||
|
_, exists := s.resources[params.URI]
|
||||||
|
s.resourcesLock.Unlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
s.sendErrorResponse(ctx, client, req.ID, fmt.Sprintf("Resource with URI '%s' not found",
|
||||||
|
params.URI), errCodeResourceNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send success response for the subscription
|
||||||
|
s.sendResponse(ctx, client, req.ID, struct{}{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// processNotificationCancelled processes the notifications/cancelled notification
|
||||||
|
func (s *sseMcpServer) processNotificationCancelled(ctx context.Context, client *mcpClient, req Request) {
|
||||||
|
// Extract the requestId that was canceled
|
||||||
|
type CancelParams struct {
|
||||||
|
RequestId int64 `json:"requestId"`
|
||||||
|
Reason string `json:"reason"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var params CancelParams
|
||||||
|
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
||||||
|
logx.Errorf("Failed to parse cancellation params: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logx.Infof("Request %d was cancelled by client. Reason: %s", params.RequestId, params.Reason)
|
||||||
|
}
|
||||||
|
|
||||||
|
// processNotificationInitialized processes the notifications/initialized notification
|
||||||
|
func (s *sseMcpServer) processNotificationInitialized(client *mcpClient) {
|
||||||
|
// Mark the client as properly initialized
|
||||||
|
client.initialized = true
|
||||||
|
logx.Infof("Client %s is now fully initialized and ready for normal operations", client.id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// processPing processes the ping request and responds immediately
|
||||||
|
func (s *sseMcpServer) processPing(ctx context.Context, client *mcpClient, req Request) {
|
||||||
|
// A ping request should simply respond with an empty result to confirm the server is alive
|
||||||
|
logx.Infof("Received ping request with ID: %d", req.ID)
|
||||||
|
|
||||||
|
// Send an empty response with client's original request ID
|
||||||
|
s.sendResponse(ctx, client, req.ID, struct{}{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendErrorResponse sends an error response via the SSE channel
|
||||||
|
func (s *sseMcpServer) sendErrorResponse(ctx context.Context, client *mcpClient,
|
||||||
|
id any, message string, code int) {
|
||||||
|
errorResponse := struct {
|
||||||
|
JsonRpc string `json:"jsonrpc"`
|
||||||
|
ID any `json:"id"`
|
||||||
|
Error errorMessage `json:"error"`
|
||||||
|
}{
|
||||||
|
JsonRpc: jsonRpcVersion,
|
||||||
|
ID: id,
|
||||||
|
Error: errorMessage{
|
||||||
|
Code: code,
|
||||||
|
Message: message,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// all fields are primitive types, impossible to fail
|
||||||
|
jsonData, _ := json.Marshal(errorResponse)
|
||||||
|
// Use CRLF line endings as requested
|
||||||
|
sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData))
|
||||||
|
logx.Infof("Sending error for ID %v: %s", id, sseMessage)
|
||||||
|
|
||||||
|
// cannot receive from ctx.Done() because we're sending to the channel for SSE messages
|
||||||
|
select {
|
||||||
|
case client.channel <- sseMessage:
|
||||||
|
default:
|
||||||
|
// Channel buffer is full, log warning and continue
|
||||||
|
logx.Infof("Client %s channel is full while sending error response with ID %d", client.id, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendResponse sends a success response via the SSE channel
|
||||||
|
func (s *sseMcpServer) sendResponse(ctx context.Context, client *mcpClient, id any, result any) {
|
||||||
|
response := Response{
|
||||||
|
JsonRpc: jsonRpcVersion,
|
||||||
|
ID: id,
|
||||||
|
Result: result,
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(response)
|
||||||
|
if err != nil {
|
||||||
|
s.sendErrorResponse(ctx, client, id, "Failed to marshal response", errCodeInternalError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use CRLF line endings as requested
|
||||||
|
sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData))
|
||||||
|
logx.Infof("Sending response for ID %v: %s", id, sseMessage)
|
||||||
|
|
||||||
|
// cannot receive from ctx.Done() because we're sending to the channel for SSE messages
|
||||||
|
select {
|
||||||
|
case client.channel <- sseMessage:
|
||||||
|
default:
|
||||||
|
// Channel buffer is full, log warning and continue
|
||||||
|
logx.Infof("Client %s channel is full while sending response with ID %v", client.id, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
3451
mcp/server_test.go
Normal file
3451
mcp/server_test.go
Normal file
File diff suppressed because it is too large
Load Diff
317
mcp/types.go
Normal file
317
mcp/types.go
Normal file
@@ -0,0 +1,317 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/zeromicro/go-zero/rest"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Cursor is an opaque token used for pagination
|
||||||
|
type Cursor string
|
||||||
|
|
||||||
|
// Request represents a generic MCP request following JSON-RPC 2.0 specification
|
||||||
|
type Request struct {
|
||||||
|
SessionId string `form:"session_id"` // Session identifier for client tracking
|
||||||
|
JsonRpc string `json:"jsonrpc"` // Must be "2.0" per JSON-RPC spec
|
||||||
|
ID any `json:"id"` // Request identifier for matching responses
|
||||||
|
Method string `json:"method"` // Method name to invoke
|
||||||
|
Params json.RawMessage `json:"params"` // Parameters for the method
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r Request) isNotification() (bool, error) {
|
||||||
|
switch val := r.ID.(type) {
|
||||||
|
case int:
|
||||||
|
return val == 0, nil
|
||||||
|
case int64:
|
||||||
|
return val == 0, nil
|
||||||
|
case float64:
|
||||||
|
return val == 0.0, nil
|
||||||
|
case string:
|
||||||
|
return len(val) == 0, nil
|
||||||
|
case nil:
|
||||||
|
return true, nil
|
||||||
|
default:
|
||||||
|
return false, fmt.Errorf("invalid type %T", val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type PaginatedParams struct {
|
||||||
|
Cursor string `json:"cursor"`
|
||||||
|
Meta struct {
|
||||||
|
ProgressToken any `json:"progressToken"`
|
||||||
|
} `json:"_meta"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Result is the base interface for all results
|
||||||
|
type Result struct {
|
||||||
|
Meta map[string]any `json:"_meta,omitempty"` // Optional metadata
|
||||||
|
}
|
||||||
|
|
||||||
|
// PaginatedResult is a base for results that support pagination
|
||||||
|
type PaginatedResult struct {
|
||||||
|
Result
|
||||||
|
NextCursor Cursor `json:"nextCursor,omitempty"` // Opaque token for fetching next page
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListToolsResult represents the response to a tools/list request
|
||||||
|
type ListToolsResult struct {
|
||||||
|
PaginatedResult
|
||||||
|
Tools []Tool `json:"tools"` // List of available tools
|
||||||
|
}
|
||||||
|
|
||||||
|
// Message Content Types
|
||||||
|
|
||||||
|
// RoleType represents the sender or recipient of messages in a conversation
|
||||||
|
type RoleType string
|
||||||
|
|
||||||
|
// PromptArgument defines a single argument that can be passed to a prompt
|
||||||
|
type PromptArgument struct {
|
||||||
|
Name string `json:"name"` // Argument name
|
||||||
|
Description string `json:"description,omitempty"` // Human-readable description
|
||||||
|
Required bool `json:"required,omitempty"` // Whether this argument is required
|
||||||
|
}
|
||||||
|
|
||||||
|
// PromptHandler is a function that dynamically generates prompt content
|
||||||
|
type PromptHandler func(ctx context.Context, args map[string]string) ([]PromptMessage, error)
|
||||||
|
|
||||||
|
// Prompt represents an MCP Prompt definition
|
||||||
|
type Prompt struct {
|
||||||
|
Name string `json:"name"` // Unique identifier for the prompt
|
||||||
|
Description string `json:"description,omitempty"` // Human-readable description
|
||||||
|
Arguments []PromptArgument `json:"arguments,omitempty"` // Arguments for customization
|
||||||
|
Content string `json:"-"` // Static content (internal use only)
|
||||||
|
Handler PromptHandler `json:"-"` // Handler for dynamic content generation
|
||||||
|
}
|
||||||
|
|
||||||
|
// PromptMessage represents a message in a conversation
|
||||||
|
type PromptMessage struct {
|
||||||
|
Role RoleType `json:"role"` // Message sender role
|
||||||
|
Content any `json:"content"` // Message content (TextContent, ImageContent, etc.)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TextContent represents text content in a message
|
||||||
|
type TextContent struct {
|
||||||
|
Text string `json:"text"` // The text content
|
||||||
|
Annotations *Annotations `json:"annotations,omitempty"` // Optional annotations
|
||||||
|
}
|
||||||
|
|
||||||
|
type typedTextContent struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
TextContent
|
||||||
|
}
|
||||||
|
|
||||||
|
// ImageContent represents image data in a message
|
||||||
|
type ImageContent struct {
|
||||||
|
Data string `json:"data"` // Base64-encoded image data
|
||||||
|
MimeType string `json:"mimeType"` // MIME type (e.g., "image/png")
|
||||||
|
}
|
||||||
|
|
||||||
|
type typedImageContent struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
ImageContent
|
||||||
|
}
|
||||||
|
|
||||||
|
// AudioContent represents audio data in a message
|
||||||
|
type AudioContent struct {
|
||||||
|
Data string `json:"data"` // Base64-encoded audio data
|
||||||
|
MimeType string `json:"mimeType"` // MIME type (e.g., "audio/mp3")
|
||||||
|
}
|
||||||
|
|
||||||
|
type typedAudioContent struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
AudioContent
|
||||||
|
}
|
||||||
|
|
||||||
|
// FileContent represents file content
|
||||||
|
type FileContent struct {
|
||||||
|
URI string `json:"uri"` // URI identifying the file
|
||||||
|
MimeType string `json:"mimeType"` // MIME type of the file
|
||||||
|
Text string `json:"text"` // File content as text
|
||||||
|
}
|
||||||
|
|
||||||
|
// EmbeddedResource represents a resource embedded in a message
|
||||||
|
type EmbeddedResource struct {
|
||||||
|
Type string `json:"type"` // Always "resource"
|
||||||
|
Resource ResourceContent `json:"resource"` // The resource data
|
||||||
|
}
|
||||||
|
|
||||||
|
// Annotations provides additional metadata for content
|
||||||
|
type Annotations struct {
|
||||||
|
Audience []RoleType `json:"audience,omitempty"` // Who should see this content
|
||||||
|
Priority *float64 `json:"priority,omitempty"` // Optional priority (0-1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tool-related Types
|
||||||
|
|
||||||
|
// ToolHandler is a function that handles tool calls
|
||||||
|
type ToolHandler func(ctx context.Context, params map[string]any) (any, error)
|
||||||
|
|
||||||
|
// Tool represents a Model Context Protocol Tool definition
|
||||||
|
type Tool struct {
|
||||||
|
Name string `json:"name"` // Unique identifier for the tool
|
||||||
|
Description string `json:"description"` // Human-readable description
|
||||||
|
InputSchema InputSchema `json:"inputSchema"` // JSON Schema for parameters
|
||||||
|
Handler ToolHandler `json:"-"` // Not sent to clients
|
||||||
|
}
|
||||||
|
|
||||||
|
// InputSchema represents tool's input schema in JSON Schema format
|
||||||
|
type InputSchema struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Properties map[string]any `json:"properties"` // Property definitions
|
||||||
|
Required []string `json:"required,omitempty"` // List of required properties
|
||||||
|
}
|
||||||
|
|
||||||
|
// CallToolResult represents a tool call result that conforms to the MCP schema
|
||||||
|
type CallToolResult struct {
|
||||||
|
Result
|
||||||
|
Content []any `json:"content"` // Content items (text, images, etc.)
|
||||||
|
IsError bool `json:"isError,omitempty"` // True if tool execution failed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resource represents a Model Context Protocol Resource definition
|
||||||
|
type Resource struct {
|
||||||
|
URI string `json:"uri"` // Unique resource identifier (RFC3986)
|
||||||
|
Name string `json:"name"` // Human-readable name
|
||||||
|
Description string `json:"description,omitempty"` // Optional description
|
||||||
|
MimeType string `json:"mimeType,omitempty"` // Optional MIME type
|
||||||
|
Handler ResourceHandler `json:"-"` // Internal handler not sent to clients
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResourceHandler is a function that handles resource read requests
|
||||||
|
type ResourceHandler func(ctx context.Context) (ResourceContent, error)
|
||||||
|
|
||||||
|
// ResourceContent represents the content of a resource
|
||||||
|
type ResourceContent struct {
|
||||||
|
URI string `json:"uri"` // Resource URI (required)
|
||||||
|
MimeType string `json:"mimeType,omitempty"` // MIME type of the resource
|
||||||
|
Text string `json:"text,omitempty"` // Text content (if available)
|
||||||
|
Blob string `json:"blob,omitempty"` // Base64 encoded blob data (if available)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResourcesListResult represents the response to a resources/list request
|
||||||
|
type ResourcesListResult struct {
|
||||||
|
PaginatedResult
|
||||||
|
Resources []Resource `json:"resources"` // List of available resources
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResourceReadParams contains parameters for a resources/read request
|
||||||
|
type ResourceReadParams struct {
|
||||||
|
URI string `json:"uri"` // URI of the resource to read
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResourceReadResult contains the result of a resources/read request
|
||||||
|
type ResourceReadResult struct {
|
||||||
|
Result
|
||||||
|
Contents []ResourceContent `json:"contents"` // Array of resource content
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResourceSubscribeParams contains parameters for a resources/subscribe request
|
||||||
|
type ResourceSubscribeParams struct {
|
||||||
|
URI string `json:"uri"` // URI of the resource to subscribe to
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResourceUpdateNotification represents a notification about a resource update
|
||||||
|
type ResourceUpdateNotification struct {
|
||||||
|
URI string `json:"uri"` // URI of the updated resource
|
||||||
|
Content ResourceContent `json:"content"` // New resource content
|
||||||
|
}
|
||||||
|
|
||||||
|
// Client and Server Types
|
||||||
|
|
||||||
|
// mcpClient represents an SSE client connection
|
||||||
|
type mcpClient struct {
|
||||||
|
id string // Unique client identifier
|
||||||
|
channel chan string // Channel for sending SSE messages
|
||||||
|
initialized bool // Tracks if client has sent notifications/initialized
|
||||||
|
}
|
||||||
|
|
||||||
|
// McpServer defines the interface for Model Context Protocol servers
|
||||||
|
type McpServer interface {
|
||||||
|
Start()
|
||||||
|
Stop()
|
||||||
|
RegisterTool(tool Tool) error
|
||||||
|
RegisterPrompt(prompt Prompt)
|
||||||
|
RegisterResource(resource Resource)
|
||||||
|
}
|
||||||
|
|
||||||
|
// sseMcpServer implements the McpServer interface using SSE
|
||||||
|
type sseMcpServer struct {
|
||||||
|
conf McpConf
|
||||||
|
server *rest.Server
|
||||||
|
clients map[string]*mcpClient
|
||||||
|
clientsLock sync.Mutex
|
||||||
|
tools map[string]Tool
|
||||||
|
toolsLock sync.Mutex
|
||||||
|
prompts map[string]Prompt
|
||||||
|
promptsLock sync.Mutex
|
||||||
|
resources map[string]Resource
|
||||||
|
resourcesLock sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// Response Types
|
||||||
|
|
||||||
|
// errorObj represents a JSON-RPC error object
|
||||||
|
type errorObj struct {
|
||||||
|
Code int `json:"code"` // Error code
|
||||||
|
Message string `json:"message"` // Error message
|
||||||
|
}
|
||||||
|
|
||||||
|
// Response represents a JSON-RPC response
|
||||||
|
type Response struct {
|
||||||
|
JsonRpc string `json:"jsonrpc"` // Always "2.0"
|
||||||
|
ID any `json:"id"` // Same as request ID
|
||||||
|
Result any `json:"result"` // Result object (null if error)
|
||||||
|
Error *errorObj `json:"error,omitempty"` // Error object (null if success)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Server Information Types
|
||||||
|
|
||||||
|
// serverInfo provides information about the server
|
||||||
|
type serverInfo struct {
|
||||||
|
Name string `json:"name"` // Server name
|
||||||
|
Version string `json:"version"` // Server version
|
||||||
|
}
|
||||||
|
|
||||||
|
// capabilities describes the server's capabilities
|
||||||
|
type capabilities struct {
|
||||||
|
Logging struct{} `json:"logging"`
|
||||||
|
Prompts struct {
|
||||||
|
ListChanged bool `json:"listChanged"` // Server will notify on prompt changes
|
||||||
|
} `json:"prompts"`
|
||||||
|
Resources struct {
|
||||||
|
Subscribe bool `json:"subscribe"` // Server supports resource subscriptions
|
||||||
|
ListChanged bool `json:"listChanged"` // Server will notify on resource changes
|
||||||
|
} `json:"resources"`
|
||||||
|
Tools struct {
|
||||||
|
ListChanged bool `json:"listChanged"` // Server will notify on tool changes
|
||||||
|
} `json:"tools"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// initializationResponse is sent in response to an initialize request
|
||||||
|
type initializationResponse struct {
|
||||||
|
ProtocolVersion string `json:"protocolVersion"` // Protocol version
|
||||||
|
Capabilities capabilities `json:"capabilities"` // Server capabilities
|
||||||
|
ServerInfo serverInfo `json:"serverInfo"` // Server information
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToolCallParams contains the parameters for a tool call
|
||||||
|
type ToolCallParams struct {
|
||||||
|
Name string `json:"name"` // Tool name
|
||||||
|
Parameters map[string]any `json:"parameters"` // Tool parameters
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToolResult contains the result of a tool execution
|
||||||
|
type ToolResult struct {
|
||||||
|
Type string `json:"type"` // Content type (text, image, etc.)
|
||||||
|
Content any `json:"content"` // Result content
|
||||||
|
}
|
||||||
|
|
||||||
|
// errorMessage represents a detailed error message
|
||||||
|
type errorMessage struct {
|
||||||
|
Code int `json:"code"` // Error code
|
||||||
|
Message string `json:"message"` // Error message
|
||||||
|
Data any `json:",omitempty"` // Additional error data
|
||||||
|
}
|
||||||
271
mcp/types_test.go
Normal file
271
mcp/types_test.go
Normal file
@@ -0,0 +1,271 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestResponseMarshaling(t *testing.T) {
|
||||||
|
// Test that the Response struct marshals correctly
|
||||||
|
resp := Response{
|
||||||
|
JsonRpc: "2.0",
|
||||||
|
ID: 123,
|
||||||
|
Result: map[string]string{
|
||||||
|
"key": "value",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(resp)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Contains(t, string(data), `"jsonrpc":"2.0"`)
|
||||||
|
assert.Contains(t, string(data), `"id":123`)
|
||||||
|
assert.Contains(t, string(data), `"result":{"key":"value"}`)
|
||||||
|
|
||||||
|
// Test response with error
|
||||||
|
respWithError := Response{
|
||||||
|
JsonRpc: "2.0",
|
||||||
|
ID: 456,
|
||||||
|
Error: &errorObj{
|
||||||
|
Code: errCodeInvalidRequest,
|
||||||
|
Message: "Invalid Request",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err = json.Marshal(respWithError)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Contains(t, string(data), `"jsonrpc":"2.0"`)
|
||||||
|
assert.Contains(t, string(data), `"id":456`)
|
||||||
|
assert.Contains(t, string(data), `"error":{"code":-32600,"message":"Invalid Request"}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestUnmarshaling(t *testing.T) {
|
||||||
|
// Test that the Request struct unmarshals correctly
|
||||||
|
jsonStr := `{
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": 789,
|
||||||
|
"method": "test_method",
|
||||||
|
"params": {"key": "value"}
|
||||||
|
}`
|
||||||
|
|
||||||
|
var req Request
|
||||||
|
err := json.Unmarshal([]byte(jsonStr), &req)
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "2.0", req.JsonRpc)
|
||||||
|
assert.Equal(t, float64(789), req.ID)
|
||||||
|
assert.Equal(t, "test_method", req.Method)
|
||||||
|
|
||||||
|
// Check params unmarshaled correctly
|
||||||
|
var params map[string]string
|
||||||
|
err = json.Unmarshal(req.Params, ¶ms)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "value", params["key"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToolStructs(t *testing.T) {
|
||||||
|
// Test Tool struct
|
||||||
|
tool := Tool{
|
||||||
|
Name: "test.tool",
|
||||||
|
Description: "A test tool",
|
||||||
|
InputSchema: InputSchema{
|
||||||
|
Type: "object",
|
||||||
|
Properties: map[string]any{
|
||||||
|
"input": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"description": "Input parameter",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Required: []string{"input"},
|
||||||
|
},
|
||||||
|
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||||
|
return "result", nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify fields are correct
|
||||||
|
assert.Equal(t, "test.tool", tool.Name)
|
||||||
|
assert.Equal(t, "A test tool", tool.Description)
|
||||||
|
assert.Equal(t, "object", tool.InputSchema.Type)
|
||||||
|
assert.Contains(t, tool.InputSchema.Properties, "input")
|
||||||
|
propMap, ok := tool.InputSchema.Properties["input"].(map[string]any)
|
||||||
|
assert.True(t, ok, "Property should be a map")
|
||||||
|
assert.Equal(t, "string", propMap["type"])
|
||||||
|
assert.NotNil(t, tool.Handler)
|
||||||
|
|
||||||
|
// Verify JSON marshalling (which should exclude Handler function)
|
||||||
|
data, err := json.Marshal(tool)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Contains(t, string(data), `"name":"test.tool"`)
|
||||||
|
assert.Contains(t, string(data), `"description":"A test tool"`)
|
||||||
|
assert.Contains(t, string(data), `"inputSchema":`)
|
||||||
|
assert.NotContains(t, string(data), `"Handler":`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPromptStructs(t *testing.T) {
|
||||||
|
// Test Prompt struct
|
||||||
|
prompt := Prompt{
|
||||||
|
Name: "test.prompt",
|
||||||
|
Description: "A test prompt description",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify fields are correct
|
||||||
|
assert.Equal(t, "test.prompt", prompt.Name)
|
||||||
|
assert.Equal(t, "A test prompt description", prompt.Description)
|
||||||
|
|
||||||
|
// Verify JSON marshalling
|
||||||
|
data, err := json.Marshal(prompt)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Contains(t, string(data), `"name":"test.prompt"`)
|
||||||
|
assert.Contains(t, string(data), `"description":"A test prompt description"`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResourceStructs(t *testing.T) {
|
||||||
|
// Test Resource struct
|
||||||
|
resource := Resource{
|
||||||
|
Name: "test.resource",
|
||||||
|
URI: "http://example.com/resource",
|
||||||
|
Description: "A test resource",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify fields are correct
|
||||||
|
assert.Equal(t, "test.resource", resource.Name)
|
||||||
|
assert.Equal(t, "http://example.com/resource", resource.URI)
|
||||||
|
assert.Equal(t, "A test resource", resource.Description)
|
||||||
|
|
||||||
|
// Verify JSON marshalling
|
||||||
|
data, err := json.Marshal(resource)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Contains(t, string(data), `"name":"test.resource"`)
|
||||||
|
assert.Contains(t, string(data), `"uri":"http://example.com/resource"`)
|
||||||
|
assert.Contains(t, string(data), `"description":"A test resource"`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContentTypes(t *testing.T) {
|
||||||
|
// Test TextContent
|
||||||
|
textContent := TextContent{
|
||||||
|
Text: "Sample text",
|
||||||
|
Annotations: &Annotations{
|
||||||
|
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||||
|
Priority: ptr(1.0),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(textContent)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Contains(t, string(data), `"text":"Sample text"`)
|
||||||
|
assert.Contains(t, string(data), `"audience":["user","assistant"]`)
|
||||||
|
assert.Contains(t, string(data), `"priority":1`)
|
||||||
|
|
||||||
|
// Test ImageContent
|
||||||
|
imageContent := ImageContent{
|
||||||
|
Data: "base64data",
|
||||||
|
MimeType: "image/png",
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err = json.Marshal(imageContent)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Contains(t, string(data), `"data":"base64data"`)
|
||||||
|
assert.Contains(t, string(data), `"mimeType":"image/png"`)
|
||||||
|
|
||||||
|
// Test AudioContent
|
||||||
|
audioContent := AudioContent{
|
||||||
|
Data: "base64audio",
|
||||||
|
MimeType: "audio/mp3",
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err = json.Marshal(audioContent)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Contains(t, string(data), `"data":"base64audio"`)
|
||||||
|
assert.Contains(t, string(data), `"mimeType":"audio/mp3"`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCallToolResult(t *testing.T) {
|
||||||
|
// Test CallToolResult
|
||||||
|
result := CallToolResult{
|
||||||
|
Result: Result{
|
||||||
|
Meta: map[string]any{
|
||||||
|
"progressToken": "token123",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Content: []interface{}{
|
||||||
|
TextContent{
|
||||||
|
Text: "Sample result",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
IsError: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(result)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Contains(t, string(data), `"_meta":{"progressToken":"token123"}`)
|
||||||
|
assert.Contains(t, string(data), `"content":[{"text":"Sample result"}]`)
|
||||||
|
assert.NotContains(t, string(data), `"isError":`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequest_isNotification(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
id any
|
||||||
|
want bool
|
||||||
|
wantErr error
|
||||||
|
}{
|
||||||
|
// integer test cases
|
||||||
|
{name: "int zero", id: 0, want: true, wantErr: nil},
|
||||||
|
{name: "int non-zero", id: 1, want: false, wantErr: nil},
|
||||||
|
{name: "int64 zero", id: int64(0), want: true, wantErr: nil},
|
||||||
|
{name: "int64 max", id: int64(9223372036854775807), want: false, wantErr: nil},
|
||||||
|
|
||||||
|
// floating point number test cases
|
||||||
|
{name: "float64 zero", id: float64(0.0), want: true, wantErr: nil},
|
||||||
|
{name: "float64 positive", id: float64(0.000001), want: false, wantErr: nil},
|
||||||
|
{name: "float64 negative", id: float64(-0.000001), want: false, wantErr: nil},
|
||||||
|
{name: "float64 epsilon", id: float64(1e-300), want: false, wantErr: nil},
|
||||||
|
|
||||||
|
// string test cases
|
||||||
|
{name: "empty string", id: "", want: true, wantErr: nil},
|
||||||
|
{name: "non-empty string", id: "abc", want: false, wantErr: nil},
|
||||||
|
{name: "space string", id: " ", want: false, wantErr: nil},
|
||||||
|
{name: "unicode string", id: "こんにちは", want: false, wantErr: nil},
|
||||||
|
|
||||||
|
// special cases
|
||||||
|
{name: "nil", id: nil, want: true, wantErr: nil},
|
||||||
|
|
||||||
|
// logical type test cases
|
||||||
|
{name: "bool true", id: true, want: false, wantErr: errors.New("invalid type bool")},
|
||||||
|
{name: "bool false", id: false, want: false, wantErr: errors.New("invalid type bool")},
|
||||||
|
{name: "struct type", id: struct{}{}, want: false, wantErr: errors.New("invalid type struct {}")},
|
||||||
|
{name: "slice type", id: []int{1, 2, 3}, want: false, wantErr: errors.New("invalid type []int")},
|
||||||
|
{name: "map type", id: map[string]int{"a": 1}, want: false, wantErr: errors.New("invalid type map[string]int")},
|
||||||
|
{name: "pointer type", id: new(int), want: false, wantErr: errors.New("invalid type *int")},
|
||||||
|
{name: "func type", id: func() {}, want: false, wantErr: errors.New("invalid type func()")},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
req := Request{
|
||||||
|
SessionId: "test-session",
|
||||||
|
JsonRpc: "2.0",
|
||||||
|
ID: tt.id,
|
||||||
|
Method: "testMethod",
|
||||||
|
Params: json.RawMessage(`{}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := req.isNotification()
|
||||||
|
|
||||||
|
if (err != nil) != (tt.wantErr != nil) {
|
||||||
|
t.Fatalf("error presence mismatch: got error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
if err != nil && tt.wantErr != nil && err.Error() != tt.wantErr.Error() {
|
||||||
|
t.Fatalf("error message mismatch:\ngot %q\nwant %q", err.Error(), tt.wantErr.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("isNotification() = %v, want %v for ID %v (%T)", got, tt.want, tt.id, tt.id)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
107
mcp/util.go
Normal file
107
mcp/util.go
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
// formatSSEMessage formats a Server-Sent Event message with proper CRLF line endings
|
||||||
|
func formatSSEMessage(event string, data []byte) string {
|
||||||
|
return fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", event, string(data))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ptr is a helper function to get a pointer to a value
|
||||||
|
func ptr[T any](v T) *T {
|
||||||
|
return &v
|
||||||
|
}
|
||||||
|
|
||||||
|
func toTypedContents(contents []any) []any {
|
||||||
|
typedContents := make([]any, len(contents))
|
||||||
|
|
||||||
|
for i, content := range contents {
|
||||||
|
switch v := content.(type) {
|
||||||
|
case TextContent:
|
||||||
|
typedContents[i] = typedTextContent{
|
||||||
|
Type: ContentTypeText,
|
||||||
|
TextContent: v,
|
||||||
|
}
|
||||||
|
case ImageContent:
|
||||||
|
typedContents[i] = typedImageContent{
|
||||||
|
Type: ContentTypeImage,
|
||||||
|
ImageContent: v,
|
||||||
|
}
|
||||||
|
case AudioContent:
|
||||||
|
typedContents[i] = typedAudioContent{
|
||||||
|
Type: ContentTypeAudio,
|
||||||
|
AudioContent: v,
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
typedContents[i] = typedTextContent{
|
||||||
|
Type: ContentTypeText,
|
||||||
|
TextContent: TextContent{
|
||||||
|
Text: fmt.Sprintf("Unknown content type: %T", v),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return typedContents
|
||||||
|
}
|
||||||
|
|
||||||
|
func toTypedPromptMessages(messages []PromptMessage) []PromptMessage {
|
||||||
|
typedMessages := make([]PromptMessage, len(messages))
|
||||||
|
|
||||||
|
for i, msg := range messages {
|
||||||
|
switch v := msg.Content.(type) {
|
||||||
|
case TextContent:
|
||||||
|
typedMessages[i] = PromptMessage{
|
||||||
|
Role: msg.Role,
|
||||||
|
Content: typedTextContent{
|
||||||
|
Type: ContentTypeText,
|
||||||
|
TextContent: v,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
case ImageContent:
|
||||||
|
typedMessages[i] = PromptMessage{
|
||||||
|
Role: msg.Role,
|
||||||
|
Content: typedImageContent{
|
||||||
|
Type: ContentTypeImage,
|
||||||
|
ImageContent: v,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
case AudioContent:
|
||||||
|
typedMessages[i] = PromptMessage{
|
||||||
|
Role: msg.Role,
|
||||||
|
Content: typedAudioContent{
|
||||||
|
Type: ContentTypeAudio,
|
||||||
|
AudioContent: v,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
typedMessages[i] = PromptMessage{
|
||||||
|
Role: msg.Role,
|
||||||
|
Content: typedTextContent{
|
||||||
|
Type: ContentTypeText,
|
||||||
|
TextContent: TextContent{
|
||||||
|
Text: fmt.Sprintf("Unknown content type: %T", v),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return typedMessages
|
||||||
|
}
|
||||||
|
|
||||||
|
// validatePromptArguments checks if all required arguments are provided
|
||||||
|
// Returns a list of missing required arguments
|
||||||
|
func validatePromptArguments(prompt Prompt, providedArgs map[string]string) []string {
|
||||||
|
var missingArgs []string
|
||||||
|
|
||||||
|
for _, arg := range prompt.Arguments {
|
||||||
|
if arg.Required {
|
||||||
|
if value, exists := providedArgs[arg.Name]; !exists || len(value) == 0 {
|
||||||
|
missingArgs = append(missingArgs, arg.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return missingArgs
|
||||||
|
}
|
||||||
274
mcp/util_test.go
Normal file
274
mcp/util_test.go
Normal file
@@ -0,0 +1,274 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Event struct {
|
||||||
|
Type string
|
||||||
|
Data map[string]any
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseEvent(input string) (*Event, error) {
|
||||||
|
var evt Event
|
||||||
|
var dataStr string
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(strings.NewReader(input))
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Text()
|
||||||
|
if strings.HasPrefix(line, "event:") {
|
||||||
|
evt.Type = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
|
||||||
|
} else if strings.HasPrefix(line, "data:") {
|
||||||
|
dataStr = strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(dataStr) > 0 {
|
||||||
|
if err := json.Unmarshal([]byte(dataStr), &evt.Data); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse data: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &evt, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestToTypedPromptMessages tests the toTypedPromptMessages function
|
||||||
|
func TestToTypedPromptMessages(t *testing.T) {
|
||||||
|
// Test with multiple message types in one test
|
||||||
|
t.Run("MixedContentTypes", func(t *testing.T) {
|
||||||
|
// Create test data with different content types
|
||||||
|
messages := []PromptMessage{
|
||||||
|
{
|
||||||
|
Role: RoleUser,
|
||||||
|
Content: TextContent{
|
||||||
|
Text: "Hello, this is a text message",
|
||||||
|
Annotations: &Annotations{
|
||||||
|
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||||
|
Priority: ptr(0.8),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: RoleAssistant,
|
||||||
|
Content: ImageContent{
|
||||||
|
Data: "base64ImageData",
|
||||||
|
MimeType: "image/jpeg",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: RoleUser,
|
||||||
|
Content: AudioContent{
|
||||||
|
Data: "base64AudioData",
|
||||||
|
MimeType: "audio/mp3",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "system",
|
||||||
|
Content: "This is a simple string that should be handled as unknown type",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call the function
|
||||||
|
result := toTypedPromptMessages(messages)
|
||||||
|
|
||||||
|
// Validate results
|
||||||
|
require.Len(t, result, 4, "Should return the same number of messages")
|
||||||
|
|
||||||
|
// Validate first message (TextContent)
|
||||||
|
msg := result[0]
|
||||||
|
assert.Equal(t, RoleUser, msg.Role, "Role should be preserved")
|
||||||
|
|
||||||
|
// Type assertion using reflection since Content is an interface
|
||||||
|
typed, ok := msg.Content.(typedTextContent)
|
||||||
|
require.True(t, ok, "Should be typedTextContent")
|
||||||
|
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
|
||||||
|
assert.Equal(t, "Hello, this is a text message", typed.Text, "Text content should be preserved")
|
||||||
|
require.NotNil(t, typed.Annotations, "Annotations should be preserved")
|
||||||
|
assert.Equal(t, []RoleType{RoleUser, RoleAssistant}, typed.Annotations.Audience, "Audience should be preserved")
|
||||||
|
require.NotNil(t, typed.Annotations.Priority, "Priority should be preserved")
|
||||||
|
assert.Equal(t, 0.8, *typed.Annotations.Priority, "Priority value should be preserved")
|
||||||
|
|
||||||
|
// Validate second message (ImageContent)
|
||||||
|
msg = result[1]
|
||||||
|
assert.Equal(t, RoleAssistant, msg.Role, "Role should be preserved")
|
||||||
|
|
||||||
|
// Type assertion for image content
|
||||||
|
typedImg, ok := msg.Content.(typedImageContent)
|
||||||
|
require.True(t, ok, "Should be typedImageContent")
|
||||||
|
assert.Equal(t, ContentTypeImage, typedImg.Type, "Type should be image")
|
||||||
|
assert.Equal(t, "base64ImageData", typedImg.Data, "Image data should be preserved")
|
||||||
|
assert.Equal(t, "image/jpeg", typedImg.MimeType, "MimeType should be preserved")
|
||||||
|
|
||||||
|
// Validate third message (AudioContent)
|
||||||
|
msg = result[2]
|
||||||
|
assert.Equal(t, RoleUser, msg.Role, "Role should be preserved")
|
||||||
|
|
||||||
|
// Type assertion for audio content
|
||||||
|
typedAudio, ok := msg.Content.(typedAudioContent)
|
||||||
|
require.True(t, ok, "Should be typedAudioContent")
|
||||||
|
assert.Equal(t, ContentTypeAudio, typedAudio.Type, "Type should be audio")
|
||||||
|
assert.Equal(t, "base64AudioData", typedAudio.Data, "Audio data should be preserved")
|
||||||
|
assert.Equal(t, "audio/mp3", typedAudio.MimeType, "MimeType should be preserved")
|
||||||
|
|
||||||
|
// Validate fourth message (unknown type converted to TextContent)
|
||||||
|
msg = result[3]
|
||||||
|
assert.Equal(t, RoleType("system"), msg.Role, "Role should be preserved")
|
||||||
|
|
||||||
|
// Should be converted to a typedTextContent with error message
|
||||||
|
typedUnknown, ok := msg.Content.(typedTextContent)
|
||||||
|
require.True(t, ok, "Unknown content should be converted to typedTextContent")
|
||||||
|
assert.Equal(t, ContentTypeText, typedUnknown.Type, "Type should be text")
|
||||||
|
assert.Contains(t, typedUnknown.Text, "Unknown content type:", "Should contain error about unknown type")
|
||||||
|
assert.Contains(t, typedUnknown.Text, "string", "Should mention the actual type")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test empty input
|
||||||
|
t.Run("EmptyInput", func(t *testing.T) {
|
||||||
|
messages := []PromptMessage{}
|
||||||
|
result := toTypedPromptMessages(messages)
|
||||||
|
assert.Empty(t, result, "Should return empty slice for empty input")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test with nil annotations
|
||||||
|
t.Run("NilAnnotations", func(t *testing.T) {
|
||||||
|
messages := []PromptMessage{
|
||||||
|
{
|
||||||
|
Role: RoleUser,
|
||||||
|
Content: TextContent{
|
||||||
|
Text: "Text with nil annotations",
|
||||||
|
Annotations: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := toTypedPromptMessages(messages)
|
||||||
|
require.Len(t, result, 1, "Should return one message")
|
||||||
|
|
||||||
|
typed, ok := result[0].Content.(typedTextContent)
|
||||||
|
require.True(t, ok, "Should be typedTextContent")
|
||||||
|
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
|
||||||
|
assert.Equal(t, "Text with nil annotations", typed.Text, "Text content should be preserved")
|
||||||
|
assert.Nil(t, typed.Annotations, "Nil annotations should be preserved as nil")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestToTypedContents tests the toTypedContents function
|
||||||
|
func TestToTypedContents(t *testing.T) {
|
||||||
|
// Test with multiple content types in one test
|
||||||
|
t.Run("MixedContentTypes", func(t *testing.T) {
|
||||||
|
// Create test data with different content types
|
||||||
|
contents := []any{
|
||||||
|
TextContent{
|
||||||
|
Text: "Hello, this is a text content",
|
||||||
|
Annotations: &Annotations{
|
||||||
|
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||||
|
Priority: ptr(0.7),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ImageContent{
|
||||||
|
Data: "base64ImageData",
|
||||||
|
MimeType: "image/png",
|
||||||
|
},
|
||||||
|
AudioContent{
|
||||||
|
Data: "base64AudioData",
|
||||||
|
MimeType: "audio/wav",
|
||||||
|
},
|
||||||
|
"This is a simple string that should be handled as unknown type",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call the function
|
||||||
|
result := toTypedContents(contents)
|
||||||
|
|
||||||
|
// Validate results
|
||||||
|
require.Len(t, result, 4, "Should return the same number of contents")
|
||||||
|
|
||||||
|
// Validate first content (TextContent)
|
||||||
|
typed, ok := result[0].(typedTextContent)
|
||||||
|
require.True(t, ok, "Should be typedTextContent")
|
||||||
|
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
|
||||||
|
assert.Equal(t, "Hello, this is a text content", typed.Text, "Text content should be preserved")
|
||||||
|
require.NotNil(t, typed.Annotations, "Annotations should be preserved")
|
||||||
|
assert.Equal(t, []RoleType{RoleUser, RoleAssistant}, typed.Annotations.Audience, "Audience should be preserved")
|
||||||
|
require.NotNil(t, typed.Annotations.Priority, "Priority should be preserved")
|
||||||
|
assert.Equal(t, 0.7, *typed.Annotations.Priority, "Priority value should be preserved")
|
||||||
|
|
||||||
|
// Validate second content (ImageContent)
|
||||||
|
typedImg, ok := result[1].(typedImageContent)
|
||||||
|
require.True(t, ok, "Should be typedImageContent")
|
||||||
|
assert.Equal(t, ContentTypeImage, typedImg.Type, "Type should be image")
|
||||||
|
assert.Equal(t, "base64ImageData", typedImg.Data, "Image data should be preserved")
|
||||||
|
assert.Equal(t, "image/png", typedImg.MimeType, "MimeType should be preserved")
|
||||||
|
|
||||||
|
// Validate third content (AudioContent)
|
||||||
|
typedAudio, ok := result[2].(typedAudioContent)
|
||||||
|
require.True(t, ok, "Should be typedAudioContent")
|
||||||
|
assert.Equal(t, ContentTypeAudio, typedAudio.Type, "Type should be audio")
|
||||||
|
assert.Equal(t, "base64AudioData", typedAudio.Data, "Audio data should be preserved")
|
||||||
|
assert.Equal(t, "audio/wav", typedAudio.MimeType, "MimeType should be preserved")
|
||||||
|
|
||||||
|
// Validate fourth content (unknown type converted to TextContent)
|
||||||
|
typedUnknown, ok := result[3].(typedTextContent)
|
||||||
|
require.True(t, ok, "Unknown content should be converted to typedTextContent")
|
||||||
|
assert.Equal(t, ContentTypeText, typedUnknown.Type, "Type should be text")
|
||||||
|
assert.Contains(t, typedUnknown.Text, "Unknown content type:", "Should contain error about unknown type")
|
||||||
|
assert.Contains(t, typedUnknown.Text, "string", "Should mention the actual type")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test empty input
|
||||||
|
t.Run("EmptyInput", func(t *testing.T) {
|
||||||
|
contents := []any{}
|
||||||
|
result := toTypedContents(contents)
|
||||||
|
assert.Empty(t, result, "Should return empty slice for empty input")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test with nil annotations
|
||||||
|
t.Run("NilAnnotations", func(t *testing.T) {
|
||||||
|
contents := []any{
|
||||||
|
TextContent{
|
||||||
|
Text: "Text with nil annotations",
|
||||||
|
Annotations: nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := toTypedContents(contents)
|
||||||
|
require.Len(t, result, 1, "Should return one content")
|
||||||
|
|
||||||
|
typed, ok := result[0].(typedTextContent)
|
||||||
|
require.True(t, ok, "Should be typedTextContent")
|
||||||
|
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
|
||||||
|
assert.Equal(t, "Text with nil annotations", typed.Text, "Text content should be preserved")
|
||||||
|
assert.Nil(t, typed.Annotations, "Nil annotations should be preserved as nil")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test with custom struct (should be handled as unknown type)
|
||||||
|
t.Run("CustomStruct", func(t *testing.T) {
|
||||||
|
type CustomContent struct {
|
||||||
|
Data string
|
||||||
|
}
|
||||||
|
|
||||||
|
contents := []any{
|
||||||
|
CustomContent{
|
||||||
|
Data: "custom data",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := toTypedContents(contents)
|
||||||
|
require.Len(t, result, 1, "Should return one content")
|
||||||
|
|
||||||
|
typed, ok := result[0].(typedTextContent)
|
||||||
|
require.True(t, ok, "Custom struct should be converted to typedTextContent")
|
||||||
|
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
|
||||||
|
assert.Contains(t, typed.Text, "Unknown content type:", "Should contain error about unknown type")
|
||||||
|
assert.Contains(t, typed.Text, "CustomContent", "Should mention the actual type")
|
||||||
|
})
|
||||||
|
}
|
||||||
149
mcp/vars.go
Normal file
149
mcp/vars.go
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/zeromicro/go-zero/core/syncx"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Protocol constants
|
||||||
|
const (
|
||||||
|
// JSON-RPC version as defined in the specification
|
||||||
|
jsonRpcVersion = "2.0"
|
||||||
|
|
||||||
|
// Session identifier key used in request URLs
|
||||||
|
sessionIdKey = "session_id"
|
||||||
|
|
||||||
|
// progressTokenKey is used to track progress of long-running tasks
|
||||||
|
progressTokenKey = "progressToken"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Server-Sent Events (SSE) event types
|
||||||
|
const (
|
||||||
|
// Standard message event for JSON-RPC responses
|
||||||
|
eventMessage = "message"
|
||||||
|
|
||||||
|
// Endpoint event for sending endpoint URL to clients
|
||||||
|
eventEndpoint = "endpoint"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Content type identifiers
|
||||||
|
const (
|
||||||
|
// ContentTypeObject is object content type
|
||||||
|
ContentTypeObject = "object"
|
||||||
|
|
||||||
|
// ContentTypeText is text content type
|
||||||
|
ContentTypeText = "text"
|
||||||
|
|
||||||
|
// ContentTypeImage is image content type
|
||||||
|
ContentTypeImage = "image"
|
||||||
|
|
||||||
|
// ContentTypeAudio is audio content type
|
||||||
|
ContentTypeAudio = "audio"
|
||||||
|
|
||||||
|
// ContentTypeResource is resource content type
|
||||||
|
ContentTypeResource = "resource"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Collection keys for broadcast events
|
||||||
|
const (
|
||||||
|
// Key for prompts collection
|
||||||
|
keyPrompts = "prompts"
|
||||||
|
|
||||||
|
// Key for resources collection
|
||||||
|
keyResources = "resources"
|
||||||
|
|
||||||
|
// Key for tools collection
|
||||||
|
keyTools = "tools"
|
||||||
|
)
|
||||||
|
|
||||||
|
// JSON-RPC error codes
|
||||||
|
// Standard error codes from JSON-RPC 2.0 spec
|
||||||
|
const (
|
||||||
|
// Invalid JSON was received by the server
|
||||||
|
errCodeInvalidRequest = -32600
|
||||||
|
|
||||||
|
// The method does not exist / is not available
|
||||||
|
errCodeMethodNotFound = -32601
|
||||||
|
|
||||||
|
// Invalid method parameter(s)
|
||||||
|
errCodeInvalidParams = -32602
|
||||||
|
|
||||||
|
// Internal JSON-RPC error
|
||||||
|
errCodeInternalError = -32603
|
||||||
|
|
||||||
|
// Tool execution timed out
|
||||||
|
errCodeTimeout = -32001
|
||||||
|
|
||||||
|
// Resource not found error
|
||||||
|
errCodeResourceNotFound = -32002
|
||||||
|
|
||||||
|
// Client hasn't completed initialization
|
||||||
|
errCodeClientNotInitialized = -32800
|
||||||
|
)
|
||||||
|
|
||||||
|
// User and assistant role definitions
|
||||||
|
const (
|
||||||
|
// RoleUser is the "user" role - the entity asking questions
|
||||||
|
RoleUser RoleType = "user"
|
||||||
|
|
||||||
|
// RoleAssistant is the "assistant" role - the entity providing responses
|
||||||
|
RoleAssistant RoleType = "assistant"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Method names as defined in the MCP specification
|
||||||
|
const (
|
||||||
|
// Initialize the connection between client and server
|
||||||
|
methodInitialize = "initialize"
|
||||||
|
|
||||||
|
// List available tools
|
||||||
|
methodToolsList = "tools/list"
|
||||||
|
|
||||||
|
// Call a specific tool
|
||||||
|
methodToolsCall = "tools/call"
|
||||||
|
|
||||||
|
// List available prompts
|
||||||
|
methodPromptsList = "prompts/list"
|
||||||
|
|
||||||
|
// Get a specific prompt
|
||||||
|
methodPromptsGet = "prompts/get"
|
||||||
|
|
||||||
|
// List available resources
|
||||||
|
methodResourcesList = "resources/list"
|
||||||
|
|
||||||
|
// Read a specific resource
|
||||||
|
methodResourcesRead = "resources/read"
|
||||||
|
|
||||||
|
// Subscribe to resource updates
|
||||||
|
methodResourcesSubscribe = "resources/subscribe"
|
||||||
|
|
||||||
|
// Simple ping to check server availability
|
||||||
|
methodPing = "ping"
|
||||||
|
|
||||||
|
// Notification that client is fully initialized
|
||||||
|
methodNotificationsInitialized = "notifications/initialized"
|
||||||
|
|
||||||
|
// Notification that a request was canceled
|
||||||
|
methodNotificationsCancelled = "notifications/cancelled"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Event names for Server-Sent Events (SSE)
|
||||||
|
const (
|
||||||
|
// Notification of tool list changes
|
||||||
|
eventToolsListChanged = "tools/list_changed"
|
||||||
|
|
||||||
|
// Notification of prompt list changes
|
||||||
|
eventPromptsListChanged = "prompts/list_changed"
|
||||||
|
|
||||||
|
// Notification of resource list changes
|
||||||
|
eventResourcesListChanged = "resources/list_changed"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// Default channel size for events
|
||||||
|
eventChanSize = 10
|
||||||
|
|
||||||
|
// Default ping interval for checking connection availability
|
||||||
|
// use syncx.ForAtomicDuration to ensure atomicity in test race
|
||||||
|
pingInterval = syncx.ForAtomicDuration(30 * time.Second)
|
||||||
|
)
|
||||||
210
mcp/vars_test.go
Normal file
210
mcp/vars_test.go
Normal file
@@ -0,0 +1,210 @@
|
|||||||
|
// filepath: /Users/kevin/Develop/go/opensource/go-zero/mcp/vars_test.go
|
||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestErrorCodes ensures error codes are applied correctly in error responses
|
||||||
|
func TestErrorCodes(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
code int
|
||||||
|
message string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "invalid request error",
|
||||||
|
code: errCodeInvalidRequest,
|
||||||
|
message: "Invalid request",
|
||||||
|
expected: `"code":-32600`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "method not found error",
|
||||||
|
code: errCodeMethodNotFound,
|
||||||
|
message: "Method not found",
|
||||||
|
expected: `"code":-32601`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid params error",
|
||||||
|
code: errCodeInvalidParams,
|
||||||
|
message: "Invalid parameters",
|
||||||
|
expected: `"code":-32602`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "internal error",
|
||||||
|
code: errCodeInternalError,
|
||||||
|
message: "Internal server error",
|
||||||
|
expected: `"code":-32603`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "timeout error",
|
||||||
|
code: errCodeTimeout,
|
||||||
|
message: "Operation timed out",
|
||||||
|
expected: `"code":-32001`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "resource not found error",
|
||||||
|
code: errCodeResourceNotFound,
|
||||||
|
message: "Resource not found",
|
||||||
|
expected: `"code":-32002`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "client not initialized error",
|
||||||
|
code: errCodeClientNotInitialized,
|
||||||
|
message: "Client not initialized",
|
||||||
|
expected: `"code":-32800`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
resp := Response{
|
||||||
|
JsonRpc: jsonRpcVersion,
|
||||||
|
ID: int64(1),
|
||||||
|
Error: &errorObj{
|
||||||
|
Code: tc.code,
|
||||||
|
Message: tc.message,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(resp)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Contains(t, string(data), tc.expected, "Error code should match expected value")
|
||||||
|
assert.Contains(t, string(data), tc.message, "Error message should be included")
|
||||||
|
assert.Contains(t, string(data), jsonRpcVersion, "JSON-RPC version should be included")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestJsonRpcVersion ensures the correct JSON-RPC version is used
|
||||||
|
func TestJsonRpcVersion(t *testing.T) {
|
||||||
|
assert.Equal(t, "2.0", jsonRpcVersion, "JSON-RPC version should be 2.0")
|
||||||
|
|
||||||
|
// Test that it's used in responses
|
||||||
|
resp := Response{
|
||||||
|
JsonRpc: jsonRpcVersion,
|
||||||
|
ID: int64(1),
|
||||||
|
Result: "test",
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(resp)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Contains(t, string(data), `"jsonrpc":"2.0"`, "Response should use correct JSON-RPC version")
|
||||||
|
|
||||||
|
// Test that it's expected in requests
|
||||||
|
reqStr := `{"jsonrpc":"2.0","id":1,"method":"test"}`
|
||||||
|
var req Request
|
||||||
|
err = json.Unmarshal([]byte(reqStr), &req)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, jsonRpcVersion, req.JsonRpc, "Request should parse correct JSON-RPC version")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSessionIdKey ensures session ID extraction works correctly
|
||||||
|
func TestSessionIdKey(t *testing.T) {
|
||||||
|
// Create a mock server implementation
|
||||||
|
mock := newMockMcpServer(t)
|
||||||
|
defer mock.shutdown()
|
||||||
|
|
||||||
|
// Verify the key constant
|
||||||
|
assert.Equal(t, "session_id", sessionIdKey, "Session ID key should be 'session_id'")
|
||||||
|
|
||||||
|
// Test that session ID is extracted correctly
|
||||||
|
mockR := httptest.NewRequest("GET", "/?"+sessionIdKey+"=test-session", nil)
|
||||||
|
|
||||||
|
// Since the mock server is using the same session key logic,
|
||||||
|
// we can test this by accessing the request query parameters directly
|
||||||
|
sessionID := mockR.URL.Query().Get(sessionIdKey)
|
||||||
|
assert.Equal(t, "test-session", sessionID, "Session ID should be extracted correctly")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestEventTypes ensures event types are set correctly in SSE responses
|
||||||
|
func TestEventTypes(t *testing.T) {
|
||||||
|
// Test message event
|
||||||
|
assert.Equal(t, "message", eventMessage, "Message event should be 'message'")
|
||||||
|
|
||||||
|
// Test endpoint event
|
||||||
|
assert.Equal(t, "endpoint", eventEndpoint, "Endpoint event should be 'endpoint'")
|
||||||
|
|
||||||
|
// Verify them in an actual SSE format string
|
||||||
|
messageEvent := "event: " + eventMessage + "\ndata: test\n\n"
|
||||||
|
assert.Contains(t, messageEvent, "event: message", "Message event should format correctly")
|
||||||
|
|
||||||
|
endpointEvent := "event: " + eventEndpoint + "\ndata: test\n\n"
|
||||||
|
assert.Contains(t, endpointEvent, "event: endpoint", "Endpoint event should format correctly")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCollectionKeys checks that collection keys are used correctly
|
||||||
|
func TestCollectionKeys(t *testing.T) {
|
||||||
|
// Verify collection key constants
|
||||||
|
assert.Equal(t, "prompts", keyPrompts, "Prompts key should be 'prompts'")
|
||||||
|
assert.Equal(t, "resources", keyResources, "Resources key should be 'resources'")
|
||||||
|
assert.Equal(t, "tools", keyTools, "Tools key should be 'tools'")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRoleTypes checks that role types are used correctly
|
||||||
|
func TestRoleTypes(t *testing.T) {
|
||||||
|
// Test in annotations
|
||||||
|
annotations := Annotations{
|
||||||
|
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(annotations)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Contains(t, string(data), `"audience":["user","assistant"]`, "Role types should marshal correctly")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMethodNames checks that method names are used correctly
|
||||||
|
func TestMethodNames(t *testing.T) {
|
||||||
|
// Verify method name constants
|
||||||
|
methods := map[string]string{
|
||||||
|
"initialize": methodInitialize,
|
||||||
|
"tools/list": methodToolsList,
|
||||||
|
"tools/call": methodToolsCall,
|
||||||
|
"prompts/list": methodPromptsList,
|
||||||
|
"prompts/get": methodPromptsGet,
|
||||||
|
"resources/list": methodResourcesList,
|
||||||
|
"resources/read": methodResourcesRead,
|
||||||
|
"resources/subscribe": methodResourcesSubscribe,
|
||||||
|
"ping": methodPing,
|
||||||
|
"notifications/initialized": methodNotificationsInitialized,
|
||||||
|
"notifications/cancelled": methodNotificationsCancelled,
|
||||||
|
}
|
||||||
|
|
||||||
|
for expected, actual := range methods {
|
||||||
|
assert.Equal(t, expected, actual, "Method name should be "+expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test in a request
|
||||||
|
for methodName := range methods {
|
||||||
|
req := Request{
|
||||||
|
JsonRpc: jsonRpcVersion,
|
||||||
|
ID: int64(1),
|
||||||
|
Method: methodName,
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(req)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Contains(t, string(data), `"method":"`+methodName+`"`, "Method name should be used in requests")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestEventNames checks that event names are used correctly
|
||||||
|
func TestEventNames(t *testing.T) {
|
||||||
|
// Verify event name constants
|
||||||
|
events := map[string]string{
|
||||||
|
"tools/list_changed": eventToolsListChanged,
|
||||||
|
"prompts/list_changed": eventPromptsListChanged,
|
||||||
|
"resources/list_changed": eventResourcesListChanged,
|
||||||
|
}
|
||||||
|
|
||||||
|
for expected, actual := range events {
|
||||||
|
assert.Equal(t, expected, actual, "Event name should be "+expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test event names in SSE format
|
||||||
|
for _, eventName := range events {
|
||||||
|
sseEvent := "event: " + eventName + "\ndata: test\n\n"
|
||||||
|
assert.Contains(t, sseEvent, "event: "+eventName, "Event name should format correctly in SSE")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -6,7 +6,6 @@
|
|||||||
|
|
||||||
[English](readme.md) | 简体中文
|
[English](readme.md) | 简体中文
|
||||||
|
|
||||||
[](https://github.com/zeromicro/go-zero/actions)
|
|
||||||
[](https://goreportcard.com/report/github.com/zeromicro/go-zero)
|
[](https://goreportcard.com/report/github.com/zeromicro/go-zero)
|
||||||
[](https://goproxy.cn/stats/github.com/zeromicro/go-zero/badges/download-count.svg)
|
[](https://goproxy.cn/stats/github.com/zeromicro/go-zero/badges/download-count.svg)
|
||||||
[](https://codecov.io/gh/zeromicro/go-zero)
|
[](https://codecov.io/gh/zeromicro/go-zero)
|
||||||
@@ -301,6 +300,10 @@ go-zero 已被许多公司用于生产部署,接入场景如在线教育、电
|
|||||||
>102. 深圳市兴海物联科技有限公司
|
>102. 深圳市兴海物联科技有限公司
|
||||||
>103. 爱芯元智半导体股份有限公司
|
>103. 爱芯元智半导体股份有限公司
|
||||||
>104. 杭州升恒科技有限公司
|
>104. 杭州升恒科技有限公司
|
||||||
|
>105. 昆仑万维科技股份有限公司
|
||||||
|
>106. 无锡盛算信息技术有限公司
|
||||||
|
>107. 深圳市聚货通信息科技有限公司
|
||||||
|
>108. 浙江银盾云科技有限公司
|
||||||
|
|
||||||
如果贵公司也已使用 go-zero,欢迎在 [登记地址](https://github.com/zeromicro/go-zero/issues/602) 登记,仅仅为了推广,不做其它用途。
|
如果贵公司也已使用 go-zero,欢迎在 [登记地址](https://github.com/zeromicro/go-zero/issues/602) 登记,仅仅为了推广,不做其它用途。
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ go-zero is a web and rpc framework with lots of builtin engineering practices. I
|
|||||||
|
|
||||||
<div align=center>
|
<div align=center>
|
||||||
|
|
||||||
[](https://github.com/zeromicro/go-zero/actions)
|
|
||||||
[](https://codecov.io/gh/zeromicro/go-zero)
|
[](https://codecov.io/gh/zeromicro/go-zero)
|
||||||
[](https://goreportcard.com/report/github.com/zeromicro/go-zero)
|
[](https://goreportcard.com/report/github.com/zeromicro/go-zero)
|
||||||
[](https://github.com/zeromicro/go-zero)
|
[](https://github.com/zeromicro/go-zero)
|
||||||
@@ -251,7 +250,3 @@ go-zero enlisted in the [CNCF Cloud Native Landscape](https://landscape.cncf.io/
|
|||||||
## Give a Star! ⭐
|
## Give a Star! ⭐
|
||||||
|
|
||||||
If you like this project or are using it to learn or start your own solution, give it a star to get updates on new releases. Your support matters!
|
If you like this project or are using it to learn or start your own solution, give it a star to get updates on new releases. Your support matters!
|
||||||
|
|
||||||
## Buy me a coffee
|
|
||||||
|
|
||||||
<a href="https://www.buymeacoffee.com/kevwan" target="_blank"><img src="https://cdn.buymeacoffee.com/buttons/v2/default-yellow.png" alt="Buy Me A Coffee" style="height: 60px !important;width: 217px !important;" ></a>
|
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"github.com/zeromicro/go-zero/rest/handler"
|
"github.com/zeromicro/go-zero/rest/handler"
|
||||||
"github.com/zeromicro/go-zero/rest/httpx"
|
"github.com/zeromicro/go-zero/rest/httpx"
|
||||||
"github.com/zeromicro/go-zero/rest/internal"
|
"github.com/zeromicro/go-zero/rest/internal"
|
||||||
|
"github.com/zeromicro/go-zero/rest/internal/header"
|
||||||
"github.com/zeromicro/go-zero/rest/internal/response"
|
"github.com/zeromicro/go-zero/rest/internal/response"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -27,7 +28,10 @@ var ErrSignatureConfig = errors.New("bad config for Signature")
|
|||||||
type engine struct {
|
type engine struct {
|
||||||
conf RestConf
|
conf RestConf
|
||||||
routes []featuredRoutes
|
routes []featuredRoutes
|
||||||
// timeout is the max timeout of all routes
|
// timeout is the max timeout of all routes,
|
||||||
|
// and is used to set http.Server.ReadTimeout and http.Server.WriteTimeout.
|
||||||
|
// this network timeout is used to avoid DoS attacks by sending data slowly
|
||||||
|
// or receiving data slowly with many connections to exhaust server resources.
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
unauthorizedCallback handler.UnauthorizedCallback
|
unauthorizedCallback handler.UnauthorizedCallback
|
||||||
unsignedCallback handler.UnsignedCallback
|
unsignedCallback handler.UnsignedCallback
|
||||||
@@ -54,13 +58,26 @@ func newEngine(c RestConf) *engine {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (ng *engine) addRoutes(r featuredRoutes) {
|
func (ng *engine) addRoutes(r featuredRoutes) {
|
||||||
|
if r.sse {
|
||||||
|
r.routes = buildSSERoutes(r.routes)
|
||||||
|
}
|
||||||
ng.routes = append(ng.routes, r)
|
ng.routes = append(ng.routes, r)
|
||||||
|
|
||||||
// need to guarantee the timeout is the max of all routes
|
ng.mightUpdateTimeout(r)
|
||||||
// otherwise impossible to set http.Server.ReadTimeout & WriteTimeout
|
|
||||||
if r.timeout > ng.timeout {
|
|
||||||
ng.timeout = r.timeout
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func buildSSERoutes(routes []Route) []Route {
|
||||||
|
for i, route := range routes {
|
||||||
|
h := route.Handler
|
||||||
|
routes[i].Handler = func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set(header.ContentType, header.ContentTypeEventStream)
|
||||||
|
w.Header().Set(header.CacheControl, header.CacheControlNoCache)
|
||||||
|
w.Header().Set(header.Connection, header.ConnectionKeepAlive)
|
||||||
|
h(w, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return routes
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ng *engine) appendAuthHandler(fr featuredRoutes, chn chain.Chain,
|
func (ng *engine) appendAuthHandler(fr featuredRoutes, chn chain.Chain,
|
||||||
@@ -174,11 +191,12 @@ func (ng *engine) checkedMaxBytes(bytes int64) int64 {
|
|||||||
return ng.conf.MaxBytes
|
return ng.conf.MaxBytes
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ng *engine) checkedTimeout(timeout time.Duration) time.Duration {
|
func (ng *engine) checkedTimeout(timeout *time.Duration) time.Duration {
|
||||||
if timeout > 0 {
|
if timeout != nil {
|
||||||
return timeout
|
return *timeout
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// if timeout not set in featured routes, use global timeout
|
||||||
return time.Duration(ng.conf.Timeout) * time.Millisecond
|
return time.Duration(ng.conf.Timeout) * time.Millisecond
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -210,6 +228,32 @@ func (ng *engine) getShedder(priority bool) load.Shedder {
|
|||||||
return ng.shedder
|
return ng.shedder
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ng *engine) hasTimeout() bool {
|
||||||
|
return ng.conf.Middlewares.Timeout && ng.timeout > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// mightUpdateTimeout checks if the route timeout is greater than the current,
|
||||||
|
// and updates the engine's timeout accordingly.
|
||||||
|
func (ng *engine) mightUpdateTimeout(r featuredRoutes) {
|
||||||
|
// if global timeout is set to 0, it means no need to set read/write timeout
|
||||||
|
// if route timeout is nil, no need to update ng.timeout
|
||||||
|
if ng.timeout == 0 || r.timeout == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// if route timeout is 0 (means no timeout), cannot set read/write timeout
|
||||||
|
if *r.timeout == 0 {
|
||||||
|
ng.timeout = 0
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// need to guarantee the timeout is the max of all routes
|
||||||
|
// otherwise impossible to set http.Server.ReadTimeout & WriteTimeout
|
||||||
|
if *r.timeout > ng.timeout {
|
||||||
|
ng.timeout = *r.timeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// notFoundHandler returns a middleware that handles 404 not found requests.
|
// notFoundHandler returns a middleware that handles 404 not found requests.
|
||||||
func (ng *engine) notFoundHandler(next http.Handler) http.Handler {
|
func (ng *engine) notFoundHandler(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -311,7 +355,7 @@ func (ng *engine) start(router httpx.Router, opts ...StartOption) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// make sure user defined options overwrite default options
|
// make sure user defined options overwrite default options
|
||||||
opts = append([]StartOption{ng.withTimeout()}, opts...)
|
opts = append([]StartOption{ng.withNetworkTimeout()}, opts...)
|
||||||
|
|
||||||
if len(ng.conf.CertFile) == 0 && len(ng.conf.KeyFile) == 0 {
|
if len(ng.conf.CertFile) == 0 && len(ng.conf.KeyFile) == 0 {
|
||||||
return internal.StartHttp(ng.conf.Host, ng.conf.Port, router, opts...)
|
return internal.StartHttp(ng.conf.Host, ng.conf.Port, router, opts...)
|
||||||
@@ -334,18 +378,19 @@ func (ng *engine) use(middleware Middleware) {
|
|||||||
ng.middlewares = append(ng.middlewares, middleware)
|
ng.middlewares = append(ng.middlewares, middleware)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ng *engine) withTimeout() internal.StartOption {
|
func (ng *engine) withNetworkTimeout() internal.StartOption {
|
||||||
return func(svr *http.Server) {
|
return func(svr *http.Server) {
|
||||||
timeout := ng.timeout
|
if !ng.hasTimeout() {
|
||||||
if timeout > 0 {
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// factor 0.8, to avoid clients send longer content-length than the actual content,
|
// factor 0.8, to avoid clients send longer content-length than the actual content,
|
||||||
// without this timeout setting, the server will time out and respond 503 Service Unavailable,
|
// without this timeout setting, the server will time out and respond 503 Service Unavailable,
|
||||||
// which triggers the circuit breaker.
|
// which triggers the circuit breaker.
|
||||||
svr.ReadTimeout = 4 * timeout / 5
|
svr.ReadTimeout = 4 * ng.timeout / 5
|
||||||
// factor 1.1, to avoid servers don't have enough time to write responses.
|
// factor 1.1, to avoid servers don't have enough time to write responses.
|
||||||
// setting the factor less than 1.0 may lead clients not receiving the responses.
|
// setting the factor less than 1.0 may lead clients not receiving the responses.
|
||||||
svr.WriteTimeout = 11 * timeout / 10
|
svr.WriteTimeout = 11 * ng.timeout / 10
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -73,7 +73,17 @@ Verbose: true
|
|||||||
Path: "/",
|
Path: "/",
|
||||||
Handler: func(w http.ResponseWriter, r *http.Request) {},
|
Handler: func(w http.ResponseWriter, r *http.Request) {},
|
||||||
}},
|
}},
|
||||||
timeout: time.Minute,
|
timeout: ptrOfDuration(time.Minute),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
jwt: jwtSetting{},
|
||||||
|
signature: signatureSetting{},
|
||||||
|
routes: []Route{{
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Path: "/",
|
||||||
|
Handler: func(w http.ResponseWriter, r *http.Request) {},
|
||||||
|
}},
|
||||||
|
timeout: ptrOfDuration(0),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
priority: true,
|
priority: true,
|
||||||
@@ -84,7 +94,7 @@ Verbose: true
|
|||||||
Path: "/",
|
Path: "/",
|
||||||
Handler: func(w http.ResponseWriter, r *http.Request) {},
|
Handler: func(w http.ResponseWriter, r *http.Request) {},
|
||||||
}},
|
}},
|
||||||
timeout: time.Second,
|
timeout: ptrOfDuration(time.Second),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
priority: true,
|
priority: true,
|
||||||
@@ -227,8 +237,12 @@ Verbose: true
|
|||||||
}))
|
}))
|
||||||
|
|
||||||
timeout := time.Second * 3
|
timeout := time.Second * 3
|
||||||
if route.timeout > timeout {
|
if route.timeout != nil {
|
||||||
timeout = route.timeout
|
if *route.timeout == 0 {
|
||||||
|
timeout = 0
|
||||||
|
} else if *route.timeout > timeout {
|
||||||
|
timeout = *route.timeout
|
||||||
|
}
|
||||||
}
|
}
|
||||||
assert.Equal(t, timeout, ng.timeout)
|
assert.Equal(t, timeout, ng.timeout)
|
||||||
})
|
})
|
||||||
@@ -236,10 +250,69 @@ Verbose: true
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNewEngine_unsignedCallback(t *testing.T) {
|
||||||
|
priKeyfile, err := fs.TempFilenameWithText(priKey)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
defer os.Remove(priKeyfile)
|
||||||
|
|
||||||
|
yaml := `Name: foo
|
||||||
|
Host: localhost
|
||||||
|
Port: 0
|
||||||
|
Middlewares:
|
||||||
|
Log: false
|
||||||
|
`
|
||||||
|
route := featuredRoutes{
|
||||||
|
priority: true,
|
||||||
|
jwt: jwtSetting{
|
||||||
|
enabled: true,
|
||||||
|
},
|
||||||
|
signature: signatureSetting{
|
||||||
|
enabled: true,
|
||||||
|
SignatureConf: SignatureConf{
|
||||||
|
Strict: true,
|
||||||
|
PrivateKeys: []PrivateKeyConf{
|
||||||
|
{
|
||||||
|
Fingerprint: "a",
|
||||||
|
KeyFile: priKeyfile,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
routes: []Route{{
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Path: "/",
|
||||||
|
Handler: func(w http.ResponseWriter, r *http.Request) {},
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
|
||||||
|
var index int32
|
||||||
|
t.Run(fmt.Sprintf("%s-%v", yaml, route.routes), func(t *testing.T) {
|
||||||
|
var cnf RestConf
|
||||||
|
assert.Nil(t, conf.LoadFromYamlBytes([]byte(yaml), &cnf))
|
||||||
|
ng := newEngine(cnf)
|
||||||
|
if atomic.AddInt32(&index, 1)%2 == 0 {
|
||||||
|
ng.setUnsignedCallback(func(w http.ResponseWriter, r *http.Request,
|
||||||
|
next http.Handler, strict bool, code int) {
|
||||||
|
})
|
||||||
|
}
|
||||||
|
ng.addRoutes(route)
|
||||||
|
ng.use(func(next http.HandlerFunc) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.NotNil(t, ng.start(mockedRouter{}, func(svr *http.Server) {
|
||||||
|
}))
|
||||||
|
|
||||||
|
assert.Equal(t, time.Duration(time.Second*3), ng.timeout)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestEngine_checkedTimeout(t *testing.T) {
|
func TestEngine_checkedTimeout(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
timeout time.Duration
|
timeout *time.Duration
|
||||||
expect time.Duration
|
expect time.Duration
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
@@ -248,17 +321,17 @@ func TestEngine_checkedTimeout(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "less",
|
name: "less",
|
||||||
timeout: time.Millisecond * 500,
|
timeout: ptrOfDuration(time.Millisecond * 500),
|
||||||
expect: time.Millisecond * 500,
|
expect: time.Millisecond * 500,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "equal",
|
name: "equal",
|
||||||
timeout: time.Second,
|
timeout: ptrOfDuration(time.Second),
|
||||||
expect: time.Second,
|
expect: time.Second,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "more",
|
name: "more",
|
||||||
timeout: time.Millisecond * 1500,
|
timeout: ptrOfDuration(time.Millisecond * 1500),
|
||||||
expect: time.Millisecond * 1500,
|
expect: time.Millisecond * 1500,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -394,9 +467,14 @@ func TestEngine_withTimeout(t *testing.T) {
|
|||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
test := test
|
test := test
|
||||||
t.Run(test.name, func(t *testing.T) {
|
t.Run(test.name, func(t *testing.T) {
|
||||||
ng := newEngine(RestConf{Timeout: test.timeout})
|
ng := newEngine(RestConf{
|
||||||
|
Timeout: test.timeout,
|
||||||
|
Middlewares: MiddlewaresConf{
|
||||||
|
Timeout: true,
|
||||||
|
},
|
||||||
|
})
|
||||||
svr := &http.Server{}
|
svr := &http.Server{}
|
||||||
ng.withTimeout()(svr)
|
ng.withNetworkTimeout()(svr)
|
||||||
|
|
||||||
assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*4/5, svr.ReadTimeout)
|
assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*4/5, svr.ReadTimeout)
|
||||||
assert.Equal(t, time.Duration(0), svr.ReadHeaderTimeout)
|
assert.Equal(t, time.Duration(0), svr.ReadHeaderTimeout)
|
||||||
@@ -406,6 +484,62 @@ func TestEngine_withTimeout(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEngine_ReadWriteTimeout(t *testing.T) {
|
||||||
|
logx.Disable()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
timeout int64
|
||||||
|
middleware bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "0/false",
|
||||||
|
timeout: 0,
|
||||||
|
middleware: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "0/true",
|
||||||
|
timeout: 0,
|
||||||
|
middleware: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "set/false",
|
||||||
|
timeout: 1000,
|
||||||
|
middleware: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "both set",
|
||||||
|
timeout: 1000,
|
||||||
|
middleware: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
test := test
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
ng := newEngine(RestConf{
|
||||||
|
Timeout: test.timeout,
|
||||||
|
Middlewares: MiddlewaresConf{
|
||||||
|
Timeout: test.middleware,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
svr := &http.Server{}
|
||||||
|
ng.withNetworkTimeout()(svr)
|
||||||
|
|
||||||
|
assert.Equal(t, time.Duration(0), svr.ReadHeaderTimeout)
|
||||||
|
assert.Equal(t, time.Duration(0), svr.IdleTimeout)
|
||||||
|
|
||||||
|
if test.timeout > 0 && test.middleware {
|
||||||
|
assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*4/5, svr.ReadTimeout)
|
||||||
|
assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*11/10, svr.WriteTimeout)
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, time.Duration(0), svr.ReadTimeout)
|
||||||
|
assert.Equal(t, time.Duration(0), svr.WriteTimeout)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestEngine_start(t *testing.T) {
|
func TestEngine_start(t *testing.T) {
|
||||||
logx.Disable()
|
logx.Disable()
|
||||||
|
|
||||||
|
|||||||
@@ -106,8 +106,8 @@ func (h *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
tw.mu.Lock()
|
tw.mu.Lock()
|
||||||
defer tw.mu.Unlock()
|
defer tw.mu.Unlock()
|
||||||
// there isn't any user-defined middleware before TimoutHandler,
|
// there isn't any user-defined middleware before TimeoutHandler,
|
||||||
// so we can guarantee that cancelation in biz related code won't come here.
|
// so we can guarantee that cancellation in biz related code won't come here.
|
||||||
httpx.ErrorCtx(r.Context(), w, ctx.Err(), func(w http.ResponseWriter, err error) {
|
httpx.ErrorCtx(r.Context(), w, ctx.Err(), func(w http.ResponseWriter, err error) {
|
||||||
if errors.Is(err, context.Canceled) {
|
if errors.Is(err, context.Canceled) {
|
||||||
w.WriteHeader(statusClientClosedRequest)
|
w.WriteHeader(statusClientClosedRequest)
|
||||||
@@ -151,7 +151,7 @@ func (tw *timeoutWriter) Flush() {
|
|||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Header returns the underline temporary http.Header.
|
// Header returns the underlying temporary http.Header.
|
||||||
func (tw *timeoutWriter) Header() http.Header {
|
func (tw *timeoutWriter) Header() http.Header {
|
||||||
return tw.h
|
return tw.h
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -105,7 +105,7 @@ func buildRequest(ctx context.Context, method, url string, data any) (*http.Requ
|
|||||||
req.URL.RawQuery = buildFormQuery(u, val[formKey])
|
req.URL.RawQuery = buildFormQuery(u, val[formKey])
|
||||||
fillHeader(req, val[headerKey])
|
fillHeader(req, val[headerKey])
|
||||||
if hasJsonBody {
|
if hasJsonBody {
|
||||||
req.Header.Set(header.ContentType, header.JsonContentType)
|
req.Header.Set(header.ContentType, header.ContentTypeJson)
|
||||||
}
|
}
|
||||||
|
|
||||||
return req, nil
|
return req, nil
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ func TestDoRequest_NotFound(t *testing.T) {
|
|||||||
defer svr.Close()
|
defer svr.Close()
|
||||||
req, err := http.NewRequest(http.MethodPost, svr.URL, nil)
|
req, err := http.NewRequest(http.MethodPost, svr.URL, nil)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
req.Header.Set(header.ContentType, header.JsonContentType)
|
req.Header.Set(header.ContentType, header.ContentTypeJson)
|
||||||
resp, err := DoRequest(req)
|
resp, err := DoRequest(req)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
|
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ func TestParse(t *testing.T) {
|
|||||||
}
|
}
|
||||||
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("foo", "bar")
|
w.Header().Set("foo", "bar")
|
||||||
w.Header().Set(header.ContentType, header.JsonContentType)
|
w.Header().Set(header.ContentType, header.ContentTypeJson)
|
||||||
w.Write([]byte(`{"name":"kevin","value":100}`))
|
w.Write([]byte(`{"name":"kevin","value":100}`))
|
||||||
}))
|
}))
|
||||||
defer svr.Close()
|
defer svr.Close()
|
||||||
@@ -38,7 +38,7 @@ func TestParseHeaderError(t *testing.T) {
|
|||||||
}
|
}
|
||||||
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("foo", "bar")
|
w.Header().Set("foo", "bar")
|
||||||
w.Header().Set(header.ContentType, header.JsonContentType)
|
w.Header().Set(header.ContentType, header.ContentTypeJson)
|
||||||
}))
|
}))
|
||||||
defer svr.Close()
|
defer svr.Close()
|
||||||
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
|
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
|
||||||
@@ -54,7 +54,7 @@ func TestParseNoBody(t *testing.T) {
|
|||||||
}
|
}
|
||||||
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("foo", "bar")
|
w.Header().Set("foo", "bar")
|
||||||
w.Header().Set(header.ContentType, header.JsonContentType)
|
w.Header().Set(header.ContentType, header.ContentTypeJson)
|
||||||
}))
|
}))
|
||||||
defer svr.Close()
|
defer svr.Close()
|
||||||
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
|
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
|
||||||
@@ -72,7 +72,7 @@ func TestParseWithZeroValue(t *testing.T) {
|
|||||||
}
|
}
|
||||||
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("foo", "0")
|
w.Header().Set("foo", "0")
|
||||||
w.Header().Set(header.ContentType, header.JsonContentType)
|
w.Header().Set(header.ContentType, header.ContentTypeJson)
|
||||||
w.Write([]byte(`{"bar":0}`))
|
w.Write([]byte(`{"bar":0}`))
|
||||||
}))
|
}))
|
||||||
defer svr.Close()
|
defer svr.Close()
|
||||||
@@ -90,7 +90,7 @@ func TestParseWithNegativeContentLength(t *testing.T) {
|
|||||||
Bar int `json:"bar"`
|
Bar int `json:"bar"`
|
||||||
}
|
}
|
||||||
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set(header.ContentType, header.JsonContentType)
|
w.Header().Set(header.ContentType, header.ContentTypeJson)
|
||||||
w.Write([]byte(`{"bar":0}`))
|
w.Write([]byte(`{"bar":0}`))
|
||||||
}))
|
}))
|
||||||
defer svr.Close()
|
defer svr.Close()
|
||||||
@@ -124,7 +124,7 @@ func TestParseWithNegativeContentLength(t *testing.T) {
|
|||||||
func TestParseWithNegativeContentLengthNoBody(t *testing.T) {
|
func TestParseWithNegativeContentLengthNoBody(t *testing.T) {
|
||||||
var val struct{}
|
var val struct{}
|
||||||
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set(header.ContentType, header.JsonContentType)
|
w.Header().Set(header.ContentType, header.ContentTypeJson)
|
||||||
}))
|
}))
|
||||||
defer svr.Close()
|
defer svr.Close()
|
||||||
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
|
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
|
||||||
@@ -156,7 +156,7 @@ func TestParseWithNegativeContentLengthNoBody(t *testing.T) {
|
|||||||
func TestParseJsonBody_BodyError(t *testing.T) {
|
func TestParseJsonBody_BodyError(t *testing.T) {
|
||||||
var val struct{}
|
var val struct{}
|
||||||
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set(header.ContentType, header.JsonContentType)
|
w.Header().Set(header.ContentType, header.ContentTypeJson)
|
||||||
}))
|
}))
|
||||||
defer svr.Close()
|
defer svr.Close()
|
||||||
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
|
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ func TestNamedService_DoRequestPost(t *testing.T) {
|
|||||||
service := NewService("foo")
|
service := NewService("foo")
|
||||||
req, err := http.NewRequest(http.MethodPost, svr.URL, nil)
|
req, err := http.NewRequest(http.MethodPost, svr.URL, nil)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
req.Header.Set(header.ContentType, header.JsonContentType)
|
req.Header.Set(header.ContentType, header.ContentTypeJson)
|
||||||
resp, err := service.DoRequest(req)
|
resp, err := service.DoRequest(req)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
|
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
|
||||||
|
|||||||
@@ -160,7 +160,7 @@ func TestParseFormArray(t *testing.T) {
|
|||||||
http.NoBody)
|
http.NoBody)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
if assert.NoError(t, Parse(r, &v)) {
|
if assert.NoError(t, Parse(r, &v)) {
|
||||||
assert.ElementsMatch(t, []string{"1", "2", "3"}, v.Names)
|
assert.ElementsMatch(t, []string{"1,2,3"}, v.Names)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -189,9 +189,7 @@ func TestParseFormArray(t *testing.T) {
|
|||||||
"/a?numbers=1,2,3",
|
"/a?numbers=1,2,3",
|
||||||
http.NoBody)
|
http.NoBody)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
if assert.NoError(t, Parse(r, &v)) {
|
assert.Error(t, Parse(r, &v))
|
||||||
assert.ElementsMatch(t, []int{1, 2, 3}, v.Numbers)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("slice with one value on array format brackets", func(t *testing.T) {
|
t.Run("slice with one value on array format brackets", func(t *testing.T) {
|
||||||
@@ -268,6 +266,36 @@ func TestParseFormArray(t *testing.T) {
|
|||||||
assert.ElementsMatch(t, []float64{2}, v.Numbers)
|
assert.ElementsMatch(t, []float64{2}, v.Numbers)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("slice with one value", func(t *testing.T) {
|
||||||
|
var v struct {
|
||||||
|
Codes []string `form:"codes"`
|
||||||
|
}
|
||||||
|
|
||||||
|
r, err := http.NewRequest(
|
||||||
|
http.MethodGet,
|
||||||
|
"/a?codes=aaa,bbb,ccc",
|
||||||
|
http.NoBody)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
if assert.NoError(t, Parse(r, &v)) {
|
||||||
|
assert.ElementsMatch(t, []string{"aaa,bbb,ccc"}, v.Codes)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("slice with multiple values", func(t *testing.T) {
|
||||||
|
var v struct {
|
||||||
|
Codes []string `form:"codes,arrayComma=false"`
|
||||||
|
}
|
||||||
|
|
||||||
|
r, err := http.NewRequest(
|
||||||
|
http.MethodGet,
|
||||||
|
"/a?codes=aaa,bbb,ccc&codes=ccc,ddd,eee",
|
||||||
|
http.NoBody)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
if assert.NoError(t, Parse(r, &v)) {
|
||||||
|
assert.ElementsMatch(t, []string{"aaa,bbb,ccc", "ccc,ddd,eee"}, v.Codes)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseForm_Error(t *testing.T) {
|
func TestParseForm_Error(t *testing.T) {
|
||||||
@@ -448,7 +476,7 @@ func TestParseJsonBody(t *testing.T) {
|
|||||||
|
|
||||||
body := `{"name":"kevin", "age": 18}`
|
body := `{"name":"kevin", "age": 18}`
|
||||||
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
|
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
|
||||||
r.Header.Set(ContentType, header.JsonContentType)
|
r.Header.Set(ContentType, header.ContentTypeJson)
|
||||||
|
|
||||||
if assert.NoError(t, Parse(r, &v)) {
|
if assert.NoError(t, Parse(r, &v)) {
|
||||||
assert.Equal(t, "kevin", v.Name)
|
assert.Equal(t, "kevin", v.Name)
|
||||||
@@ -464,7 +492,7 @@ func TestParseJsonBody(t *testing.T) {
|
|||||||
|
|
||||||
body := `{"name":"kevin", "ag": 18}`
|
body := `{"name":"kevin", "ag": 18}`
|
||||||
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
|
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
|
||||||
r.Header.Set(ContentType, header.JsonContentType)
|
r.Header.Set(ContentType, header.ContentTypeJson)
|
||||||
|
|
||||||
assert.Error(t, Parse(r, &v))
|
assert.Error(t, Parse(r, &v))
|
||||||
})
|
})
|
||||||
@@ -489,7 +517,7 @@ func TestParseJsonBody(t *testing.T) {
|
|||||||
|
|
||||||
body := `[{"name":"kevin", "age": 18}]`
|
body := `[{"name":"kevin", "age": 18}]`
|
||||||
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
|
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
|
||||||
r.Header.Set(ContentType, header.JsonContentType)
|
r.Header.Set(ContentType, header.ContentTypeJson)
|
||||||
|
|
||||||
assert.NoError(t, Parse(r, &v))
|
assert.NoError(t, Parse(r, &v))
|
||||||
assert.Equal(t, 1, len(v))
|
assert.Equal(t, 1, len(v))
|
||||||
@@ -509,7 +537,7 @@ func TestParseJsonBody(t *testing.T) {
|
|||||||
|
|
||||||
body := `[{"name":"apple", "age": 18}]`
|
body := `[{"name":"apple", "age": 18}]`
|
||||||
r := httptest.NewRequest(http.MethodPost, "/a?product=tree", strings.NewReader(body))
|
r := httptest.NewRequest(http.MethodPost, "/a?product=tree", strings.NewReader(body))
|
||||||
r.Header.Set(ContentType, header.JsonContentType)
|
r.Header.Set(ContentType, header.ContentTypeJson)
|
||||||
|
|
||||||
assert.NoError(t, Parse(r, &v))
|
assert.NoError(t, Parse(r, &v))
|
||||||
assert.Equal(t, 1, len(v))
|
assert.Equal(t, 1, len(v))
|
||||||
@@ -527,7 +555,7 @@ func TestParseJsonBody(t *testing.T) {
|
|||||||
body, _ := json.Marshal(v1)
|
body, _ := json.Marshal(v1)
|
||||||
t.Logf("body:%s", string(body))
|
t.Logf("body:%s", string(body))
|
||||||
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(body)))
|
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(body)))
|
||||||
r.Header.Set(ContentType, header.JsonContentType)
|
r.Header.Set(ContentType, header.ContentTypeJson)
|
||||||
var v2 v
|
var v2 v
|
||||||
err := ParseJsonBody(r, &v2)
|
err := ParseJsonBody(r, &v2)
|
||||||
if assert.NoError(t, err) {
|
if assert.NoError(t, err) {
|
||||||
@@ -581,7 +609,7 @@ func TestParseHeaders(t *testing.T) {
|
|||||||
request.Header.Add("addrs", "addr2")
|
request.Header.Add("addrs", "addr2")
|
||||||
request.Header.Add("X-Forwarded-For", "10.0.10.11")
|
request.Header.Add("X-Forwarded-For", "10.0.10.11")
|
||||||
request.Header.Add("x-real-ip", "10.0.11.10")
|
request.Header.Add("x-real-ip", "10.0.11.10")
|
||||||
request.Header.Add("Accept", header.JsonContentType)
|
request.Header.Add("Accept", header.ContentTypeJson)
|
||||||
err = ParseHeaders(request, &v)
|
err = ParseHeaders(request, &v)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@@ -591,7 +619,7 @@ func TestParseHeaders(t *testing.T) {
|
|||||||
assert.Equal(t, []string{"addr1", "addr2"}, v.Addrs)
|
assert.Equal(t, []string{"addr1", "addr2"}, v.Addrs)
|
||||||
assert.Equal(t, "10.0.10.11", v.XForwardedFor)
|
assert.Equal(t, "10.0.10.11", v.XForwardedFor)
|
||||||
assert.Equal(t, "10.0.11.10", v.XRealIP)
|
assert.Equal(t, "10.0.11.10", v.XRealIP)
|
||||||
assert.Equal(t, header.JsonContentType, v.Accept)
|
assert.Equal(t, header.ContentTypeJson, v.Accept)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseHeaders_Error(t *testing.T) {
|
func TestParseHeaders_Error(t *testing.T) {
|
||||||
@@ -683,7 +711,7 @@ func TestParseWithFloatPtr(t *testing.T) {
|
|||||||
}
|
}
|
||||||
body := `{"weightFloat32": 3.2}`
|
body := `{"weightFloat32": 3.2}`
|
||||||
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
|
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
|
||||||
r.Header.Set(ContentType, header.JsonContentType)
|
r.Header.Set(ContentType, header.ContentTypeJson)
|
||||||
|
|
||||||
if assert.NoError(t, Parse(r, &v)) {
|
if assert.NoError(t, Parse(r, &v)) {
|
||||||
assert.Equal(t, float32(3.2), *v.WeightFloat32)
|
assert.Equal(t, float32(3.2), *v.WeightFloat32)
|
||||||
|
|||||||
@@ -179,7 +179,7 @@ func doWriteJson(w http.ResponseWriter, code int, v any) error {
|
|||||||
return fmt.Errorf("marshal json failed, error: %w", err)
|
return fmt.Errorf("marshal json failed, error: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
w.Header().Set(ContentType, header.JsonContentType)
|
w.Header().Set(ContentType, header.ContentTypeJson)
|
||||||
w.WriteHeader(code)
|
w.WriteHeader(code)
|
||||||
|
|
||||||
if n, err := w.Write(bs); err != nil {
|
if n, err := w.Write(bs); err != nil {
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ const (
|
|||||||
// ContentType means Content-Type.
|
// ContentType means Content-Type.
|
||||||
ContentType = header.ContentType
|
ContentType = header.ContentType
|
||||||
// JsonContentType means application/json.
|
// JsonContentType means application/json.
|
||||||
JsonContentType = header.JsonContentType
|
JsonContentType = header.ContentTypeJson
|
||||||
// KeyField means key.
|
// KeyField means key.
|
||||||
KeyField = "key"
|
KeyField = "key"
|
||||||
// SecretField means secret.
|
// SecretField means secret.
|
||||||
|
|||||||
@@ -2,15 +2,16 @@ package fileserver
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"path"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Middleware returns a middleware that serves files from the given file system.
|
// Middleware returns a middleware that serves files from the given file system.
|
||||||
func Middleware(path string, fs http.FileSystem) func(http.HandlerFunc) http.HandlerFunc {
|
func Middleware(upath string, fs http.FileSystem) func(http.HandlerFunc) http.HandlerFunc {
|
||||||
fileServer := http.FileServer(fs)
|
fileServer := http.FileServer(fs)
|
||||||
pathWithoutTrailSlash := ensureNoTrailingSlash(path)
|
pathWithoutTrailSlash := ensureNoTrailingSlash(upath)
|
||||||
canServe := createServeChecker(path, fs)
|
canServe := createServeChecker(upath, fs)
|
||||||
|
|
||||||
return func(next http.HandlerFunc) http.HandlerFunc {
|
return func(next http.HandlerFunc) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -28,9 +29,22 @@ func createFileChecker(fs http.FileSystem) func(string) bool {
|
|||||||
var lock sync.RWMutex
|
var lock sync.RWMutex
|
||||||
fileChecker := make(map[string]bool)
|
fileChecker := make(map[string]bool)
|
||||||
|
|
||||||
return func(path string) bool {
|
return func(upath string) bool {
|
||||||
|
// Emulate http.Dir.Open’s path normalization for embed.FS.Open.
|
||||||
|
// http.FileServer redirects any request ending in "/index.html"
|
||||||
|
// to the same path without the final "index.html".
|
||||||
|
// So the path here may be empty or end with a "/".
|
||||||
|
// http.Dir.Open uses this logic to clean the path,
|
||||||
|
// correctly handling those two cases.
|
||||||
|
// embed.FS doesn’t perform this normalization, so we apply the same logic here.
|
||||||
|
upath = path.Clean("/" + upath)[1:]
|
||||||
|
if len(upath) == 0 {
|
||||||
|
// if the path is empty, we use "." to open the current directory
|
||||||
|
upath = "."
|
||||||
|
}
|
||||||
|
|
||||||
lock.RLock()
|
lock.RLock()
|
||||||
exist, ok := fileChecker[path]
|
exist, ok := fileChecker[upath]
|
||||||
lock.RUnlock()
|
lock.RUnlock()
|
||||||
if ok {
|
if ok {
|
||||||
return exist
|
return exist
|
||||||
@@ -39,9 +53,9 @@ func createFileChecker(fs http.FileSystem) func(string) bool {
|
|||||||
lock.Lock()
|
lock.Lock()
|
||||||
defer lock.Unlock()
|
defer lock.Unlock()
|
||||||
|
|
||||||
file, err := fs.Open(path)
|
file, err := fs.Open(upath)
|
||||||
exist = err == nil
|
exist = err == nil
|
||||||
fileChecker[path] = exist
|
fileChecker[upath] = exist
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -51,8 +65,8 @@ func createFileChecker(fs http.FileSystem) func(string) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func createServeChecker(path string, fs http.FileSystem) func(r *http.Request) bool {
|
func createServeChecker(upath string, fs http.FileSystem) func(r *http.Request) bool {
|
||||||
pathWithTrailSlash := ensureTrailingSlash(path)
|
pathWithTrailSlash := ensureTrailingSlash(upath)
|
||||||
fileChecker := createFileChecker(fs)
|
fileChecker := createFileChecker(fs)
|
||||||
|
|
||||||
return func(r *http.Request) bool {
|
return func(r *http.Request) bool {
|
||||||
@@ -62,18 +76,18 @@ func createServeChecker(path string, fs http.FileSystem) func(r *http.Request) b
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ensureTrailingSlash(path string) string {
|
func ensureTrailingSlash(upath string) string {
|
||||||
if strings.HasSuffix(path, "/") {
|
if strings.HasSuffix(upath, "/") {
|
||||||
return path
|
return upath
|
||||||
}
|
}
|
||||||
|
|
||||||
return path + "/"
|
return upath + "/"
|
||||||
}
|
}
|
||||||
|
|
||||||
func ensureNoTrailingSlash(path string) string {
|
func ensureNoTrailingSlash(upath string) string {
|
||||||
if strings.HasSuffix(path, "/") {
|
if strings.HasSuffix(upath, "/") {
|
||||||
return path[:len(path)-1]
|
return upath[:len(upath)-1]
|
||||||
}
|
}
|
||||||
|
|
||||||
return path
|
return upath
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
package fileserver
|
package fileserver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"embed"
|
||||||
|
"io/fs"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -61,6 +63,46 @@ func TestMiddleware(t *testing.T) {
|
|||||||
requestPath: "/ws",
|
requestPath: "/ws",
|
||||||
expectedStatus: http.StatusAlreadyReported,
|
expectedStatus: http.StatusAlreadyReported,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
// http.FileServer redirects any request ending in "/index.html"
|
||||||
|
// to the same path, without the final "index.html".
|
||||||
|
{
|
||||||
|
name: "Serve index.html",
|
||||||
|
path: "/static",
|
||||||
|
dir: "testdata",
|
||||||
|
requestPath: "/static/index.html",
|
||||||
|
expectedStatus: http.StatusMovedPermanently,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Serve index.html with path with trailing slash",
|
||||||
|
path: "/static/",
|
||||||
|
dir: "testdata",
|
||||||
|
requestPath: "/static/index.html",
|
||||||
|
expectedStatus: http.StatusMovedPermanently,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Serve index.html in a nested directory",
|
||||||
|
path: "/static",
|
||||||
|
dir: "testdata",
|
||||||
|
requestPath: "/static/nested/index.html",
|
||||||
|
expectedStatus: http.StatusMovedPermanently,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Request index.html indirectly",
|
||||||
|
path: "/static",
|
||||||
|
dir: "testdata",
|
||||||
|
requestPath: "/static/",
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedContent: "hello",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Request index.html in a nested directory indirectly",
|
||||||
|
path: "/static",
|
||||||
|
dir: "testdata",
|
||||||
|
requestPath: "/static/nested/",
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedContent: "hello",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@@ -87,6 +129,128 @@ func TestMiddleware(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
//go:embed testdata
|
||||||
|
testdataFS embed.FS
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMiddleware_embedFS(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
path string
|
||||||
|
requestPath string
|
||||||
|
expectedStatus int
|
||||||
|
expectedContent string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Serve static file",
|
||||||
|
path: "/static",
|
||||||
|
requestPath: "/static/example.txt",
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedContent: "1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Path with trailing slash",
|
||||||
|
path: "/static/",
|
||||||
|
requestPath: "/static/example.txt",
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedContent: "1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Root path",
|
||||||
|
path: "/",
|
||||||
|
requestPath: "/example.txt",
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedContent: "1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Pass through non-matching path",
|
||||||
|
path: "/static/",
|
||||||
|
requestPath: "/other/path",
|
||||||
|
expectedStatus: http.StatusAlreadyReported,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Not exist file",
|
||||||
|
path: "/assets",
|
||||||
|
requestPath: "/assets/not-exist.txt",
|
||||||
|
expectedStatus: http.StatusAlreadyReported,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Not exist file in root",
|
||||||
|
path: "/",
|
||||||
|
requestPath: "/not-exist.txt",
|
||||||
|
expectedStatus: http.StatusAlreadyReported,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "websocket request",
|
||||||
|
path: "/",
|
||||||
|
requestPath: "/ws",
|
||||||
|
expectedStatus: http.StatusAlreadyReported,
|
||||||
|
},
|
||||||
|
|
||||||
|
// http.FileServer redirects any request ending in "/index.html"
|
||||||
|
// to the same path, without the final "index.html".
|
||||||
|
{
|
||||||
|
name: "Serve index.html",
|
||||||
|
path: "/static",
|
||||||
|
requestPath: "/static/index.html",
|
||||||
|
expectedStatus: http.StatusMovedPermanently,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Serve index.html with path with trailing slash",
|
||||||
|
path: "/static/",
|
||||||
|
requestPath: "/static/index.html",
|
||||||
|
expectedStatus: http.StatusMovedPermanently,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Serve index.html in a nested directory",
|
||||||
|
path: "/static",
|
||||||
|
requestPath: "/static/nested/index.html",
|
||||||
|
expectedStatus: http.StatusMovedPermanently,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Request index.html indirectly",
|
||||||
|
path: "/static",
|
||||||
|
requestPath: "/static/",
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedContent: "hello",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Request index.html in a nested directory indirectly",
|
||||||
|
path: "/static",
|
||||||
|
requestPath: "/static/nested/",
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedContent: "hello",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
subFS, err := fs.Sub(testdataFS, "testdata")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
middleware := Middleware(tt.path, http.FS(subFS))
|
||||||
|
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusAlreadyReported)
|
||||||
|
})
|
||||||
|
|
||||||
|
handlerToTest := middleware(nextHandler)
|
||||||
|
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, tt.requestPath, nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handlerToTest.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
assert.Equal(t, tt.expectedStatus, rr.Code)
|
||||||
|
if len(tt.expectedContent) > 0 {
|
||||||
|
assert.Equal(t, tt.expectedContent, rr.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestEnsureTrailingSlash(t *testing.T) {
|
func TestEnsureTrailingSlash(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
input string
|
input string
|
||||||
|
|||||||
1
rest/internal/fileserver/testdata/index.html
vendored
Normal file
1
rest/internal/fileserver/testdata/index.html
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
hello
|
||||||
1
rest/internal/fileserver/testdata/nested/index.html
vendored
Normal file
1
rest/internal/fileserver/testdata/nested/index.html
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
hello
|
||||||
@@ -3,8 +3,18 @@ package header
|
|||||||
const (
|
const (
|
||||||
// ApplicationJson stands for application/json.
|
// ApplicationJson stands for application/json.
|
||||||
ApplicationJson = "application/json"
|
ApplicationJson = "application/json"
|
||||||
|
// CacheControl is the header key for Cache-Control.
|
||||||
|
CacheControl = "Cache-Control"
|
||||||
|
// CacheControlNoCache is the value for Cache-Control: no-cache.
|
||||||
|
CacheControlNoCache = "no-cache"
|
||||||
|
// Connection is the header key for Connection.
|
||||||
|
Connection = "Connection"
|
||||||
|
// ConnectionKeepAlive is the value for Connection: keep-alive.
|
||||||
|
ConnectionKeepAlive = "keep-alive"
|
||||||
// ContentType is the header key for Content-Type.
|
// ContentType is the header key for Content-Type.
|
||||||
ContentType = "Content-Type"
|
ContentType = "Content-Type"
|
||||||
// JsonContentType is the content type for JSON.
|
// ContentTypeJson is the content type for JSON.
|
||||||
JsonContentType = "application/json; charset=utf-8"
|
ContentTypeJson = "application/json; charset=utf-8"
|
||||||
|
// ContentTypeEventStream is the content type for event stream.
|
||||||
|
ContentTypeEventStream = "text/event-stream"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -628,7 +628,7 @@ func TestParseWrappedRequest(t *testing.T) {
|
|||||||
func TestParseWrappedGetRequestWithJsonHeader(t *testing.T) {
|
func TestParseWrappedGetRequestWithJsonHeader(t *testing.T) {
|
||||||
r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017", bytes.NewReader(nil))
|
r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017", bytes.NewReader(nil))
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
r.Header.Set(httpx.ContentType, header.JsonContentType)
|
r.Header.Set(httpx.ContentType, header.ContentTypeJson)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
Request struct {
|
Request struct {
|
||||||
@@ -661,7 +661,7 @@ func TestParseWrappedGetRequestWithJsonHeader(t *testing.T) {
|
|||||||
func TestParseWrappedHeadRequestWithJsonHeader(t *testing.T) {
|
func TestParseWrappedHeadRequestWithJsonHeader(t *testing.T) {
|
||||||
r, err := http.NewRequest(http.MethodHead, "http://hello.com/kevin/2017", bytes.NewReader(nil))
|
r, err := http.NewRequest(http.MethodHead, "http://hello.com/kevin/2017", bytes.NewReader(nil))
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
r.Header.Set(httpx.ContentType, header.JsonContentType)
|
r.Header.Set(httpx.ContentType, header.ContentTypeJson)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
Request struct {
|
Request struct {
|
||||||
@@ -758,7 +758,7 @@ func TestParseWithAllUtf8(t *testing.T) {
|
|||||||
r, err := http.NewRequest(http.MethodPost, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000",
|
r, err := http.NewRequest(http.MethodPost, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000",
|
||||||
bytes.NewBufferString(`{"location": "shanghai", "time": 20170912}`))
|
bytes.NewBufferString(`{"location": "shanghai", "time": 20170912}`))
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
r.Header.Set(httpx.ContentType, header.JsonContentType)
|
r.Header.Set(httpx.ContentType, header.ContentTypeJson)
|
||||||
|
|
||||||
router := NewRouter()
|
router := NewRouter()
|
||||||
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
|
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
|
||||||
@@ -948,7 +948,7 @@ func TestParseWithMissingAllPaths(t *testing.T) {
|
|||||||
func TestParseGetWithContentLengthHeader(t *testing.T) {
|
func TestParseGetWithContentLengthHeader(t *testing.T) {
|
||||||
r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000", nil)
|
r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000", nil)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
r.Header.Set(httpx.ContentType, header.JsonContentType)
|
r.Header.Set(httpx.ContentType, header.ContentTypeJson)
|
||||||
r.Header.Set(contentLength, "1024")
|
r.Header.Set(contentLength, "1024")
|
||||||
|
|
||||||
router := NewRouter()
|
router := NewRouter()
|
||||||
@@ -976,7 +976,7 @@ func TestParseJsonPostWithTypeMismatch(t *testing.T) {
|
|||||||
r, err := http.NewRequest(http.MethodPost, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000",
|
r, err := http.NewRequest(http.MethodPost, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000",
|
||||||
bytes.NewBufferString(`{"time": "20170912"}`))
|
bytes.NewBufferString(`{"time": "20170912"}`))
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
r.Header.Set(httpx.ContentType, header.JsonContentType)
|
r.Header.Set(httpx.ContentType, header.ContentTypeJson)
|
||||||
|
|
||||||
router := NewRouter()
|
router := NewRouter()
|
||||||
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
|
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
|
||||||
@@ -1002,7 +1002,7 @@ func TestParseJsonPostWithInt2String(t *testing.T) {
|
|||||||
r, err := http.NewRequest(http.MethodPost, "http://hello.com/kevin/2017",
|
r, err := http.NewRequest(http.MethodPost, "http://hello.com/kevin/2017",
|
||||||
bytes.NewBufferString(`{"time": 20170912}`))
|
bytes.NewBufferString(`{"time": 20170912}`))
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
r.Header.Set(httpx.ContentType, header.JsonContentType)
|
r.Header.Set(httpx.ContentType, header.ContentTypeJson)
|
||||||
|
|
||||||
router := NewRouter()
|
router := NewRouter()
|
||||||
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
|
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
|
||||||
|
|||||||
@@ -63,6 +63,11 @@ func NewServer(c RestConf, opts ...RunOption) (*Server, error) {
|
|||||||
return server, nil
|
return server, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AddRoute adds given route into the Server.
|
||||||
|
func (s *Server) AddRoute(r Route, opts ...RouteOption) {
|
||||||
|
s.AddRoutes([]Route{r}, opts...)
|
||||||
|
}
|
||||||
|
|
||||||
// AddRoutes add given routes into the Server.
|
// AddRoutes add given routes into the Server.
|
||||||
func (s *Server) AddRoutes(rs []Route, opts ...RouteOption) {
|
func (s *Server) AddRoutes(rs []Route, opts ...RouteOption) {
|
||||||
r := featuredRoutes{
|
r := featuredRoutes{
|
||||||
@@ -74,11 +79,6 @@ func (s *Server) AddRoutes(rs []Route, opts ...RouteOption) {
|
|||||||
s.ngin.addRoutes(r)
|
s.ngin.addRoutes(r)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddRoute adds given route into the Server.
|
|
||||||
func (s *Server) AddRoute(r Route, opts ...RouteOption) {
|
|
||||||
s.AddRoutes([]Route{r}, opts...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// PrintRoutes prints the added routes to stdout.
|
// PrintRoutes prints the added routes to stdout.
|
||||||
func (s *Server) PrintRoutes() {
|
func (s *Server) PrintRoutes() {
|
||||||
s.ngin.print()
|
s.ngin.print()
|
||||||
@@ -95,25 +95,6 @@ func (s *Server) Routes() []Route {
|
|||||||
return routes
|
return routes
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServeHTTP is for test purpose, allow developer to do a unit test with
|
|
||||||
// all defined router without starting an HTTP Server.
|
|
||||||
//
|
|
||||||
// For example:
|
|
||||||
//
|
|
||||||
// server := MustNewServer(...)
|
|
||||||
// server.addRoute(...) // router a
|
|
||||||
// server.addRoute(...) // router b
|
|
||||||
// server.addRoute(...) // router c
|
|
||||||
//
|
|
||||||
// r, _ := http.NewRequest(...)
|
|
||||||
// w := httptest.NewRecorder(...)
|
|
||||||
// server.ServeHTTP(w, r)
|
|
||||||
// // verify the response
|
|
||||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
||||||
s.ngin.bindRoutes(s.router)
|
|
||||||
s.router.ServeHTTP(w, r)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start starts the Server.
|
// Start starts the Server.
|
||||||
// Graceful shutdown is enabled by default.
|
// Graceful shutdown is enabled by default.
|
||||||
// Use proc.SetTimeToForceQuit to customize the graceful shutdown period.
|
// Use proc.SetTimeToForceQuit to customize the graceful shutdown period.
|
||||||
@@ -138,6 +119,16 @@ func (s *Server) Use(middleware Middleware) {
|
|||||||
s.ngin.use(middleware)
|
s.ngin.use(middleware)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// build builds the Server and binds the routes to the router.
|
||||||
|
func (s *Server) build() error {
|
||||||
|
return s.ngin.bindRoutes(s.router)
|
||||||
|
}
|
||||||
|
|
||||||
|
// serve serves the HTTP requests using the Server's router.
|
||||||
|
func (s *Server) serve(w http.ResponseWriter, r *http.Request) {
|
||||||
|
s.router.ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
// ToMiddleware converts the given handler to a Middleware.
|
// ToMiddleware converts the given handler to a Middleware.
|
||||||
func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware {
|
func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware {
|
||||||
return func(handle http.HandlerFunc) http.HandlerFunc {
|
return func(handle http.HandlerFunc) http.HandlerFunc {
|
||||||
@@ -298,10 +289,18 @@ func WithSignature(signature SignatureConf) RouteOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithSSE returns a RouteOption to enable server-sent events.
|
||||||
|
func WithSSE() RouteOption {
|
||||||
|
return func(r *featuredRoutes) {
|
||||||
|
r.sse = true
|
||||||
|
r.timeout = ptrOfDuration(0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// WithTimeout returns a RouteOption to set timeout with given value.
|
// WithTimeout returns a RouteOption to set timeout with given value.
|
||||||
func WithTimeout(timeout time.Duration) RouteOption {
|
func WithTimeout(timeout time.Duration) RouteOption {
|
||||||
return func(r *featuredRoutes) {
|
return func(r *featuredRoutes) {
|
||||||
r.timeout = timeout
|
r.timeout = &timeout
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -336,6 +335,10 @@ func handleError(err error) {
|
|||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ptrOfDuration(d time.Duration) *time.Duration {
|
||||||
|
return &d
|
||||||
|
}
|
||||||
|
|
||||||
func validateSecret(secret string) {
|
func validateSecret(secret string) {
|
||||||
if len(secret) < 8 {
|
if len(secret) < 8 {
|
||||||
panic("secret's length can't be less than 8")
|
panic("secret's length can't be less than 8")
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import (
|
|||||||
"github.com/zeromicro/go-zero/rest/chain"
|
"github.com/zeromicro/go-zero/rest/chain"
|
||||||
"github.com/zeromicro/go-zero/rest/httpx"
|
"github.com/zeromicro/go-zero/rest/httpx"
|
||||||
"github.com/zeromicro/go-zero/rest/internal/cors"
|
"github.com/zeromicro/go-zero/rest/internal/cors"
|
||||||
|
"github.com/zeromicro/go-zero/rest/internal/header"
|
||||||
"github.com/zeromicro/go-zero/rest/router"
|
"github.com/zeromicro/go-zero/rest/router"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -231,7 +232,7 @@ func TestWithFileServerMiddleware(t *testing.T) {
|
|||||||
req := httptest.NewRequest(http.MethodGet, tt.requestPath, nil)
|
req := httptest.NewRequest(http.MethodGet, tt.requestPath, nil)
|
||||||
rr := httptest.NewRecorder()
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
server.ServeHTTP(rr, req)
|
serve(server, rr, req)
|
||||||
|
|
||||||
assert.Equal(t, tt.expectedStatus, rr.Code)
|
assert.Equal(t, tt.expectedStatus, rr.Code)
|
||||||
if len(tt.expectedContent) > 0 {
|
if len(tt.expectedContent) > 0 {
|
||||||
@@ -344,7 +345,7 @@ func TestWithPriority(t *testing.T) {
|
|||||||
func TestWithTimeout(t *testing.T) {
|
func TestWithTimeout(t *testing.T) {
|
||||||
var fr featuredRoutes
|
var fr featuredRoutes
|
||||||
WithTimeout(time.Hour)(&fr)
|
WithTimeout(time.Hour)(&fr)
|
||||||
assert.Equal(t, time.Hour, fr.timeout)
|
assert.Equal(t, time.Hour, *fr.timeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWithTLSConfig(t *testing.T) {
|
func TestWithTLSConfig(t *testing.T) {
|
||||||
@@ -458,7 +459,7 @@ Port: 54321
|
|||||||
// we would need to verify the behavior here. Since we don't have
|
// we would need to verify the behavior here. Since we don't have
|
||||||
// direct access to headers, we'll mock newCorsRouter to capture it.
|
// direct access to headers, we'll mock newCorsRouter to capture it.
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
svr.ServeHTTP(w, httptest.NewRequest(http.MethodOptions, "/", nil))
|
serve(svr, w, httptest.NewRequest(http.MethodOptions, "/", nil))
|
||||||
|
|
||||||
vals := w.Header().Values("Access-Control-Allow-Headers")
|
vals := w.Header().Values("Access-Control-Allow-Headers")
|
||||||
respHeaders := make(map[string]struct{})
|
respHeaders := make(map[string]struct{})
|
||||||
@@ -748,12 +749,46 @@ Port: 54321
|
|||||||
t.Run(test.name, func(t *testing.T) {
|
t.Run(test.name, func(t *testing.T) {
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
req, _ := http.NewRequest("GET", test.path, nil)
|
req, _ := http.NewRequest("GET", test.path, nil)
|
||||||
svr.ServeHTTP(w, req)
|
serve(svr, w, req)
|
||||||
assert.Equal(t, test.code, w.Code)
|
assert.Equal(t, test.code, w.Code)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestServerEventStream(t *testing.T) {
|
||||||
|
server := MustNewServer(RestConf{})
|
||||||
|
server.AddRoutes([]Route{
|
||||||
|
{
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Path: "/foo",
|
||||||
|
Handler: func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte("foo"))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Path: "/bar",
|
||||||
|
Handler: func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte("bar"))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, WithSSE())
|
||||||
|
|
||||||
|
check := func(val string) {
|
||||||
|
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("/%s", val), http.NoBody)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
serve(server, rr, req)
|
||||||
|
assert.Equal(t, http.StatusOK, rr.Code)
|
||||||
|
assert.Equal(t, header.ContentTypeEventStream, rr.Header().Get(header.ContentType))
|
||||||
|
assert.Equal(t, header.CacheControlNoCache, rr.Header().Get(header.CacheControl))
|
||||||
|
assert.Equal(t, header.ConnectionKeepAlive, rr.Header().Get(header.Connection))
|
||||||
|
assert.Equal(t, val, rr.Body.String())
|
||||||
|
}
|
||||||
|
check("foo")
|
||||||
|
check("bar")
|
||||||
|
}
|
||||||
|
|
||||||
//go:embed testdata
|
//go:embed testdata
|
||||||
var content embed.FS
|
var content embed.FS
|
||||||
|
|
||||||
@@ -765,6 +800,25 @@ func TestServerEmbedFileSystem(t *testing.T) {
|
|||||||
req, err := http.NewRequest(http.MethodGet, "/assets/sample.txt", http.NoBody)
|
req, err := http.NewRequest(http.MethodGet, "/assets/sample.txt", http.NoBody)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
rr := httptest.NewRecorder()
|
rr := httptest.NewRecorder()
|
||||||
server.ServeHTTP(rr, req)
|
serve(server, rr, req)
|
||||||
assert.Equal(t, sampleContent, rr.Body.String())
|
assert.Equal(t, sampleContent, rr.Body.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// serve is for test purpose, allow developer to do a unit test with
|
||||||
|
// all defined routes without starting an HTTP Server.
|
||||||
|
//
|
||||||
|
// For example:
|
||||||
|
//
|
||||||
|
// server := MustNewServer(...)
|
||||||
|
// server.addRoute(...) // router a
|
||||||
|
// server.addRoute(...) // router b
|
||||||
|
// server.addRoute(...) // router c
|
||||||
|
//
|
||||||
|
// r, _ := http.NewRequest(...)
|
||||||
|
// w := httptest.NewRecorder(...)
|
||||||
|
// serve(server, w, r)
|
||||||
|
// // verify the response
|
||||||
|
func serve(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||||
|
_ = s.build()
|
||||||
|
s.serve(w, r)
|
||||||
|
}
|
||||||
|
|||||||
27
rest/serverless.go
Normal file
27
rest/serverless.go
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
package rest
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
// Serverless is a wrapper around Server that allows it to be used in serverless environments.
|
||||||
|
type Serverless struct {
|
||||||
|
server *Server
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewServerless creates a new Serverless instance from the provided Server.
|
||||||
|
func NewServerless(server *Server) (*Serverless, error) {
|
||||||
|
// Ensure the server is built before using it in a serverless context.
|
||||||
|
// Why not call server.build() when serving requests,
|
||||||
|
// is because we need to ensure fail fast behavior.
|
||||||
|
if err := server.build(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Serverless{
|
||||||
|
server: server,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serve handles HTTP requests by delegating them to the underlying Server instance.
|
||||||
|
func (s *Serverless) Serve(w http.ResponseWriter, r *http.Request) {
|
||||||
|
s.server.serve(w, r)
|
||||||
|
}
|
||||||
67
rest/serverless_test.go
Normal file
67
rest/serverless_test.go
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
package rest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/zeromicro/go-zero/core/conf"
|
||||||
|
"github.com/zeromicro/go-zero/core/logx/logtest"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewServerless(t *testing.T) {
|
||||||
|
logtest.Discard(t)
|
||||||
|
|
||||||
|
const configYaml = `
|
||||||
|
Name: foo
|
||||||
|
Host: localhost
|
||||||
|
Port: 0
|
||||||
|
`
|
||||||
|
var cnf RestConf
|
||||||
|
assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
|
||||||
|
|
||||||
|
svr, err := NewServer(cnf)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
svr.AddRoute(Route{
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Path: "/",
|
||||||
|
Handler: func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte("Hello World"))
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
serverless, err := NewServerless(svr)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
serverless.Serve(w, r)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Equal(t, "Hello World", w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewServerlessWithError(t *testing.T) {
|
||||||
|
logtest.Discard(t)
|
||||||
|
|
||||||
|
const configYaml = `
|
||||||
|
Name: foo
|
||||||
|
Host: localhost
|
||||||
|
Port: 0
|
||||||
|
`
|
||||||
|
var cnf RestConf
|
||||||
|
assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
|
||||||
|
|
||||||
|
svr, err := NewServer(cnf)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
svr.AddRoute(Route{
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Path: "notstartwith/",
|
||||||
|
Handler: nil,
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err = NewServerless(svr)
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
@@ -31,10 +31,11 @@ type (
|
|||||||
}
|
}
|
||||||
|
|
||||||
featuredRoutes struct {
|
featuredRoutes struct {
|
||||||
timeout time.Duration
|
timeout *time.Duration
|
||||||
priority bool
|
priority bool
|
||||||
jwt jwtSetting
|
jwt jwtSetting
|
||||||
signature signatureSetting
|
signature signatureSetting
|
||||||
|
sse bool
|
||||||
routes []Route
|
routes []Route
|
||||||
maxBytes int64
|
maxBytes int64
|
||||||
}
|
}
|
||||||
|
|||||||
1
tools/goctl/.gitignore
vendored
Normal file
1
tools/goctl/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
dist
|
||||||
@@ -3,7 +3,8 @@ FROM golang:alpine AS builder
|
|||||||
LABEL stage=gobuilder
|
LABEL stage=gobuilder
|
||||||
|
|
||||||
ENV CGO_ENABLED=0
|
ENV CGO_ENABLED=0
|
||||||
ENV GOPROXY=https://goproxy.cn,direct
|
# if you are in China, you can use the following command to speed up the download
|
||||||
|
# ENV GOPROXY=https://goproxy.cn,direct
|
||||||
|
|
||||||
RUN apk update --no-cache && apk add --no-cache tzdata
|
RUN apk update --no-cache && apk add --no-cache tzdata
|
||||||
RUN go install google.golang.org/protobuf/cmd/protoc-gen-go@latest
|
RUN go install google.golang.org/protobuf/cmd/protoc-gen-go@latest
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package api
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/api/apigen"
|
"github.com/zeromicro/go-zero/tools/goctl/api/apigen"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/api/dartgen"
|
"github.com/zeromicro/go-zero/tools/goctl/api/dartgen"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/api/docgen"
|
"github.com/zeromicro/go-zero/tools/goctl/api/docgen"
|
||||||
@@ -10,6 +11,7 @@ import (
|
|||||||
"github.com/zeromicro/go-zero/tools/goctl/api/javagen"
|
"github.com/zeromicro/go-zero/tools/goctl/api/javagen"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/api/ktgen"
|
"github.com/zeromicro/go-zero/tools/goctl/api/ktgen"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/api/new"
|
"github.com/zeromicro/go-zero/tools/goctl/api/new"
|
||||||
|
"github.com/zeromicro/go-zero/tools/goctl/api/swagger"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/api/tsgen"
|
"github.com/zeromicro/go-zero/tools/goctl/api/tsgen"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/api/validate"
|
"github.com/zeromicro/go-zero/tools/goctl/api/validate"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/config"
|
"github.com/zeromicro/go-zero/tools/goctl/config"
|
||||||
@@ -31,6 +33,7 @@ var (
|
|||||||
ktCmd = cobrax.NewCommand("kt", cobrax.WithRunE(ktgen.KtCommand))
|
ktCmd = cobrax.NewCommand("kt", cobrax.WithRunE(ktgen.KtCommand))
|
||||||
pluginCmd = cobrax.NewCommand("plugin", cobrax.WithRunE(plugin.PluginCommand))
|
pluginCmd = cobrax.NewCommand("plugin", cobrax.WithRunE(plugin.PluginCommand))
|
||||||
tsCmd = cobrax.NewCommand("ts", cobrax.WithRunE(tsgen.TsCommand))
|
tsCmd = cobrax.NewCommand("ts", cobrax.WithRunE(tsgen.TsCommand))
|
||||||
|
swaggerCmd = cobrax.NewCommand("swagger", cobrax.WithRunE(swagger.Command))
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
@@ -46,6 +49,7 @@ func init() {
|
|||||||
pluginCmdFlags = pluginCmd.Flags()
|
pluginCmdFlags = pluginCmd.Flags()
|
||||||
tsCmdFlags = tsCmd.Flags()
|
tsCmdFlags = tsCmd.Flags()
|
||||||
validateCmdFlags = validateCmd.Flags()
|
validateCmdFlags = validateCmd.Flags()
|
||||||
|
swaggerCmdFlags = swaggerCmd.Flags()
|
||||||
)
|
)
|
||||||
|
|
||||||
apiCmdFlags.StringVar(&apigen.VarStringOutput, "o")
|
apiCmdFlags.StringVar(&apigen.VarStringOutput, "o")
|
||||||
@@ -73,6 +77,7 @@ func init() {
|
|||||||
goCmdFlags.StringVar(&gogen.VarStringRemote, "remote")
|
goCmdFlags.StringVar(&gogen.VarStringRemote, "remote")
|
||||||
goCmdFlags.StringVar(&gogen.VarStringBranch, "branch")
|
goCmdFlags.StringVar(&gogen.VarStringBranch, "branch")
|
||||||
goCmdFlags.BoolVar(&gogen.VarBoolWithTest, "test")
|
goCmdFlags.BoolVar(&gogen.VarBoolWithTest, "test")
|
||||||
|
goCmdFlags.BoolVar(&gogen.VarBoolTypeGroup, "type-group")
|
||||||
goCmdFlags.StringVarWithDefaultValue(&gogen.VarStringStyle, "style", config.DefaultFormat)
|
goCmdFlags.StringVarWithDefaultValue(&gogen.VarStringStyle, "style", config.DefaultFormat)
|
||||||
|
|
||||||
javaCmdFlags.StringVar(&javagen.VarStringDir, "dir")
|
javaCmdFlags.StringVar(&javagen.VarStringDir, "dir")
|
||||||
@@ -97,8 +102,13 @@ func init() {
|
|||||||
tsCmdFlags.StringVar(&tsgen.VarStringCaller, "caller")
|
tsCmdFlags.StringVar(&tsgen.VarStringCaller, "caller")
|
||||||
tsCmdFlags.BoolVar(&tsgen.VarBoolUnWrap, "unwrap")
|
tsCmdFlags.BoolVar(&tsgen.VarBoolUnWrap, "unwrap")
|
||||||
|
|
||||||
|
swaggerCmdFlags.StringVar(&swagger.VarStringAPI, "api")
|
||||||
|
swaggerCmdFlags.StringVar(&swagger.VarStringDir, "dir")
|
||||||
|
swaggerCmdFlags.StringVar(&swagger.VarStringFilename, "filename")
|
||||||
|
swaggerCmdFlags.BoolVar(&swagger.VarBoolYaml, "yaml")
|
||||||
|
|
||||||
validateCmdFlags.StringVar(&validate.VarStringAPI, "api")
|
validateCmdFlags.StringVar(&validate.VarStringAPI, "api")
|
||||||
|
|
||||||
// Add sub-commands
|
// Add sub-commands
|
||||||
Cmd.AddCommand(dartCmd, docCmd, formatCmd, goCmd, javaCmd, ktCmd, newCmd, pluginCmd, tsCmd, validateCmd)
|
Cmd.AddCommand(dartCmd, docCmd, formatCmd, goCmd, javaCmd, ktCmd, newCmd, pluginCmd, tsCmd, validateCmd, swaggerCmd)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,10 +10,10 @@ const (
|
|||||||
import 'package:shared_preferences/shared_preferences.dart';
|
import 'package:shared_preferences/shared_preferences.dart';
|
||||||
import '../data/tokens.dart';
|
import '../data/tokens.dart';
|
||||||
|
|
||||||
/// 保存tokens到本地
|
/// store tokens to local
|
||||||
///
|
///
|
||||||
/// 传入null则删除本地tokens
|
/// pass null will clean local stored tokens
|
||||||
/// 返回:true:设置成功 false:设置失败
|
/// returns true if success, otherwise false
|
||||||
Future<bool> setTokens(Tokens tokens) async {
|
Future<bool> setTokens(Tokens tokens) async {
|
||||||
var sp = await SharedPreferences.getInstance();
|
var sp = await SharedPreferences.getInstance();
|
||||||
if (tokens == null) {
|
if (tokens == null) {
|
||||||
@@ -23,9 +23,9 @@ Future<bool> setTokens(Tokens tokens) async {
|
|||||||
return await sp.setString('tokens', jsonEncode(tokens.toJson()));
|
return await sp.setString('tokens', jsonEncode(tokens.toJson()));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 获取本地存储的tokens
|
/// get local stored tokens
|
||||||
///
|
///
|
||||||
/// 如果没有,则返回null
|
/// if no, returns null
|
||||||
Future<Tokens> getTokens() async {
|
Future<Tokens> getTokens() async {
|
||||||
try {
|
try {
|
||||||
var sp = await SharedPreferences.getInstance();
|
var sp = await SharedPreferences.getInstance();
|
||||||
@@ -82,7 +82,8 @@ func genVars(dir string, isLegacy bool, scheme string, hostname string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !fileExists(dir + "vars.dart") {
|
if !fileExists(dir + "vars.dart") {
|
||||||
err = os.WriteFile(dir+"vars.dart", []byte(fmt.Sprintf(`const serverHost='%s://%s';`, scheme, hostname)), 0o644)
|
err = os.WriteFile(dir+"vars.dart", []byte(fmt.Sprintf(`const serverHost='%s://%s';`,
|
||||||
|
scheme, hostname)), 0o644)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -42,9 +42,20 @@ var (
|
|||||||
func GoFormatApi(_ *cobra.Command, _ []string) error {
|
func GoFormatApi(_ *cobra.Command, _ []string) error {
|
||||||
var be errorx.BatchError
|
var be errorx.BatchError
|
||||||
if VarBoolUseStdin {
|
if VarBoolUseStdin {
|
||||||
|
if env.UseExperimental() {
|
||||||
|
data, err := io.ReadAll(os.Stdin)
|
||||||
|
if err != nil {
|
||||||
|
be.Add(err)
|
||||||
|
} else {
|
||||||
|
if err := apiF.Source(data, os.Stdout); err != nil {
|
||||||
|
be.Add(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
if err := apiFormatReader(os.Stdin, VarStringDir, VarBoolSkipCheckDeclare); err != nil {
|
if err := apiFormatReader(os.Stdin, VarStringDir, VarBoolSkipCheckDeclare); err != nil {
|
||||||
be.Add(err)
|
be.Add(err)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
if len(VarStringDir) == 0 {
|
if len(VarStringDir) == 0 {
|
||||||
return errors.New("missing -dir")
|
return errors.New("missing -dir")
|
||||||
|
|||||||
@@ -40,6 +40,8 @@ var (
|
|||||||
// VarStringStyle describes the style of output files.
|
// VarStringStyle describes the style of output files.
|
||||||
VarStringStyle string
|
VarStringStyle string
|
||||||
VarBoolWithTest bool
|
VarBoolWithTest bool
|
||||||
|
// VarBoolTypeGroup describes whether to group types.
|
||||||
|
VarBoolTypeGroup bool
|
||||||
)
|
)
|
||||||
|
|
||||||
// GoCommand gen go project files from command line
|
// GoCommand gen go project files from command line
|
||||||
|
|||||||
@@ -6,8 +6,10 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/zeromicro/go-zero/core/collection"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||||
apiutil "github.com/zeromicro/go-zero/tools/goctl/api/util"
|
apiutil "github.com/zeromicro/go-zero/tools/goctl/api/util"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/config"
|
"github.com/zeromicro/go-zero/tools/goctl/config"
|
||||||
@@ -39,20 +41,152 @@ func BuildTypes(types []spec.Type) (string, error) {
|
|||||||
return builder.String(), nil
|
return builder.String(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func genTypes(dir string, cfg *config.Config, api *spec.ApiSpec) error {
|
func getTypeName(tp spec.Type) string {
|
||||||
val, err := BuildTypes(api.Types)
|
if tp == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
switch val := tp.(type) {
|
||||||
|
case spec.DefineStruct:
|
||||||
|
typeName := util.Title(tp.Name())
|
||||||
|
return typeName
|
||||||
|
case spec.PointerType:
|
||||||
|
return getTypeName(val.Type)
|
||||||
|
case spec.ArrayType:
|
||||||
|
return getTypeName(val.Value)
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func genTypesWithGroup(dir string, cfg *config.Config, api *spec.ApiSpec) error {
|
||||||
|
groupTypes := make(map[string]map[string]spec.Type)
|
||||||
|
typesBelongToFiles := make(map[string]*collection.Set)
|
||||||
|
|
||||||
|
for _, v := range api.Service.Groups {
|
||||||
|
group := v.GetAnnotation(groupProperty)
|
||||||
|
if len(group) == 0 {
|
||||||
|
group = groupTypeDefault
|
||||||
|
}
|
||||||
|
// convert filepath to Identifier name spec.
|
||||||
|
group = strings.TrimPrefix(group, "/")
|
||||||
|
group = strings.TrimSuffix(group, "/")
|
||||||
|
group = util.SafeString(group)
|
||||||
|
for _, v := range v.Routes {
|
||||||
|
requestTypeName := getTypeName(v.RequestType)
|
||||||
|
responseTypeName := getTypeName(v.ResponseType)
|
||||||
|
requestTypeFileSet, ok := typesBelongToFiles[requestTypeName]
|
||||||
|
if !ok {
|
||||||
|
requestTypeFileSet = collection.NewSet()
|
||||||
|
}
|
||||||
|
if len(requestTypeName) > 0 {
|
||||||
|
requestTypeFileSet.AddStr(group)
|
||||||
|
typesBelongToFiles[requestTypeName] = requestTypeFileSet
|
||||||
|
}
|
||||||
|
|
||||||
|
responseTypeFileSet, ok := typesBelongToFiles[responseTypeName]
|
||||||
|
if !ok {
|
||||||
|
responseTypeFileSet = collection.NewSet()
|
||||||
|
}
|
||||||
|
if len(responseTypeName) > 0 {
|
||||||
|
responseTypeFileSet.AddStr(group)
|
||||||
|
typesBelongToFiles[responseTypeName] = responseTypeFileSet
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
typesInOneFile := make(map[string]*collection.Set)
|
||||||
|
for typeName, fileSet := range typesBelongToFiles {
|
||||||
|
count := fileSet.Count()
|
||||||
|
switch {
|
||||||
|
case count == 0: // it means there has no structure type or no request/response body
|
||||||
|
continue
|
||||||
|
case count == 1: // it means a structure type used in only one group.
|
||||||
|
groupName := fileSet.KeysStr()[0]
|
||||||
|
typeSet, ok := typesInOneFile[groupName]
|
||||||
|
if !ok {
|
||||||
|
typeSet = collection.NewSet()
|
||||||
|
}
|
||||||
|
typeSet.AddStr(typeName)
|
||||||
|
typesInOneFile[groupName] = typeSet
|
||||||
|
default: // it means this type is used in multiple groups.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range api.Types {
|
||||||
|
typeName := util.Title(v.Name())
|
||||||
|
groupSet, ok := typesBelongToFiles[typeName]
|
||||||
|
var typeCount int
|
||||||
|
if !ok {
|
||||||
|
typeCount = 0
|
||||||
|
} else {
|
||||||
|
typeCount = groupSet.Count()
|
||||||
|
}
|
||||||
|
|
||||||
|
if typeCount == 0 { // not belong to any group
|
||||||
|
types, ok := groupTypes[groupTypeDefault]
|
||||||
|
if !ok {
|
||||||
|
types = make(map[string]spec.Type)
|
||||||
|
}
|
||||||
|
types[typeName] = v
|
||||||
|
groupTypes[groupTypeDefault] = types
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if typeCount == 1 { // belong to one group
|
||||||
|
groupName := groupSet.KeysStr()[0]
|
||||||
|
types, ok := groupTypes[groupName]
|
||||||
|
if !ok {
|
||||||
|
types = make(map[string]spec.Type)
|
||||||
|
}
|
||||||
|
types[typeName] = v
|
||||||
|
groupTypes[groupName] = types
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// belong to multiple groups
|
||||||
|
types, ok := groupTypes[groupTypeDefault]
|
||||||
|
if !ok {
|
||||||
|
types = make(map[string]spec.Type)
|
||||||
|
}
|
||||||
|
types[typeName] = v
|
||||||
|
groupTypes[groupTypeDefault] = types
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
for group, typeGroup := range groupTypes {
|
||||||
|
var types []spec.Type
|
||||||
|
for _, v := range typeGroup {
|
||||||
|
types = append(types, v)
|
||||||
|
}
|
||||||
|
sort.Slice(types, func(i, j int) bool {
|
||||||
|
return types[i].Name() < types[j].Name()
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := writeTypes(dir, group, cfg, types); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeTypes(dir, baseFilename string, cfg *config.Config, types []spec.Type) error {
|
||||||
|
if len(types) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
val, err := BuildTypes(types)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
typeFilename, err := format.FileNamingFormat(cfg.NamingFormat, typesFile)
|
typeFilename, err := format.FileNamingFormat(cfg.NamingFormat, baseFilename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
typeFilename = typeFilename + ".go"
|
typeFilename = typeFilename + ".go"
|
||||||
filename := path.Join(dir, typesDir, typeFilename)
|
filename := path.Join(dir, typesDir, typeFilename)
|
||||||
os.Remove(filename)
|
_ = os.Remove(filename)
|
||||||
|
|
||||||
return genFile(fileGenConfig{
|
return genFile(fileGenConfig{
|
||||||
dir: dir,
|
dir: dir,
|
||||||
@@ -70,6 +204,13 @@ func genTypes(dir string, cfg *config.Config, api *spec.ApiSpec) error {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func genTypes(dir string, cfg *config.Config, api *spec.ApiSpec) error {
|
||||||
|
if VarBoolTypeGroup {
|
||||||
|
return genTypesWithGroup(dir, cfg, api)
|
||||||
|
}
|
||||||
|
return writeTypes(dir, typesFile, cfg, api.Types)
|
||||||
|
}
|
||||||
|
|
||||||
func writeType(writer io.Writer, tp spec.Type) error {
|
func writeType(writer io.Writer, tp spec.Type) error {
|
||||||
structType, ok := tp.(spec.DefineStruct)
|
structType, ok := tp.(spec.DefineStruct)
|
||||||
if !ok {
|
if !ok {
|
||||||
|
|||||||
@@ -10,4 +10,6 @@ const (
|
|||||||
middlewareDir = internal + "middleware"
|
middlewareDir = internal + "middleware"
|
||||||
typesDir = internal + typesPacket
|
typesDir = internal + typesPacket
|
||||||
groupProperty = "group"
|
groupProperty = "group"
|
||||||
|
|
||||||
|
groupTypeDefault="types"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -8,10 +8,10 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"path"
|
"path"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"text/template"
|
"text/template"
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/core/stringx"
|
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||||
apiutil "github.com/zeromicro/go-zero/tools/goctl/api/util"
|
apiutil "github.com/zeromicro/go-zero/tools/goctl/api/util"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/internal/version"
|
"github.com/zeromicro/go-zero/tools/goctl/internal/version"
|
||||||
@@ -96,13 +96,13 @@ func (c *componentsContext) createComponent(dir, packetName string, ty spec.Type
|
|||||||
for _, item := range c.responseTypes {
|
for _, item := range c.responseTypes {
|
||||||
if item.Name() == defineStruct.Name() {
|
if item.Name() == defineStruct.Name() {
|
||||||
superClassName = "HttpResponseData"
|
superClassName = "HttpResponseData"
|
||||||
if !stringx.Contains(c.imports, httpResponseData) {
|
if !slices.Contains(c.imports, httpResponseData) {
|
||||||
c.imports = append(c.imports, httpResponseData)
|
c.imports = append(c.imports, httpResponseData)
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if superClassName == "HttpData" && !stringx.Contains(c.imports, httpData) {
|
if superClassName == "HttpData" && !slices.Contains(c.imports, httpData) {
|
||||||
c.imports = append(c.imports, httpData)
|
c.imports = append(c.imports, httpData)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -266,7 +266,7 @@ func (c *componentsContext) genGetSet(writer io.Writer, indent int) error {
|
|||||||
tyString := javaType
|
tyString := javaType
|
||||||
decorator := ""
|
decorator := ""
|
||||||
javaPrimitiveType := []string{"int", "long", "boolean", "float", "double", "short"}
|
javaPrimitiveType := []string{"int", "long", "boolean", "float", "double", "short"}
|
||||||
if !stringx.Contains(javaPrimitiveType, javaType) {
|
if !slices.Contains(javaPrimitiveType, javaType) {
|
||||||
if member.IsOptional() || member.IsOmitEmpty() {
|
if member.IsOptional() || member.IsOmitEmpty() {
|
||||||
decorator = "@Nullable "
|
decorator = "@Nullable "
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -3,9 +3,9 @@ package spec
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"path"
|
"path"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/core/stringx"
|
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/util"
|
"github.com/zeromicro/go-zero/tools/goctl/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -57,14 +57,14 @@ func (m Member) Tags() []*Tag {
|
|||||||
|
|
||||||
// IsOptional returns true if tag is optional
|
// IsOptional returns true if tag is optional
|
||||||
func (m Member) IsOptional() bool {
|
func (m Member) IsOptional() bool {
|
||||||
if !m.IsBodyMember() {
|
if !m.IsBodyMember() && !m.IsFormMember() {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
tag := m.Tags()
|
tag := m.Tags()
|
||||||
for _, item := range tag {
|
for _, item := range tag {
|
||||||
if item.Key == bodyTagKey {
|
if item.Key == bodyTagKey || item.Key == formTagKey {
|
||||||
if stringx.Contains(item.Options, "optional") {
|
if slices.Contains(item.Options, "optional") {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -81,7 +81,7 @@ func (m Member) IsOmitEmpty() bool {
|
|||||||
tag := m.Tags()
|
tag := m.Tags()
|
||||||
for _, item := range tag {
|
for _, item := range tag {
|
||||||
if item.Key == bodyTagKey {
|
if item.Key == bodyTagKey {
|
||||||
if stringx.Contains(item.Options, "omitempty") {
|
if slices.Contains(item.Options, "omitempty") {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -93,7 +93,7 @@ func (m Member) IsOmitEmpty() bool {
|
|||||||
func (m Member) GetPropertyName() (string, error) {
|
func (m Member) GetPropertyName() (string, error) {
|
||||||
tags := m.Tags()
|
tags := m.Tags()
|
||||||
for _, tag := range tags {
|
for _, tag := range tags {
|
||||||
if stringx.Contains(definedKeys, tag.Key) {
|
if slices.Contains(definedKeys, tag.Key) {
|
||||||
if tag.Name == "-" {
|
if tag.Name == "-" {
|
||||||
return util.Untitle(m.Name), nil
|
return util.Untitle(m.Name), nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ type (
|
|||||||
|
|
||||||
// ApiSpec describes an api file
|
// ApiSpec describes an api file
|
||||||
ApiSpec struct {
|
ApiSpec struct {
|
||||||
Info Info // Deprecated: useless expression
|
Info Info
|
||||||
Syntax ApiSyntax // Deprecated: useless expression
|
Syntax ApiSyntax // Deprecated: useless expression
|
||||||
Imports []Import // Deprecated: useless expression
|
Imports []Import // Deprecated: useless expression
|
||||||
Types []Type
|
Types []Type
|
||||||
@@ -59,11 +59,11 @@ type (
|
|||||||
// Member describes the field of a structure
|
// Member describes the field of a structure
|
||||||
Member struct {
|
Member struct {
|
||||||
Name string
|
Name string
|
||||||
// 数据类型字面值,如:string、map[int]string、[]int64、[]*User
|
// data type, for example, string、map[int]string、[]int64、[]*User
|
||||||
Type Type
|
Type Type
|
||||||
Tag string
|
Tag string
|
||||||
Comment string
|
Comment string
|
||||||
// 成员头顶注释说明
|
// document for the field
|
||||||
Docs Doc
|
Docs Doc
|
||||||
IsInline bool
|
IsInline bool
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -49,6 +49,9 @@ func Parse(tag string) (*Tags, error) {
|
|||||||
|
|
||||||
// Get gets tag value by specified key
|
// Get gets tag value by specified key
|
||||||
func (t *Tags) Get(key string) (*Tag, error) {
|
func (t *Tags) Get(key string) (*Tag, error) {
|
||||||
|
if t == nil {
|
||||||
|
return nil, errTagNotExist
|
||||||
|
}
|
||||||
for _, tag := range t.tags {
|
for _, tag := range t.tags {
|
||||||
if tag.Key == key {
|
if tag.Key == key {
|
||||||
return tag, nil
|
return tag, nil
|
||||||
@@ -60,6 +63,9 @@ func (t *Tags) Get(key string) (*Tag, error) {
|
|||||||
|
|
||||||
// Keys returns all keys in Tags
|
// Keys returns all keys in Tags
|
||||||
func (t *Tags) Keys() []string {
|
func (t *Tags) Keys() []string {
|
||||||
|
if t == nil {
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
var keys []string
|
var keys []string
|
||||||
for _, tag := range t.tags {
|
for _, tag := range t.tags {
|
||||||
keys = append(keys, tag.Key)
|
keys = append(keys, tag.Key)
|
||||||
@@ -69,5 +75,8 @@ func (t *Tags) Keys() []string {
|
|||||||
|
|
||||||
// Tags returns all tags in Tags
|
// Tags returns all tags in Tags
|
||||||
func (t *Tags) Tags() []*Tag {
|
func (t *Tags) Tags() []*Tag {
|
||||||
|
if t == nil {
|
||||||
|
return []*Tag{}
|
||||||
|
}
|
||||||
return t.tags
|
return t.tags
|
||||||
}
|
}
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user