mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-05-19 20:58:17 +08:00
Compare commits
54 Commits
tools/goct
...
v1.8.4
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0f2b589d4d | ||
|
|
19fec36d24 | ||
|
|
f037bf344d | ||
|
|
d99cf35b07 | ||
|
|
f459f1b5ff | ||
|
|
0140fd417b | ||
|
|
7969e0ca38 | ||
|
|
91c885b5b0 | ||
|
|
d4cccca387 | ||
|
|
4b2095ed03 | ||
|
|
1229eeb2d2 | ||
|
|
9142b146c5 | ||
|
|
8a1b2d5aed | ||
|
|
da5d39e6ca | ||
|
|
68c5a17c67 | ||
|
|
b53f9f5f2d | ||
|
|
36d57626b6 | ||
|
|
4e36ba832f | ||
|
|
a44954a771 | ||
|
|
f3edd4b880 | ||
|
|
2de3e397ff | ||
|
|
a435eb56f2 | ||
|
|
d80761c147 | ||
|
|
e7bd0d8b60 | ||
|
|
b109b3ef4c | ||
|
|
e3c371ac89 | ||
|
|
15eb6f4f6d | ||
|
|
4d3681b71c | ||
|
|
a682bda0bb | ||
|
|
45b27ad93a | ||
|
|
292a8302a1 | ||
|
|
91ab1f6d2b | ||
|
|
5048c350ae | ||
|
|
94edc32f3e | ||
|
|
ec989b2e2a | ||
|
|
82fe802e81 | ||
|
|
072d68f897 | ||
|
|
2e91ba5811 | ||
|
|
5564c43197 | ||
|
|
e55158b0f7 | ||
|
|
69aa7fe346 | ||
|
|
c3820a95c1 | ||
|
|
493f3bad0f | ||
|
|
eb0d5ad3a4 | ||
|
|
14192050ae | ||
|
|
9193e771e3 | ||
|
|
808b4e496a | ||
|
|
e416d01f8d | ||
|
|
789c5de873 | ||
|
|
52078a0c14 | ||
|
|
7ef13116a0 | ||
|
|
6b8053410a | ||
|
|
81c6928445 | ||
|
|
761c2dd716 |
@@ -8,16 +8,12 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/mathx"
|
||||
"github.com/zeromicro/go-zero/core/proc"
|
||||
"github.com/zeromicro/go-zero/core/stat"
|
||||
"github.com/zeromicro/go-zero/core/stringx"
|
||||
)
|
||||
|
||||
const (
|
||||
numHistoryReasons = 5
|
||||
timeFormat = "15:04:05"
|
||||
)
|
||||
const numHistoryReasons = 5
|
||||
|
||||
// ErrServiceUnavailable is returned when the Breaker state is open.
|
||||
var ErrServiceUnavailable = errors.New("circuit breaker is open")
|
||||
@@ -262,9 +258,9 @@ type errorWindow struct {
|
||||
|
||||
func (ew *errorWindow) add(reason string) {
|
||||
ew.lock.Lock()
|
||||
ew.reasons[ew.index] = fmt.Sprintf("%s %s", time.Now().Format(timeFormat), reason)
|
||||
ew.reasons[ew.index] = fmt.Sprintf("%s %s", time.Now().Format(time.TimeOnly), reason)
|
||||
ew.index = (ew.index + 1) % numHistoryReasons
|
||||
ew.count = mathx.MinInt(ew.count+1, numHistoryReasons)
|
||||
ew.count = min(ew.count+1, numHistoryReasons)
|
||||
ew.lock.Unlock()
|
||||
}
|
||||
|
||||
|
||||
@@ -86,21 +86,16 @@ func TestConsistentHashIncrementalTransfer(t *testing.T) {
|
||||
|
||||
func TestConsistentHashTransferOnFailure(t *testing.T) {
|
||||
index := 41
|
||||
keys, newKeys := getKeysBeforeAndAfterFailure(t, "localhost:", index)
|
||||
var transferred int
|
||||
for k, v := range newKeys {
|
||||
if v != keys[k] {
|
||||
transferred++
|
||||
}
|
||||
}
|
||||
|
||||
ratio := float32(transferred) / float32(requestSize)
|
||||
assert.True(t, ratio < 2.5/float32(keySize), fmt.Sprintf("%d: %f", index, ratio))
|
||||
ratioNotExists := getTransferRatioOnFailure(t, index)
|
||||
assert.True(t, ratioNotExists == 0, fmt.Sprintf("%d: %f", index, ratioNotExists))
|
||||
index = 13
|
||||
ratio := getTransferRatioOnFailure(t, index)
|
||||
assert.True(t, ratio < 2.5/keySize, fmt.Sprintf("%d: %f", index, ratio))
|
||||
}
|
||||
|
||||
func TestConsistentHashLeastTransferOnFailure(t *testing.T) {
|
||||
prefix := "localhost:"
|
||||
index := 41
|
||||
index := 13
|
||||
keys, newKeys := getKeysBeforeAndAfterFailure(t, prefix, index)
|
||||
for k, v := range keys {
|
||||
newV := newKeys[k]
|
||||
@@ -164,6 +159,17 @@ func getKeysBeforeAndAfterFailure(t *testing.T, prefix string, index int) (map[i
|
||||
return keys, newKeys
|
||||
}
|
||||
|
||||
func getTransferRatioOnFailure(t *testing.T, index int) float32 {
|
||||
keys, newKeys := getKeysBeforeAndAfterFailure(t, "localhost:", index)
|
||||
var transferred int
|
||||
for k, v := range newKeys {
|
||||
if v != keys[k] {
|
||||
transferred++
|
||||
}
|
||||
}
|
||||
return float32(transferred) / float32(requestSize)
|
||||
}
|
||||
|
||||
type mockNode struct {
|
||||
addr string
|
||||
id int
|
||||
|
||||
@@ -2,7 +2,7 @@ package hash
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"fmt"
|
||||
"encoding/hex"
|
||||
|
||||
"github.com/spaolacci/murmur3"
|
||||
)
|
||||
@@ -20,6 +20,7 @@ func Md5(data []byte) []byte {
|
||||
}
|
||||
|
||||
// Md5Hex returns the md5 hex string of data.
|
||||
// This function is optimized for better performance than fmt.Sprintf.
|
||||
func Md5Hex(data []byte) string {
|
||||
return fmt.Sprintf("%x", Md5(data))
|
||||
return hex.EncodeToString(Md5(data))
|
||||
}
|
||||
|
||||
@@ -18,7 +18,6 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
dateFormat = "2006-01-02"
|
||||
hoursPerDay = 24
|
||||
bufferSize = 100
|
||||
defaultDirMode = 0o755
|
||||
@@ -116,7 +115,7 @@ func (r *DailyRotateRule) OutdatedFiles() []string {
|
||||
}
|
||||
|
||||
var buf strings.Builder
|
||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay*r.days)).Format(dateFormat)
|
||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay*r.days)).Format(time.DateOnly)
|
||||
buf.WriteString(r.filename)
|
||||
buf.WriteString(r.delimiter)
|
||||
buf.WriteString(boundary)
|
||||
@@ -425,7 +424,7 @@ func compressLogFile(file string) {
|
||||
}
|
||||
|
||||
func getNowDate() string {
|
||||
return time.Now().Format(dateFormat)
|
||||
return time.Now().Format(time.DateOnly)
|
||||
}
|
||||
|
||||
func getNowDateInRFC3339Format() string {
|
||||
|
||||
@@ -52,7 +52,7 @@ func TestDailyRotateRuleOutdatedFiles(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("temp files", func(t *testing.T) {
|
||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
|
||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(time.DateOnly)
|
||||
f1, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
||||
assert.NoError(t, err)
|
||||
_ = f1.Close()
|
||||
@@ -73,7 +73,7 @@ func TestDailyRotateRuleOutdatedFiles(t *testing.T) {
|
||||
|
||||
func TestDailyRotateRuleShallRotate(t *testing.T) {
|
||||
var rule DailyRotateRule
|
||||
rule.rotatedTime = time.Now().Add(time.Hour * 24).Format(dateFormat)
|
||||
rule.rotatedTime = time.Now().Add(time.Hour * 24).Format(time.DateOnly)
|
||||
assert.True(t, rule.ShallRotate(0))
|
||||
}
|
||||
|
||||
@@ -117,12 +117,12 @@ func TestSizeLimitRotateRuleOutdatedFiles(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("temp files", func(t *testing.T) {
|
||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
|
||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(time.DateOnly)
|
||||
f1, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
||||
assert.NoError(t, err)
|
||||
f2, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
||||
assert.NoError(t, err)
|
||||
boundary1 := time.Now().Add(time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
|
||||
boundary1 := time.Now().Add(time.Hour * time.Duration(hoursPerDay) * 2).Format(time.DateOnly)
|
||||
f3, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary1)
|
||||
assert.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
@@ -144,12 +144,12 @@ func TestSizeLimitRotateRuleOutdatedFiles(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("no backups", func(t *testing.T) {
|
||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
|
||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(time.DateOnly)
|
||||
f1, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
||||
assert.NoError(t, err)
|
||||
f2, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
||||
assert.NoError(t, err)
|
||||
boundary1 := time.Now().Add(time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
|
||||
boundary1 := time.Now().Add(time.Hour * time.Duration(hoursPerDay) * 2).Format(time.DateOnly)
|
||||
f3, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary1)
|
||||
assert.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
@@ -319,7 +319,7 @@ func TestRotateLoggerWrite(t *testing.T) {
|
||||
}
|
||||
// the following write calls cannot be changed to Write, because of DATA RACE.
|
||||
logger.write([]byte(`foo`))
|
||||
rule.rotatedTime = time.Now().Add(-time.Hour * 24).Format(dateFormat)
|
||||
rule.rotatedTime = time.Now().Add(-time.Hour * 24).Format(time.DateOnly)
|
||||
logger.write([]byte(`bar`))
|
||||
logger.Close()
|
||||
logger.write([]byte(`baz`))
|
||||
@@ -447,7 +447,7 @@ func TestRotateLoggerWithSizeLimitRotateRuleWrite(t *testing.T) {
|
||||
}
|
||||
// the following write calls cannot be changed to Write, because of DATA RACE.
|
||||
logger.write([]byte(`foo`))
|
||||
rule.rotatedTime = time.Now().Add(-time.Hour * 24).Format(dateFormat)
|
||||
rule.rotatedTime = time.Now().Add(-time.Hour * 24).Format(time.DateOnly)
|
||||
logger.write([]byte(`bar`))
|
||||
logger.Close()
|
||||
logger.write([]byte(`baz`))
|
||||
|
||||
@@ -142,7 +142,7 @@ func validateOptional(field reflect.StructField, value reflect.Value) error {
|
||||
if value.IsNil() {
|
||||
return fmt.Errorf("field %q is nil", field.Name)
|
||||
}
|
||||
case reflect.Array, reflect.Slice, reflect.Map:
|
||||
case reflect.Slice, reflect.Map:
|
||||
if value.IsNil() || value.Len() == 0 {
|
||||
return fmt.Errorf("field %q is empty", field.Name)
|
||||
}
|
||||
|
||||
@@ -462,3 +462,15 @@ func TestMarshal_FromString(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "10", m["json"]["age"].(string))
|
||||
}
|
||||
|
||||
func TestMarshal_Array(t *testing.T) {
|
||||
v := struct {
|
||||
H [1]int `json:"h,string"`
|
||||
}{
|
||||
H: [1]int{1},
|
||||
}
|
||||
|
||||
m, err := Marshal(v)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "[1]", m["json"]["h"].(string))
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -15,7 +16,6 @@ import (
|
||||
"github.com/zeromicro/go-zero/core/jsonx"
|
||||
"github.com/zeromicro/go-zero/core/lang"
|
||||
"github.com/zeromicro/go-zero/core/proc"
|
||||
"github.com/zeromicro/go-zero/core/stringx"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -894,7 +894,7 @@ func (u *Unmarshaler) processNamedFieldWithValueFromString(fieldType reflect.Typ
|
||||
valueKind.String())
|
||||
}
|
||||
|
||||
if !stringx.Contains(options, checkValue) {
|
||||
if !slices.Contains(options, checkValue) {
|
||||
return fmt.Errorf(`value "%s" for field %q is not defined in options "%v"`,
|
||||
mapValue, key, options)
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -634,11 +635,11 @@ func validateValueInOptions(val any, options []string) error {
|
||||
if len(options) > 0 {
|
||||
switch v := val.(type) {
|
||||
case string:
|
||||
if !stringx.Contains(options, v) {
|
||||
if !slices.Contains(options, v) {
|
||||
return fmt.Errorf(`error: value %q is not defined in options "%v"`, v, options)
|
||||
}
|
||||
default:
|
||||
if !stringx.Contains(options, Repr(v)) {
|
||||
if !slices.Contains(options, Repr(v)) {
|
||||
return fmt.Errorf(`error: value "%v" is not defined in options "%v"`, val, options)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,19 +1,13 @@
|
||||
package mathx
|
||||
|
||||
// MaxInt returns the larger one of a and b.
|
||||
// Deprecated: use builtin max instead.
|
||||
func MaxInt(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
|
||||
return b
|
||||
return max(a, b)
|
||||
}
|
||||
|
||||
// MinInt returns the smaller one of a and b.
|
||||
// Deprecated: use builtin min instead.
|
||||
func MinInt(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
|
||||
return b
|
||||
return min(a, b)
|
||||
}
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
package prof
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strconv"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/olekukonko/tablewriter"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/threading"
|
||||
)
|
||||
@@ -28,46 +27,15 @@ type (
|
||||
|
||||
const flushInterval = 5 * time.Minute
|
||||
|
||||
var (
|
||||
pc = &profileCenter{
|
||||
slots: make(map[string]*profileSlot),
|
||||
}
|
||||
once sync.Once
|
||||
)
|
||||
|
||||
func report(name string, duration time.Duration) {
|
||||
updated := func() bool {
|
||||
pc.lock.RLock()
|
||||
defer pc.lock.RUnlock()
|
||||
|
||||
slot, ok := pc.slots[name]
|
||||
if ok {
|
||||
atomic.AddInt64(&slot.lifecount, 1)
|
||||
atomic.AddInt64(&slot.lastcount, 1)
|
||||
atomic.AddInt64(&slot.lifecycle, int64(duration))
|
||||
atomic.AddInt64(&slot.lastcycle, int64(duration))
|
||||
}
|
||||
return ok
|
||||
}()
|
||||
|
||||
if !updated {
|
||||
func() {
|
||||
pc.lock.Lock()
|
||||
defer pc.lock.Unlock()
|
||||
|
||||
pc.slots[name] = &profileSlot{
|
||||
lifecount: 1,
|
||||
lastcount: 1,
|
||||
lifecycle: int64(duration),
|
||||
lastcycle: int64(duration),
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
once.Do(flushRepeatly)
|
||||
var pc = &profileCenter{
|
||||
slots: make(map[string]*profileSlot),
|
||||
}
|
||||
|
||||
func flushRepeatly() {
|
||||
func init() {
|
||||
flushRepeatedly()
|
||||
}
|
||||
|
||||
func flushRepeatedly() {
|
||||
threading.GoSafe(func() {
|
||||
for {
|
||||
time.Sleep(flushInterval)
|
||||
@@ -76,42 +44,64 @@ func flushRepeatly() {
|
||||
})
|
||||
}
|
||||
|
||||
func report(name string, duration time.Duration) {
|
||||
slot := loadOrStoreSlot(name, duration)
|
||||
|
||||
atomic.AddInt64(&slot.lifecount, 1)
|
||||
atomic.AddInt64(&slot.lastcount, 1)
|
||||
atomic.AddInt64(&slot.lifecycle, int64(duration))
|
||||
atomic.AddInt64(&slot.lastcycle, int64(duration))
|
||||
}
|
||||
|
||||
func loadOrStoreSlot(name string, duration time.Duration) *profileSlot {
|
||||
pc.lock.RLock()
|
||||
slot, ok := pc.slots[name]
|
||||
pc.lock.RUnlock()
|
||||
|
||||
if ok {
|
||||
return slot
|
||||
}
|
||||
|
||||
pc.lock.Lock()
|
||||
defer pc.lock.Unlock()
|
||||
|
||||
// double-check
|
||||
if slot, ok = pc.slots[name]; ok {
|
||||
return slot
|
||||
}
|
||||
|
||||
slot = &profileSlot{}
|
||||
pc.slots[name] = slot
|
||||
return slot
|
||||
}
|
||||
|
||||
func generateReport() string {
|
||||
var buffer bytes.Buffer
|
||||
buffer.WriteString("Profiling report\n")
|
||||
var data [][]string
|
||||
var builder strings.Builder
|
||||
builder.WriteString("Profiling report\n")
|
||||
builder.WriteString("QUEUE,LIFECOUNT,LIFECYCLE,LASTCOUNT,LASTCYCLE\n")
|
||||
|
||||
calcFn := func(total, count int64) string {
|
||||
if count == 0 {
|
||||
return "-"
|
||||
}
|
||||
|
||||
return (time.Duration(total) / time.Duration(count)).String()
|
||||
}
|
||||
|
||||
func() {
|
||||
pc.lock.Lock()
|
||||
defer pc.lock.Unlock()
|
||||
pc.lock.Lock()
|
||||
for key, slot := range pc.slots {
|
||||
builder.WriteString(fmt.Sprintf("%s,%d,%s,%d,%s\n",
|
||||
key,
|
||||
slot.lifecount,
|
||||
calcFn(slot.lifecycle, slot.lifecount),
|
||||
slot.lastcount,
|
||||
calcFn(slot.lastcycle, slot.lastcount),
|
||||
))
|
||||
|
||||
for key, slot := range pc.slots {
|
||||
data = append(data, []string{
|
||||
key,
|
||||
strconv.FormatInt(slot.lifecount, 10),
|
||||
calcFn(slot.lifecycle, slot.lifecount),
|
||||
strconv.FormatInt(slot.lastcount, 10),
|
||||
calcFn(slot.lastcycle, slot.lastcount),
|
||||
})
|
||||
// reset last cycle stats
|
||||
atomic.StoreInt64(&slot.lastcount, 0)
|
||||
atomic.StoreInt64(&slot.lastcycle, 0)
|
||||
}
|
||||
pc.lock.Unlock()
|
||||
|
||||
// reset the data for last cycle
|
||||
slot.lastcount = 0
|
||||
slot.lastcycle = 0
|
||||
}
|
||||
}()
|
||||
|
||||
table := tablewriter.NewWriter(&buffer)
|
||||
table.SetHeader([]string{"QUEUE", "LIFECOUNT", "LIFECYCLE", "LASTCOUNT", "LASTCYCLE"})
|
||||
table.SetBorder(false)
|
||||
table.AppendBulk(data)
|
||||
table.Render()
|
||||
|
||||
return buffer.String()
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
)
|
||||
|
||||
func TestReport(t *testing.T) {
|
||||
once.Do(func() {})
|
||||
assert.NotContains(t, generateReport(), "foo")
|
||||
report("foo", time.Second)
|
||||
assert.Contains(t, generateReport(), "foo")
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/zeromicro/go-zero/core/stat"
|
||||
"github.com/zeromicro/go-zero/core/trace"
|
||||
"github.com/zeromicro/go-zero/internal/devserver"
|
||||
"github.com/zeromicro/go-zero/internal/profiling"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -38,6 +39,8 @@ type (
|
||||
Telemetry trace.Config `json:",optional"`
|
||||
DevServer DevServerConfig `json:",optional"`
|
||||
Shutdown proc.ShutdownConf `json:",optional"`
|
||||
// Profiling is the configuration for continuous profiling.
|
||||
Profiling profiling.Config `json:",optional"`
|
||||
}
|
||||
)
|
||||
|
||||
@@ -70,7 +73,9 @@ func (sc ServiceConf) SetUp() error {
|
||||
if len(sc.MetricsUrl) > 0 {
|
||||
stat.SetReportWriter(stat.NewRemoteWriter(sc.MetricsUrl))
|
||||
}
|
||||
|
||||
devserver.StartAgent(sc.DevServer)
|
||||
profiling.Start(sc.Profiling)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/proc"
|
||||
"github.com/zeromicro/go-zero/core/syncx"
|
||||
"github.com/zeromicro/go-zero/core/threading"
|
||||
)
|
||||
|
||||
@@ -35,7 +36,7 @@ type (
|
||||
// NewServiceGroup returns a ServiceGroup.
|
||||
func NewServiceGroup() *ServiceGroup {
|
||||
sg := new(ServiceGroup)
|
||||
sg.stopOnce = syncx.Once(sg.doStop)
|
||||
sg.stopOnce = sync.OnceFunc(sg.doStop)
|
||||
return sg
|
||||
}
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ import (
|
||||
const (
|
||||
clusterNameKey = "CLUSTER_NAME"
|
||||
testEnv = "test.v"
|
||||
timeFormat = "2006-01-02 15:04:05"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -45,7 +44,7 @@ func Report(msg string) {
|
||||
if fn != nil {
|
||||
reported := lessExecutor.DoOrDiscard(func() {
|
||||
var builder strings.Builder
|
||||
builder.WriteString(fmt.Sprintln(time.Now().Format(timeFormat)))
|
||||
builder.WriteString(fmt.Sprintln(time.Now().Format(time.DateTime)))
|
||||
if len(clusterName) > 0 {
|
||||
builder.WriteString(fmt.Sprintf("cluster: %s\n", clusterName))
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package stringx
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"slices"
|
||||
"unicode"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/lang"
|
||||
@@ -15,14 +16,9 @@ var (
|
||||
)
|
||||
|
||||
// Contains checks if str is in list.
|
||||
// Deprecated: use slices.Contains instead.
|
||||
func Contains(list []string, str string) bool {
|
||||
for _, each := range list {
|
||||
if each == str {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
return slices.Contains(list, str)
|
||||
}
|
||||
|
||||
// Filter filters chars from s with given filter function.
|
||||
@@ -123,11 +119,7 @@ func Remove(strings []string, strs ...string) []string {
|
||||
// Reverse reverses s.
|
||||
func Reverse(s string) string {
|
||||
runes := []rune(s)
|
||||
|
||||
for from, to := 0, len(runes)-1; from < to; from, to = from+1, to-1 {
|
||||
runes[from], runes[to] = runes[to], runes[from]
|
||||
}
|
||||
|
||||
slices.Reverse(runes)
|
||||
return string(runes)
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,28 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestContainsString(t *testing.T) {
|
||||
cases := []struct {
|
||||
slice []string
|
||||
value string
|
||||
expect bool
|
||||
}{
|
||||
{[]string{"1"}, "1", true},
|
||||
{[]string{"1"}, "2", false},
|
||||
{[]string{"1", "2"}, "1", true},
|
||||
{[]string{"1", "2"}, "3", false},
|
||||
{nil, "3", false},
|
||||
{nil, "", false},
|
||||
}
|
||||
|
||||
for _, each := range cases {
|
||||
t.Run(path.Join(each.slice...), func(t *testing.T) {
|
||||
actual := Contains(each.slice, each.value)
|
||||
assert.Equal(t, each.expect, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNotEmpty(t *testing.T) {
|
||||
cases := []struct {
|
||||
args []string
|
||||
@@ -41,28 +63,6 @@ func TestNotEmpty(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestContainsString(t *testing.T) {
|
||||
cases := []struct {
|
||||
slice []string
|
||||
value string
|
||||
expect bool
|
||||
}{
|
||||
{[]string{"1"}, "1", true},
|
||||
{[]string{"1"}, "2", false},
|
||||
{[]string{"1", "2"}, "1", true},
|
||||
{[]string{"1", "2"}, "3", false},
|
||||
{nil, "3", false},
|
||||
{nil, "", false},
|
||||
}
|
||||
|
||||
for _, each := range cases {
|
||||
t.Run(path.Join(each.slice...), func(t *testing.T) {
|
||||
actual := Contains(each.slice, each.value)
|
||||
assert.Equal(t, each.expect, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilter(t *testing.T) {
|
||||
cases := []struct {
|
||||
input string
|
||||
|
||||
@@ -3,9 +3,7 @@ package syncx
|
||||
import "sync"
|
||||
|
||||
// Once returns a func that guarantees fn can only called once.
|
||||
// Deprecated: use sync.OnceFunc instead.
|
||||
func Once(fn func()) func() {
|
||||
once := new(sync.Once)
|
||||
return func() {
|
||||
once.Do(fn)
|
||||
}
|
||||
return sync.OnceFunc(fn)
|
||||
}
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/mathx"
|
||||
"github.com/zeromicro/go-zero/core/stringx"
|
||||
)
|
||||
|
||||
@@ -39,7 +39,7 @@ func compare(v1, v2 string) int {
|
||||
fields1, fields2 := strings.Split(v1, "."), strings.Split(v2, ".")
|
||||
ver1, ver2 := strsToInts(fields1), strsToInts(fields2)
|
||||
ver1len, ver2len := len(ver1), len(ver2)
|
||||
shorter := mathx.MinInt(ver1len, ver2len)
|
||||
shorter := min(ver1len, ver2len)
|
||||
|
||||
for i := 0; i < shorter; i++ {
|
||||
if ver1[i] == ver2[i] {
|
||||
@@ -50,14 +50,7 @@ func compare(v1, v2 string) int {
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
if ver1len < ver2len {
|
||||
return -1
|
||||
} else if ver1len == ver2len {
|
||||
return 0
|
||||
} else {
|
||||
return 1
|
||||
}
|
||||
return cmp.Compare(ver1len, ver2len)
|
||||
}
|
||||
|
||||
func strsToInts(strs []string) []int64 {
|
||||
|
||||
@@ -201,6 +201,13 @@ func TestHttpToHttp(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("method not allowed", func(t *testing.T) {
|
||||
resp, err := httpc.Do(context.Background(), http.MethodPost,
|
||||
"http://localhost:18882/api/ping", nil)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode)
|
||||
})
|
||||
}
|
||||
|
||||
func TestHttpToHttpBadUpstream(t *testing.T) {
|
||||
|
||||
10
go.mod
10
go.mod
@@ -4,7 +4,7 @@ go 1.21
|
||||
|
||||
require (
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2
|
||||
github.com/alicebob/miniredis/v2 v2.34.0
|
||||
github.com/alicebob/miniredis/v2 v2.35.0
|
||||
github.com/fatih/color v1.18.0
|
||||
github.com/fullstorydev/grpcurl v1.9.3
|
||||
github.com/go-sql-driver/mysql v1.9.0
|
||||
@@ -12,17 +12,17 @@ require (
|
||||
github.com/golang/mock v1.6.0
|
||||
github.com/golang/protobuf v1.5.4
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/grafana/pyroscope-go v1.2.2
|
||||
github.com/jackc/pgx/v5 v5.7.4
|
||||
github.com/jhump/protoreflect v1.17.0
|
||||
github.com/olekukonko/tablewriter v0.0.5
|
||||
github.com/pelletier/go-toml/v2 v2.2.2
|
||||
github.com/prometheus/client_golang v1.21.1
|
||||
github.com/redis/go-redis/v9 v9.7.3
|
||||
github.com/redis/go-redis/v9 v9.10.0
|
||||
github.com/spaolacci/murmur3 v1.1.0
|
||||
github.com/stretchr/testify v1.10.0
|
||||
go.etcd.io/etcd/api/v3 v3.5.15
|
||||
go.etcd.io/etcd/client/v3 v3.5.15
|
||||
go.mongodb.org/mongo-driver v1.17.3
|
||||
go.mongodb.org/mongo-driver v1.17.4
|
||||
go.opentelemetry.io/otel v1.24.0
|
||||
go.opentelemetry.io/otel/exporters/jaeger v1.17.0
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.24.0
|
||||
@@ -50,7 +50,6 @@ require (
|
||||
|
||||
require (
|
||||
filippo.io/edwards25519 v1.1.0 // indirect
|
||||
github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 // indirect
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
github.com/bufbuild/protocompile v0.14.1 // indirect
|
||||
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
|
||||
@@ -73,6 +72,7 @@ require (
|
||||
github.com/google/gnostic-models v0.6.8 // indirect
|
||||
github.com/google/go-cmp v0.6.0 // indirect
|
||||
github.com/google/gofuzz v1.2.0 // indirect
|
||||
github.com/grafana/pyroscope-go/godeltaprof v0.1.8 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 // indirect
|
||||
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
|
||||
21
go.sum
21
go.sum
@@ -2,10 +2,8 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
|
||||
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
|
||||
github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 h1:uvdUDbHQHO85qeSydJtItA4T55Pw6BtAejd0APRJOCE=
|
||||
github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc=
|
||||
github.com/alicebob/miniredis/v2 v2.34.0 h1:mBFWMaJSNL9RwdGRyEDoAAv8OQc5UlEhLDQggTglU/0=
|
||||
github.com/alicebob/miniredis/v2 v2.34.0/go.mod h1:kWShP4b58T1CW0Y5dViCd5ztzrDqRWqM3nksiyXk5s8=
|
||||
github.com/alicebob/miniredis/v2 v2.35.0 h1:QwLphYqCEAo1eu1TqPRN2jgVMPBweeQcR21jeqDCONI=
|
||||
github.com/alicebob/miniredis/v2 v2.35.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM=
|
||||
github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8=
|
||||
github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
|
||||
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
||||
@@ -82,6 +80,10 @@ github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1 h1:K6RDEckDVWvDI9JAJY
|
||||
github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/grafana/pyroscope-go v1.2.2 h1:uvKCyZMD724RkaCEMrSTC38Yn7AnFe8S2wiAIYdDPCE=
|
||||
github.com/grafana/pyroscope-go v1.2.2/go.mod h1:zzT9QXQAp2Iz2ZdS216UiV8y9uXJYQiGE1q8v1FyhqU=
|
||||
github.com/grafana/pyroscope-go/godeltaprof v0.1.8 h1:iwOtYXeeVSAeYefJNaxDytgjKtUuKQbJqgAIjlnicKg=
|
||||
github.com/grafana/pyroscope-go/godeltaprof v0.1.8/go.mod h1:2+l7K7twW49Ct4wFluZD3tZ6e0SjanjcUUBPVD/UuGU=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 h1:bkypFPDjIYGfCYD5mRBvpqxfYX1YCS1PXdKYWi8FsN0=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0/go.mod h1:P+Lt/0by1T8bfcF3z737NnSbmxQAppXMRziHUxPOC8k=
|
||||
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw=
|
||||
@@ -121,7 +123,6 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
|
||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
|
||||
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
|
||||
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
@@ -135,8 +136,6 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
||||
github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32 h1:W6apQkHrMkS0Muv8G/TipAy/FJl/rCYT0+EuS8+Z0z4=
|
||||
github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32/go.mod h1:9wM+0iRr9ahx58uYLpLIr5fm8diHn0JbqRycJi6w0Ms=
|
||||
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
|
||||
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
|
||||
github.com/onsi/ginkgo/v2 v2.13.0 h1:0jY9lJquiL8fcf3M4LAXN5aMlS/b2BV86HFFPCPMgE4=
|
||||
github.com/onsi/ginkgo/v2 v2.13.0/go.mod h1:TE309ZR8s5FsKKpuB1YAQYBzCaAfUgatB/xlT/ETL/o=
|
||||
github.com/onsi/gomega v1.29.0 h1:KIA/t2t5UBzoirT4H9tsML45GEbo3ouUnBHsCfD2tVg=
|
||||
@@ -159,8 +158,8 @@ github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ
|
||||
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
|
||||
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
|
||||
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
|
||||
github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM=
|
||||
github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA=
|
||||
github.com/redis/go-redis/v9 v9.10.0 h1:FxwK3eV8p/CQa0Ch276C7u2d0eNC9kCmAYQ7mCXCzVs=
|
||||
github.com/redis/go-redis/v9 v9.10.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
|
||||
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
||||
@@ -203,8 +202,8 @@ go.etcd.io/etcd/client/pkg/v3 v3.5.15 h1:fo0HpWz/KlHGMCC+YejpiCmyWDEuIpnTDzpJLB5
|
||||
go.etcd.io/etcd/client/pkg/v3 v3.5.15/go.mod h1:mXDI4NAOwEiszrHCb0aqfAYNCrZP4e9hRca3d1YK8EU=
|
||||
go.etcd.io/etcd/client/v3 v3.5.15 h1:23M0eY4Fd/inNv1ZfU3AxrbbOdW79r9V9Rl62Nm6ip4=
|
||||
go.etcd.io/etcd/client/v3 v3.5.15/go.mod h1:CLSJxrYjvLtHsrPKsy7LmZEE+DK2ktfd2bN4RhBMwlU=
|
||||
go.mongodb.org/mongo-driver v1.17.3 h1:TQyXhnsWfWtgAhMtOgtYHMTkZIfBTpMTsMnd9ZBeHxQ=
|
||||
go.mongodb.org/mongo-driver v1.17.3/go.mod h1:Hy04i7O2kC4RS06ZrhPRqj/u4DTYkFDAAccj+rVKqgQ=
|
||||
go.mongodb.org/mongo-driver v1.17.4 h1:jUorfmVzljjr0FLzYQsGP8cgN/qzzxlY9Vh0C9KFXVw=
|
||||
go.mongodb.org/mongo-driver v1.17.4/go.mod h1:Hy04i7O2kC4RS06ZrhPRqj/u4DTYkFDAAccj+rVKqgQ=
|
||||
go.opentelemetry.io/otel v1.24.0 h1:0LAOdjNmQeSTzGBzduGe/rU4tZhMwL5rWgtp9Ku5Jfo=
|
||||
go.opentelemetry.io/otel v1.24.0/go.mod h1:W7b9Ozg4nkF5tWI5zsXkaKKDjdVjpD4oAt9Qi/MArHo=
|
||||
go.opentelemetry.io/otel/exporters/jaeger v1.17.0 h1:D7UpUy2Xc2wsi1Ras6V40q806WM07rqoCWzXu7Sqy+4=
|
||||
|
||||
263
internal/profiling/profiling.go
Normal file
263
internal/profiling/profiling.go
Normal file
@@ -0,0 +1,263 @@
|
||||
package profiling
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/grafana/pyroscope-go"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/proc"
|
||||
"github.com/zeromicro/go-zero/core/stat"
|
||||
"github.com/zeromicro/go-zero/core/threading"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultCheckInterval = time.Second * 10
|
||||
defaultProfilingDuration = time.Minute * 2
|
||||
defaultUploadRate = time.Second * 15
|
||||
)
|
||||
|
||||
type (
|
||||
Config struct {
|
||||
// Name is the name of the application.
|
||||
Name string `json:",optional,inherit"`
|
||||
// ServerAddr is the address of the profiling server.
|
||||
ServerAddr string
|
||||
// AuthUser is the username for basic authentication.
|
||||
AuthUser string `json:",optional"`
|
||||
// AuthPassword is the password for basic authentication.
|
||||
AuthPassword string `json:",optional"`
|
||||
// UploadRate is the duration for which profiling data is uploaded.
|
||||
UploadRate time.Duration `json:",default=15s"`
|
||||
// CheckInterval is the interval to check if profiling should start.
|
||||
CheckInterval time.Duration `json:",default=10s"`
|
||||
// ProfilingDuration is the duration for which profiling data is collected.
|
||||
ProfilingDuration time.Duration `json:",default=2m"`
|
||||
// CpuThreshold the collection is allowed only when the current service cpu < CpuThreshold
|
||||
CpuThreshold int64 `json:",default=700,range=[0:1000)"`
|
||||
|
||||
// ProfileType is the type of profiling to be performed.
|
||||
ProfileType ProfileType
|
||||
}
|
||||
|
||||
ProfileType struct {
|
||||
// Logger is a flag to enable or disable logging.
|
||||
Logger bool `json:",default=false"`
|
||||
// CPU is a flag to disable CPU profiling.
|
||||
CPU bool `json:",default=true"`
|
||||
// Goroutines is a flag to disable goroutine profiling.
|
||||
Goroutines bool `json:",default=true"`
|
||||
// Memory is a flag to disable memory profiling.
|
||||
Memory bool `json:",default=true"`
|
||||
// Mutex is a flag to disable mutex profiling.
|
||||
Mutex bool `json:",default=false"`
|
||||
// Block is a flag to disable block profiling.
|
||||
Block bool `json:",default=false"`
|
||||
}
|
||||
|
||||
profiler interface {
|
||||
Start() error
|
||||
Stop() error
|
||||
}
|
||||
|
||||
pyroscopeProfiler struct {
|
||||
c Config
|
||||
profiler *pyroscope.Profiler
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
once sync.Once
|
||||
|
||||
newProfiler = func(c Config) profiler {
|
||||
return newPyroscopeProfiler(c)
|
||||
}
|
||||
)
|
||||
|
||||
// Start initializes the pyroscope profiler with the given configuration.
|
||||
func Start(c Config) {
|
||||
// check if the profiling is enabled
|
||||
if len(c.ServerAddr) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// set default values for the configuration
|
||||
if c.ProfilingDuration <= 0 {
|
||||
c.ProfilingDuration = defaultProfilingDuration
|
||||
}
|
||||
|
||||
// set default values for the configuration
|
||||
if c.CheckInterval <= 0 {
|
||||
c.CheckInterval = defaultCheckInterval
|
||||
}
|
||||
|
||||
if c.UploadRate <= 0 {
|
||||
c.UploadRate = defaultUploadRate
|
||||
}
|
||||
|
||||
once.Do(func() {
|
||||
logx.Info("continuous profiling started")
|
||||
|
||||
threading.GoSafe(func() {
|
||||
startPyroscope(c, proc.Done())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// startPyroscope starts the pyroscope profiler with the given configuration.
|
||||
func startPyroscope(c Config, done <-chan struct{}) {
|
||||
var (
|
||||
pr profiler
|
||||
err error
|
||||
latestProfilingTime time.Time
|
||||
intervalTicker = time.NewTicker(c.CheckInterval)
|
||||
profilingTicker = time.NewTicker(c.ProfilingDuration)
|
||||
)
|
||||
|
||||
defer profilingTicker.Stop()
|
||||
defer intervalTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-intervalTicker.C:
|
||||
// Check if the machine is overloaded and if the profiler is not running
|
||||
if pr == nil && isCpuOverloaded(c) {
|
||||
pr = newProfiler(c)
|
||||
if err := pr.Start(); err != nil {
|
||||
logx.Errorf("failed to start profiler: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// record the latest profiling time
|
||||
latestProfilingTime = time.Now()
|
||||
logx.Infof("pyroscope profiler started.")
|
||||
}
|
||||
case <-profilingTicker.C:
|
||||
// check if the profiling duration has passed
|
||||
if !time.Now().After(latestProfilingTime.Add(c.ProfilingDuration)) {
|
||||
continue
|
||||
}
|
||||
|
||||
// check if the profiler is already running, if so, skip
|
||||
if pr != nil {
|
||||
if err = pr.Stop(); err != nil {
|
||||
logx.Errorf("failed to stop profiler: %v", err)
|
||||
}
|
||||
logx.Infof("pyroscope profiler stopped.")
|
||||
pr = nil
|
||||
}
|
||||
case <-done:
|
||||
logx.Infof("continuous profiling stopped.")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// genPyroscopeConf generates the pyroscope configuration based on the given config.
|
||||
func genPyroscopeConf(c Config) pyroscope.Config {
|
||||
pConf := pyroscope.Config{
|
||||
UploadRate: c.UploadRate,
|
||||
ApplicationName: c.Name,
|
||||
BasicAuthUser: c.AuthUser, // http basic auth user
|
||||
BasicAuthPassword: c.AuthPassword, // http basic auth password
|
||||
ServerAddress: c.ServerAddr,
|
||||
Logger: nil,
|
||||
HTTPHeaders: map[string]string{},
|
||||
// you can provide static tags via a map:
|
||||
Tags: map[string]string{
|
||||
"name": c.Name,
|
||||
},
|
||||
}
|
||||
|
||||
if c.ProfileType.Logger {
|
||||
pConf.Logger = logx.WithCallerSkip(0)
|
||||
}
|
||||
|
||||
if c.ProfileType.CPU {
|
||||
pConf.ProfileTypes = append(pConf.ProfileTypes, pyroscope.ProfileCPU)
|
||||
}
|
||||
if c.ProfileType.Goroutines {
|
||||
pConf.ProfileTypes = append(pConf.ProfileTypes, pyroscope.ProfileGoroutines)
|
||||
}
|
||||
if c.ProfileType.Memory {
|
||||
pConf.ProfileTypes = append(pConf.ProfileTypes, pyroscope.ProfileAllocObjects, pyroscope.ProfileAllocSpace,
|
||||
pyroscope.ProfileInuseObjects, pyroscope.ProfileInuseSpace)
|
||||
}
|
||||
if c.ProfileType.Mutex {
|
||||
pConf.ProfileTypes = append(pConf.ProfileTypes, pyroscope.ProfileMutexCount, pyroscope.ProfileMutexDuration)
|
||||
}
|
||||
if c.ProfileType.Block {
|
||||
pConf.ProfileTypes = append(pConf.ProfileTypes, pyroscope.ProfileBlockCount, pyroscope.ProfileBlockDuration)
|
||||
}
|
||||
|
||||
logx.Infof("applicationName: %s", pConf.ApplicationName)
|
||||
|
||||
return pConf
|
||||
}
|
||||
|
||||
// isCpuOverloaded checks the machine performance based on the given configuration.
|
||||
func isCpuOverloaded(c Config) bool {
|
||||
currentValue := stat.CpuUsage()
|
||||
if currentValue >= c.CpuThreshold {
|
||||
logx.Infof("continuous profiling cpu overload, cpu: %d", currentValue)
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func newPyroscopeProfiler(c Config) profiler {
|
||||
return &pyroscopeProfiler{
|
||||
c: c,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *pyroscopeProfiler) Start() error {
|
||||
pConf := genPyroscopeConf(p.c)
|
||||
// set mutex and block profile rate
|
||||
setFraction(p.c)
|
||||
prof, err := pyroscope.Start(pConf)
|
||||
if err != nil {
|
||||
resetFraction(p.c)
|
||||
return err
|
||||
}
|
||||
|
||||
p.profiler = prof
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *pyroscopeProfiler) Stop() error {
|
||||
if p.profiler == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := p.profiler.Stop(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resetFraction(p.c)
|
||||
p.profiler = nil
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func setFraction(c Config) {
|
||||
// These 2 lines are only required if you're using mutex or block profiling
|
||||
if c.ProfileType.Mutex {
|
||||
runtime.SetMutexProfileFraction(10) // 10/seconds
|
||||
}
|
||||
if c.ProfileType.Block {
|
||||
runtime.SetBlockProfileRate(1000 * 1000) // 1/millisecond
|
||||
}
|
||||
}
|
||||
|
||||
func resetFraction(c Config) {
|
||||
// These 2 lines are only required if you're using mutex or block profiling
|
||||
if c.ProfileType.Mutex {
|
||||
runtime.SetMutexProfileFraction(0)
|
||||
}
|
||||
if c.ProfileType.Block {
|
||||
runtime.SetBlockProfileRate(0)
|
||||
}
|
||||
}
|
||||
177
internal/profiling/profiling_test.go
Normal file
177
internal/profiling/profiling_test.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package profiling
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/grafana/pyroscope-go"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/conf"
|
||||
"github.com/zeromicro/go-zero/core/syncx"
|
||||
)
|
||||
|
||||
func TestStart(t *testing.T) {
|
||||
t.Run("profiling", func(t *testing.T) {
|
||||
var c Config
|
||||
assert.NoError(t, conf.FillDefault(&c))
|
||||
c.Name = "test"
|
||||
p := newProfiler(c)
|
||||
assert.NotNil(t, p)
|
||||
assert.NoError(t, p.Start())
|
||||
assert.NoError(t, p.Stop())
|
||||
})
|
||||
|
||||
t.Run("invalid config", func(t *testing.T) {
|
||||
mp := &mockProfiler{}
|
||||
newProfiler = func(c Config) profiler {
|
||||
return mp
|
||||
}
|
||||
|
||||
Start(Config{})
|
||||
|
||||
Start(Config{
|
||||
ServerAddr: "localhost:4040",
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("test start profiler", func(t *testing.T) {
|
||||
mp := &mockProfiler{}
|
||||
newProfiler = func(c Config) profiler {
|
||||
return mp
|
||||
}
|
||||
|
||||
c := Config{
|
||||
Name: "test",
|
||||
ServerAddr: "localhost:4040",
|
||||
CheckInterval: time.Millisecond,
|
||||
ProfilingDuration: time.Millisecond * 10,
|
||||
CpuThreshold: 0,
|
||||
}
|
||||
var done = make(chan struct{})
|
||||
go startPyroscope(c, done)
|
||||
|
||||
time.Sleep(time.Millisecond * 50)
|
||||
close(done)
|
||||
|
||||
assert.True(t, mp.started.True())
|
||||
assert.True(t, mp.stopped.True())
|
||||
})
|
||||
|
||||
t.Run("test start profiler with cpu overloaded", func(t *testing.T) {
|
||||
mp := &mockProfiler{}
|
||||
newProfiler = func(c Config) profiler {
|
||||
return mp
|
||||
}
|
||||
|
||||
c := Config{
|
||||
Name: "test",
|
||||
ServerAddr: "localhost:4040",
|
||||
CheckInterval: time.Millisecond,
|
||||
ProfilingDuration: time.Millisecond * 10,
|
||||
CpuThreshold: 900,
|
||||
}
|
||||
var done = make(chan struct{})
|
||||
go startPyroscope(c, done)
|
||||
|
||||
time.Sleep(time.Millisecond * 50)
|
||||
close(done)
|
||||
|
||||
assert.False(t, mp.started.True())
|
||||
})
|
||||
|
||||
t.Run("start/stop err", func(t *testing.T) {
|
||||
mp := &mockProfiler{
|
||||
err: assert.AnError,
|
||||
}
|
||||
newProfiler = func(c Config) profiler {
|
||||
return mp
|
||||
}
|
||||
|
||||
c := Config{
|
||||
Name: "test",
|
||||
ServerAddr: "localhost:4040",
|
||||
CheckInterval: time.Millisecond,
|
||||
ProfilingDuration: time.Millisecond * 10,
|
||||
CpuThreshold: 0,
|
||||
}
|
||||
var done = make(chan struct{})
|
||||
go startPyroscope(c, done)
|
||||
|
||||
time.Sleep(time.Millisecond * 50)
|
||||
close(done)
|
||||
|
||||
assert.False(t, mp.started.True())
|
||||
assert.False(t, mp.stopped.True())
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenPyroscopeConf(t *testing.T) {
|
||||
c := Config{
|
||||
Name: "",
|
||||
ServerAddr: "localhost:4040",
|
||||
AuthUser: "user",
|
||||
AuthPassword: "password",
|
||||
ProfileType: ProfileType{
|
||||
Logger: true,
|
||||
CPU: true,
|
||||
Goroutines: true,
|
||||
Memory: true,
|
||||
Mutex: true,
|
||||
Block: true,
|
||||
},
|
||||
}
|
||||
|
||||
pyroscopeConf := genPyroscopeConf(c)
|
||||
assert.Equal(t, c.ServerAddr, pyroscopeConf.ServerAddress)
|
||||
assert.Equal(t, c.AuthUser, pyroscopeConf.BasicAuthUser)
|
||||
assert.Equal(t, c.AuthPassword, pyroscopeConf.BasicAuthPassword)
|
||||
assert.Equal(t, c.Name, pyroscopeConf.ApplicationName)
|
||||
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileCPU)
|
||||
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileGoroutines)
|
||||
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileAllocObjects)
|
||||
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileAllocSpace)
|
||||
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileInuseObjects)
|
||||
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileInuseSpace)
|
||||
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileMutexCount)
|
||||
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileMutexDuration)
|
||||
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileBlockCount)
|
||||
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileBlockDuration)
|
||||
|
||||
setFraction(c)
|
||||
resetFraction(c)
|
||||
|
||||
newPyroscopeProfiler(c)
|
||||
}
|
||||
|
||||
func TestNewPyroscopeProfiler(t *testing.T) {
|
||||
p := newPyroscopeProfiler(Config{})
|
||||
|
||||
assert.Error(t, p.Start())
|
||||
assert.NoError(t, p.Stop())
|
||||
}
|
||||
|
||||
type mockProfiler struct {
|
||||
mutex sync.Mutex
|
||||
started syncx.AtomicBool
|
||||
stopped syncx.AtomicBool
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockProfiler) Start() error {
|
||||
m.mutex.Lock()
|
||||
if m.err == nil {
|
||||
m.started.Set(true)
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
return m.err
|
||||
}
|
||||
|
||||
func (m *mockProfiler) Stop() error {
|
||||
m.mutex.Lock()
|
||||
if m.err == nil {
|
||||
m.stopped.Set(true)
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
return m.err
|
||||
}
|
||||
@@ -34,7 +34,10 @@ type McpConf struct {
|
||||
// Cors contains allowed CORS origins
|
||||
Cors []string `json:",optional"`
|
||||
|
||||
// ToolTimeout is the maximum time allowed for tool execution
|
||||
ToolTimeout time.Duration `json:",default=30s"`
|
||||
// SseTimeout is the maximum time allowed for SSE connections
|
||||
SseTimeout time.Duration `json:",default=24h"`
|
||||
|
||||
// MessageTimeout is the maximum time allowed for request execution
|
||||
MessageTimeout time.Duration `json:",default=30s"`
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ mcp:
|
||||
assert.Equal(t, "2024-11-05", c.Mcp.ProtocolVersion, "Default protocol version should be 2024-11-05")
|
||||
assert.Equal(t, "/sse", c.Mcp.SseEndpoint, "Default SSE endpoint should be /sse")
|
||||
assert.Equal(t, "/message", c.Mcp.MessageEndpoint, "Default message endpoint should be /message")
|
||||
assert.Equal(t, 30*time.Second, c.Mcp.ToolTimeout, "Default tool timeout should be 30s")
|
||||
assert.Equal(t, 30*time.Second, c.Mcp.MessageTimeout, "Default message timeout should be 30s")
|
||||
}
|
||||
|
||||
func TestMcpConfCustomValues(t *testing.T) {
|
||||
@@ -43,7 +43,7 @@ func TestMcpConfCustomValues(t *testing.T) {
|
||||
"SseEndpoint": "/custom-sse",
|
||||
"MessageEndpoint": "/custom-message",
|
||||
"Cors": ["http://localhost:3000", "http://example.com"],
|
||||
"ToolTimeout": "60s"
|
||||
"MessageTimeout": "60s"
|
||||
}
|
||||
}`
|
||||
|
||||
@@ -59,5 +59,5 @@ func TestMcpConfCustomValues(t *testing.T) {
|
||||
assert.Equal(t, "/custom-sse", c.Mcp.SseEndpoint, "SSE endpoint should be customizable")
|
||||
assert.Equal(t, "/custom-message", c.Mcp.MessageEndpoint, "Message endpoint should be customizable")
|
||||
assert.Equal(t, []string{"http://localhost:3000", "http://example.com"}, c.Mcp.Cors, "CORS settings should be customizable")
|
||||
assert.Equal(t, 60*time.Second, c.Mcp.ToolTimeout, "Tool timeout should be customizable")
|
||||
assert.Equal(t, 60*time.Second, c.Mcp.MessageTimeout, "Tool timeout should be customizable")
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package mcp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@@ -59,7 +60,7 @@ func TestHTTPHandlerIntegration(t *testing.T) {
|
||||
conf := McpConf{}
|
||||
conf.Mcp.Name = "test-integration"
|
||||
conf.Mcp.Version = "1.0.0-test"
|
||||
conf.Mcp.ToolTimeout = 1 * time.Second
|
||||
conf.Mcp.MessageTimeout = 1 * time.Second
|
||||
|
||||
// Create a mock server directly
|
||||
server := &sseMcpServer{
|
||||
@@ -75,7 +76,6 @@ func TestHTTPHandlerIntegration(t *testing.T) {
|
||||
Name: "echo",
|
||||
Description: "Echo tool for testing",
|
||||
InputSchema: InputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{
|
||||
"message": map[string]any{
|
||||
"type": "string",
|
||||
@@ -83,7 +83,7 @@ func TestHTTPHandlerIntegration(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
Handler: func(params map[string]any) (any, error) {
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
if msg, ok := params["message"].(string); ok {
|
||||
return fmt.Sprintf("Echo: %s", msg), nil
|
||||
}
|
||||
@@ -181,7 +181,7 @@ func TestHandlerResponseFlow(t *testing.T) {
|
||||
Name: "test.tool",
|
||||
Description: "Test tool",
|
||||
InputSchema: InputSchema{Type: "object"},
|
||||
Handler: func(params map[string]any) (any, error) {
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
return "tool result", nil
|
||||
},
|
||||
})
|
||||
@@ -329,7 +329,7 @@ func TestProcessListMethods(t *testing.T) {
|
||||
Params: json.RawMessage(`{"cursor": "", "_meta": {"progressToken": "token1"}}`),
|
||||
}
|
||||
|
||||
server.processListTools(client, req)
|
||||
server.processListTools(context.Background(), client, req)
|
||||
|
||||
// Read response
|
||||
select {
|
||||
@@ -344,7 +344,7 @@ func TestProcessListMethods(t *testing.T) {
|
||||
req.ID = 2
|
||||
req.Method = methodPromptsList
|
||||
req.Params = json.RawMessage(`{"cursor": "next"}`)
|
||||
server.processListPrompts(client, req)
|
||||
server.processListPrompts(context.Background(), client, req)
|
||||
|
||||
// Read response
|
||||
select {
|
||||
@@ -358,7 +358,7 @@ func TestProcessListMethods(t *testing.T) {
|
||||
req.ID = 3
|
||||
req.Method = methodResourcesList
|
||||
req.Params = json.RawMessage(`{"cursor": "next"}`)
|
||||
server.processListResources(client, req)
|
||||
server.processListResources(context.Background(), client, req)
|
||||
|
||||
// Read response
|
||||
select {
|
||||
@@ -393,7 +393,7 @@ func TestErrorResponseHandling(t *testing.T) {
|
||||
}
|
||||
|
||||
// Mock handleRequest by directly calling error handler
|
||||
server.sendErrorResponse(client, req.ID, "Method not found", errCodeMethodNotFound)
|
||||
server.sendErrorResponse(context.Background(), client, req.ID, "Method not found", errCodeMethodNotFound)
|
||||
|
||||
// Check response
|
||||
select {
|
||||
@@ -412,7 +412,7 @@ func TestErrorResponseHandling(t *testing.T) {
|
||||
}
|
||||
|
||||
// Call process method directly
|
||||
server.processToolCall(client, toolReq)
|
||||
server.processToolCall(context.Background(), client, toolReq)
|
||||
|
||||
// Check response
|
||||
select {
|
||||
@@ -431,7 +431,7 @@ func TestErrorResponseHandling(t *testing.T) {
|
||||
}
|
||||
|
||||
// Call process method directly
|
||||
server.processGetPrompt(client, promptReq)
|
||||
server.processGetPrompt(context.Background(), client, promptReq)
|
||||
|
||||
// Check response
|
||||
select {
|
||||
|
||||
23
mcp/parser.go
Normal file
23
mcp/parser.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/mapping"
|
||||
)
|
||||
|
||||
// ParseArguments parses the arguments and populates the request object
|
||||
func ParseArguments(args any, req any) error {
|
||||
switch arguments := args.(type) {
|
||||
case map[string]string:
|
||||
m := make(map[string]any, len(arguments))
|
||||
for k, v := range arguments {
|
||||
m[k] = v
|
||||
}
|
||||
return mapping.UnmarshalJsonMap(m, req, mapping.WithStringValues())
|
||||
case map[string]any:
|
||||
return mapping.UnmarshalJsonMap(arguments, req)
|
||||
default:
|
||||
return fmt.Errorf("unsupported argument type: %T", arguments)
|
||||
}
|
||||
}
|
||||
139
mcp/parser_test.go
Normal file
139
mcp/parser_test.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestParseArguments_MapStringString tests parsing map[string]string arguments
|
||||
func TestParseArguments_MapStringString(t *testing.T) {
|
||||
// Sample request struct to populate
|
||||
type requestStruct struct {
|
||||
Name string `json:"name"`
|
||||
Message string `json:"message"`
|
||||
Count int `json:"count"`
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
// Create test arguments
|
||||
args := map[string]string{
|
||||
"name": "test-name",
|
||||
"message": "hello world",
|
||||
"count": "42",
|
||||
"enabled": "true",
|
||||
}
|
||||
|
||||
// Create a target object to populate
|
||||
var req requestStruct
|
||||
|
||||
// Parse the arguments
|
||||
err := ParseArguments(args, &req)
|
||||
|
||||
// Verify results
|
||||
assert.NoError(t, err, "Should parse map[string]string without error")
|
||||
assert.Equal(t, "test-name", req.Name, "Name should be correctly parsed")
|
||||
assert.Equal(t, "hello world", req.Message, "Message should be correctly parsed")
|
||||
assert.Equal(t, 42, req.Count, "Count should be correctly parsed to int")
|
||||
assert.True(t, req.Enabled, "Enabled should be correctly parsed to bool")
|
||||
}
|
||||
|
||||
// TestParseArguments_MapStringAny tests parsing map[string]any arguments
|
||||
func TestParseArguments_MapStringAny(t *testing.T) {
|
||||
// Sample request struct to populate
|
||||
type requestStruct struct {
|
||||
Name string `json:"name"`
|
||||
Message string `json:"message"`
|
||||
Count int `json:"count"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Tags []string `json:"tags"`
|
||||
Metadata map[string]string `json:"metadata"`
|
||||
}
|
||||
|
||||
// Create test arguments with mixed types
|
||||
args := map[string]any{
|
||||
"name": "test-name",
|
||||
"message": "hello world",
|
||||
"count": 42, // note: this is already an int
|
||||
"enabled": true, // note: this is already a bool
|
||||
"tags": []string{"tag1", "tag2"},
|
||||
"metadata": map[string]string{
|
||||
"key1": "value1",
|
||||
"key2": "value2",
|
||||
},
|
||||
}
|
||||
|
||||
// Create a target object to populate
|
||||
var req requestStruct
|
||||
|
||||
// Parse the arguments
|
||||
err := ParseArguments(args, &req)
|
||||
|
||||
// Verify results
|
||||
assert.NoError(t, err, "Should parse map[string]any without error")
|
||||
assert.Equal(t, "test-name", req.Name, "Name should be correctly parsed")
|
||||
assert.Equal(t, "hello world", req.Message, "Message should be correctly parsed")
|
||||
assert.Equal(t, 42, req.Count, "Count should be correctly parsed")
|
||||
assert.True(t, req.Enabled, "Enabled should be correctly parsed")
|
||||
assert.Equal(t, []string{"tag1", "tag2"}, req.Tags, "Tags should be correctly parsed")
|
||||
assert.Equal(t, map[string]string{
|
||||
"key1": "value1",
|
||||
"key2": "value2",
|
||||
}, req.Metadata, "Metadata should be correctly parsed")
|
||||
}
|
||||
|
||||
// TestParseArguments_UnsupportedType tests parsing with an unsupported type
|
||||
func TestParseArguments_UnsupportedType(t *testing.T) {
|
||||
// Sample request struct to populate
|
||||
type requestStruct struct {
|
||||
Name string `json:"name"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// Use an unsupported argument type (slice)
|
||||
args := []string{"not", "a", "map"}
|
||||
|
||||
// Create a target object to populate
|
||||
var req requestStruct
|
||||
|
||||
// Parse the arguments
|
||||
err := ParseArguments(args, &req)
|
||||
|
||||
// Verify error is returned with correct message
|
||||
assert.Error(t, err, "Should return error for unsupported type")
|
||||
assert.Contains(t, err.Error(), "unsupported argument type", "Error should mention unsupported type")
|
||||
assert.Contains(t, err.Error(), "[]string", "Error should include the actual type")
|
||||
}
|
||||
|
||||
// TestParseArguments_EmptyMap tests parsing with empty maps
|
||||
func TestParseArguments_EmptyMap(t *testing.T) {
|
||||
// Sample request struct to populate
|
||||
type requestStruct struct {
|
||||
Name string `json:"name,optional"`
|
||||
Message string `json:"message,optional"`
|
||||
}
|
||||
|
||||
// Test empty map[string]string
|
||||
t.Run("EmptyMapStringString", func(t *testing.T) {
|
||||
args := map[string]string{}
|
||||
var req requestStruct
|
||||
|
||||
err := ParseArguments(args, &req)
|
||||
|
||||
assert.NoError(t, err, "Should parse empty map[string]string without error")
|
||||
assert.Empty(t, req.Name, "Name should be empty string")
|
||||
assert.Empty(t, req.Message, "Message should be empty string")
|
||||
})
|
||||
|
||||
// Test empty map[string]any
|
||||
t.Run("EmptyMapStringAny", func(t *testing.T) {
|
||||
args := map[string]any{}
|
||||
var req requestStruct
|
||||
|
||||
err := ParseArguments(args, &req)
|
||||
|
||||
assert.NoError(t, err, "Should parse empty map[string]any without error")
|
||||
assert.Empty(t, req.Name, "Name should be empty string")
|
||||
assert.Empty(t, req.Message, "Message should be empty string")
|
||||
})
|
||||
}
|
||||
824
mcp/readme.md
824
mcp/readme.md
@@ -1,7 +1,7 @@
|
||||
# Model Context Protocol (MCP) SDK Implementation
|
||||
# Model Context Protocol (MCP) Implementation
|
||||
|
||||
## Overview
|
||||
This package implements a Model Context Protocol (MCP) server in Go that facilitates real-time communication between AI models and clients using Server-Sent Events (SSE). The implementation provides a framework for building AI-assisted applications with bidirectional communication capabilities.
|
||||
This package implements the Model Context Protocol (MCP) server specification in Go, providing a framework for real-time communication between AI models and clients using Server-Sent Events (SSE). The implementation follows the standardized protocol for building AI-assisted applications with bidirectional communication capabilities.
|
||||
|
||||
## Core Components
|
||||
|
||||
@@ -54,9 +54,817 @@ This package implements a Model Context Protocol (MCP) server in Go that facilit
|
||||
|
||||
## Usage
|
||||
|
||||
To create and use an MCP server, see the examples directory for practical implementation examples including:
|
||||
- Tool registration and execution
|
||||
- Static and dynamic prompt creation
|
||||
- Resource handling with proper URI identification
|
||||
- Embedded resources in prompt responses
|
||||
- Client connection management
|
||||
### Setting Up an MCP Server
|
||||
|
||||
To create and start an MCP server:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/zeromicro/go-zero/core/conf"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/mcp"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Load configuration from YAML file
|
||||
var c mcp.McpConf
|
||||
conf.MustLoad("config.yaml", &c)
|
||||
|
||||
// Optional: Disable stats logging
|
||||
logx.DisableStat()
|
||||
|
||||
// Create MCP server
|
||||
server := mcp.NewMcpServer(c)
|
||||
|
||||
// Register tools, prompts, and resources (examples below)
|
||||
|
||||
// Start the server and ensure it's stopped on exit
|
||||
defer server.Stop()
|
||||
server.Start()
|
||||
}
|
||||
```
|
||||
|
||||
Sample configuration file (config.yaml):
|
||||
|
||||
```yaml
|
||||
name: mcp-server
|
||||
host: localhost
|
||||
port: 8080
|
||||
mcp:
|
||||
name: my-mcp-server
|
||||
messageTimeout: 30s # Timeout for tool calls
|
||||
cors:
|
||||
- http://localhost:3000 # Optional CORS configuration
|
||||
```
|
||||
|
||||
### Registering Tools
|
||||
|
||||
Tools allow AI models to execute custom code through the MCP protocol.
|
||||
|
||||
#### Basic Tool Example:
|
||||
|
||||
```go
|
||||
// Register a simple echo tool
|
||||
echoTool := mcp.Tool{
|
||||
Name: "echo",
|
||||
Description: "Echoes back the message provided by the user",
|
||||
InputSchema: mcp.InputSchema{
|
||||
Properties: map[string]any{
|
||||
"message": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The message to echo back",
|
||||
},
|
||||
"prefix": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Optional prefix to add to the echoed message",
|
||||
"default": "Echo: ",
|
||||
},
|
||||
},
|
||||
Required: []string{"message"},
|
||||
},
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
var req struct {
|
||||
Message string `json:"message"`
|
||||
Prefix string `json:"prefix,optional"`
|
||||
}
|
||||
|
||||
if err := mcp.ParseArguments(params, &req); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse params: %w", err)
|
||||
}
|
||||
|
||||
prefix := "Echo: "
|
||||
if len(req.Prefix) > 0 {
|
||||
prefix = req.Prefix
|
||||
}
|
||||
|
||||
return prefix + req.Message, nil
|
||||
},
|
||||
}
|
||||
|
||||
server.RegisterTool(echoTool)
|
||||
```
|
||||
|
||||
#### Tool with Different Response Types:
|
||||
|
||||
```go
|
||||
// Tool returning JSON data
|
||||
dataTool := mcp.Tool{
|
||||
Name: "data.generate",
|
||||
Description: "Generates sample data in various formats",
|
||||
InputSchema: mcp.InputSchema{
|
||||
Properties: map[string]any{
|
||||
"format": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Format of data (json, text)",
|
||||
"enum": []string{"json", "text"},
|
||||
},
|
||||
},
|
||||
},
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
var req struct {
|
||||
Format string `json:"format"`
|
||||
}
|
||||
|
||||
if err := mcp.ParseArguments(params, &req); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse params: %w", err)
|
||||
}
|
||||
|
||||
if req.Format == "json" {
|
||||
// Return structured data
|
||||
return map[string]any{
|
||||
"items": []map[string]any{
|
||||
{"id": 1, "name": "Item 1"},
|
||||
{"id": 2, "name": "Item 2"},
|
||||
},
|
||||
"count": 2,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Default to text
|
||||
return "Sample text data", nil
|
||||
},
|
||||
}
|
||||
|
||||
server.RegisterTool(dataTool)
|
||||
```
|
||||
|
||||
#### Image Generation Tool Example:
|
||||
|
||||
```go
|
||||
// Tool returning image content
|
||||
imageTool := mcp.Tool{
|
||||
Name: "image.generate",
|
||||
Description: "Generates a simple image",
|
||||
InputSchema: mcp.InputSchema{
|
||||
Properties: map[string]any{
|
||||
"type": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Type of image to generate",
|
||||
"default": "placeholder",
|
||||
},
|
||||
},
|
||||
},
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
// Return image content directly
|
||||
return mcp.ImageContent{
|
||||
Data: "base64EncodedImageData...", // Base64 encoded image data
|
||||
MimeType: "image/png",
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
server.RegisterTool(imageTool)
|
||||
```
|
||||
|
||||
#### Using ToolResult for Custom Outputs:
|
||||
|
||||
```go
|
||||
// Tool that returns a custom ToolResult type
|
||||
customResultTool := mcp.Tool{
|
||||
Name: "custom.result",
|
||||
Description: "Returns a custom formatted result",
|
||||
InputSchema: mcp.InputSchema{
|
||||
Properties: map[string]any{
|
||||
"resultType": map[string]any{
|
||||
"type": "string",
|
||||
"enum": []string{"text", "image"},
|
||||
},
|
||||
},
|
||||
},
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
var req struct {
|
||||
ResultType string `json:"resultType"`
|
||||
}
|
||||
|
||||
if err := mcp.ParseArguments(params, &req); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse params: %w", err)
|
||||
}
|
||||
|
||||
if req.ResultType == "image" {
|
||||
return mcp.ToolResult{
|
||||
Type: mcp.ContentTypeImage,
|
||||
Content: map[string]any{
|
||||
"data": "base64EncodedImageData...",
|
||||
"mimeType": "image/jpeg",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Default to text
|
||||
return mcp.ToolResult{
|
||||
Type: mcp.ContentTypeText,
|
||||
Content: "This is a text result from ToolResult",
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
server.RegisterTool(customResultTool)
|
||||
```
|
||||
|
||||
### Registering Prompts
|
||||
|
||||
Prompts are reusable conversation templates for AI models.
|
||||
|
||||
#### Static Prompt Example:
|
||||
|
||||
```go
|
||||
// Register a simple static prompt with placeholders
|
||||
server.RegisterPrompt(mcp.Prompt{
|
||||
Name: "hello",
|
||||
Description: "A simple hello prompt",
|
||||
Arguments: []mcp.PromptArgument{
|
||||
{
|
||||
Name: "name",
|
||||
Description: "The name to greet",
|
||||
Required: false,
|
||||
},
|
||||
},
|
||||
Content: "Say hello to {{name}} and introduce yourself as an AI assistant.",
|
||||
})
|
||||
```
|
||||
|
||||
#### Dynamic Prompt with Handler Function:
|
||||
|
||||
```go
|
||||
// Register a prompt with a dynamic handler function
|
||||
server.RegisterPrompt(mcp.Prompt{
|
||||
Name: "dynamic-prompt",
|
||||
Description: "A prompt that uses a handler to generate dynamic content",
|
||||
Arguments: []mcp.PromptArgument{
|
||||
{
|
||||
Name: "username",
|
||||
Description: "User's name for personalized greeting",
|
||||
Required: true,
|
||||
},
|
||||
{
|
||||
Name: "topic",
|
||||
Description: "Topic of expertise",
|
||||
Required: true,
|
||||
},
|
||||
},
|
||||
Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) {
|
||||
var req struct {
|
||||
Username string `json:"username"`
|
||||
Topic string `json:"topic"`
|
||||
}
|
||||
|
||||
if err := mcp.ParseArguments(args, &req); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse args: %w", err)
|
||||
}
|
||||
|
||||
// Create a user message
|
||||
userMessage := mcp.PromptMessage{
|
||||
Role: mcp.RoleUser,
|
||||
Content: mcp.TextContent{
|
||||
Text: fmt.Sprintf("Hello, I'm %s and I'd like to learn about %s.", req.Username, req.Topic),
|
||||
},
|
||||
}
|
||||
|
||||
// Create an assistant response with current time
|
||||
currentTime := time.Now().Format(time.RFC1123)
|
||||
assistantMessage := mcp.PromptMessage{
|
||||
Role: mcp.RoleAssistant,
|
||||
Content: mcp.TextContent{
|
||||
Text: fmt.Sprintf("Hello %s! I'm an AI assistant and I'll help you learn about %s. The current time is %s.",
|
||||
req.Username, req.Topic, currentTime),
|
||||
},
|
||||
}
|
||||
|
||||
// Return both messages as a conversation
|
||||
return []mcp.PromptMessage{userMessage, assistantMessage}, nil
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
#### Multi-Message Prompt with Code Examples:
|
||||
|
||||
```go
|
||||
// Register a prompt that provides code examples in different programming languages
|
||||
server.RegisterPrompt(mcp.Prompt{
|
||||
Name: "code-example",
|
||||
Description: "Provides code examples in different programming languages",
|
||||
Arguments: []mcp.PromptArgument{
|
||||
{
|
||||
Name: "language",
|
||||
Description: "Programming language for the example",
|
||||
Required: true,
|
||||
},
|
||||
{
|
||||
Name: "complexity",
|
||||
Description: "Complexity level (simple, medium, advanced)",
|
||||
},
|
||||
},
|
||||
Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) {
|
||||
var req struct {
|
||||
Language string `json:"language"`
|
||||
Complexity string `json:"complexity,optional"`
|
||||
}
|
||||
|
||||
if err := mcp.ParseArguments(args, &req); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse args: %w", err)
|
||||
}
|
||||
|
||||
// Validate language
|
||||
supportedLanguages := map[string]bool{"go": true, "python": true, "javascript": true, "rust": true}
|
||||
if !supportedLanguages[req.Language] {
|
||||
return nil, fmt.Errorf("unsupported language: %s", req.Language)
|
||||
}
|
||||
|
||||
// Generate code example based on language and complexity
|
||||
var codeExample string
|
||||
|
||||
switch req.Language {
|
||||
case "go":
|
||||
if req.Complexity == "simple" {
|
||||
codeExample = `
|
||||
package main
|
||||
|
||||
import "fmt"
|
||||
|
||||
func main() {
|
||||
fmt.Println("Hello, World!")
|
||||
}`
|
||||
} else {
|
||||
codeExample = `
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
func main() {
|
||||
now := time.Now()
|
||||
fmt.Printf("Hello, World! Current time is %s\n", now.Format(time.RFC3339))
|
||||
}`
|
||||
}
|
||||
case "python":
|
||||
// Python example code
|
||||
if req.Complexity == "simple" {
|
||||
codeExample = `
|
||||
def greet(name):
|
||||
return f"Hello, {name}!"
|
||||
|
||||
print(greet("World"))`
|
||||
} else {
|
||||
codeExample = `
|
||||
import datetime
|
||||
|
||||
def greet(name, include_time=False):
|
||||
message = f"Hello, {name}!"
|
||||
if include_time:
|
||||
message += f" Current time is {datetime.datetime.now().isoformat()}"
|
||||
return message
|
||||
|
||||
print(greet("World", include_time=True))`
|
||||
}
|
||||
}
|
||||
|
||||
// Create messages array according to MCP spec
|
||||
messages := []mcp.PromptMessage{
|
||||
{
|
||||
Role: mcp.RoleAssistant,
|
||||
Content: mcp.TextContent{
|
||||
Text: fmt.Sprintf("You are a helpful coding assistant specialized in %s programming.", req.Language),
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: mcp.RoleUser,
|
||||
Content: mcp.TextContent{
|
||||
Text: fmt.Sprintf("Show me a %s example of a Hello World program in %s.", req.Complexity, req.Language),
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: mcp.RoleAssistant,
|
||||
Content: mcp.TextContent{
|
||||
Text: fmt.Sprintf("Here's a %s example in %s:\n\n```%s%s\n```\n\nHow can I help you implement this?",
|
||||
req.Complexity, req.Language, req.Language, codeExample),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return messages, nil
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
### Registering Resources
|
||||
|
||||
Resources provide access to external content such as files or generated data.
|
||||
|
||||
#### Basic Resource Example:
|
||||
|
||||
```go
|
||||
// Register a static resource
|
||||
server.RegisterResource(mcp.Resource{
|
||||
Name: "example-document",
|
||||
URI: "file:///example/document.txt",
|
||||
Description: "An example document",
|
||||
MimeType: "text/plain",
|
||||
Handler: func(ctx context.Context) (mcp.ResourceContent, error) {
|
||||
return mcp.ResourceContent{
|
||||
URI: "file:///example/document.txt",
|
||||
MimeType: "text/plain",
|
||||
Text: "This is an example document content.",
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
#### Dynamic Resource with Code Example:
|
||||
|
||||
```go
|
||||
// Register a Go code resource with dynamic handler
|
||||
server.RegisterResource(mcp.Resource{
|
||||
Name: "go-example",
|
||||
URI: "file:///project/src/main.go",
|
||||
Description: "A simple Go example with multiple files",
|
||||
MimeType: "text/x-go",
|
||||
Handler: func(ctx context.Context) (mcp.ResourceContent, error) {
|
||||
// Return ResourceContent with all required fields
|
||||
return mcp.ResourceContent{
|
||||
URI: "file:///project/src/main.go",
|
||||
MimeType: "text/x-go",
|
||||
Text: "package main\n\nimport (\n\t\"fmt\"\n\t\"./greeting\"\n)\n\nfunc main() {\n\tfmt.Println(greeting.Hello(\"world\"))\n}",
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
|
||||
// Register a companion file for the above example
|
||||
server.RegisterResource(mcp.Resource{
|
||||
Name: "go-greeting",
|
||||
URI: "file:///project/src/greeting/greeting.go",
|
||||
Description: "A greeting package for the Go example",
|
||||
MimeType: "text/x-go",
|
||||
Handler: func(ctx context.Context) (mcp.ResourceContent, error) {
|
||||
return mcp.ResourceContent{
|
||||
URI: "file:///project/src/greeting/greeting.go",
|
||||
MimeType: "text/x-go",
|
||||
Text: "package greeting\n\nfunc Hello(name string) string {\n\treturn \"Hello, \" + name + \"!\"\n}",
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
#### Binary Resource Example:
|
||||
|
||||
```go
|
||||
// Register a binary resource (like an image)
|
||||
server.RegisterResource(mcp.Resource{
|
||||
Name: "example-image",
|
||||
URI: "file:///example/image.png",
|
||||
Description: "An example image",
|
||||
MimeType: "image/png",
|
||||
Handler: func(ctx context.Context) (mcp.ResourceContent, error) {
|
||||
// Read image from file or generate it
|
||||
imageData := "base64EncodedImageData..." // Base64 encoded image data
|
||||
|
||||
return mcp.ResourceContent{
|
||||
URI: "file:///example/image.png",
|
||||
MimeType: "image/png",
|
||||
Blob: imageData, // For binary data
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
### Using Resources in Prompts
|
||||
|
||||
You can embed resources in prompt responses to create rich interactions with proper MCP-compliant structure:
|
||||
|
||||
```go
|
||||
// Register a prompt that embeds a resource
|
||||
server.RegisterPrompt(mcp.Prompt{
|
||||
Name: "resource-example",
|
||||
Description: "A prompt that embeds a resource",
|
||||
Arguments: []mcp.PromptArgument{
|
||||
{
|
||||
Name: "file_type",
|
||||
Description: "Type of file to show (rust or go)",
|
||||
Required: true,
|
||||
},
|
||||
},
|
||||
Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) {
|
||||
var req struct {
|
||||
FileType string `json:"file_type"`
|
||||
}
|
||||
|
||||
if err := mcp.ParseArguments(args, &req); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse args: %w", err)
|
||||
}
|
||||
|
||||
var resourceURI, mimeType, fileContent string
|
||||
if req.FileType == "rust" {
|
||||
resourceURI = "file:///project/src/main.rs"
|
||||
mimeType = "text/x-rust"
|
||||
fileContent = "fn main() {\n println!(\"Hello world!\");\n}"
|
||||
} else {
|
||||
resourceURI = "file:///project/src/main.go"
|
||||
mimeType = "text/x-go"
|
||||
fileContent = "package main\n\nimport \"fmt\"\n\nfunc main() {\n\tfmt.Println(\"Hello, world!\")\n}"
|
||||
}
|
||||
|
||||
// Create message with embedded resource using proper MCP format
|
||||
return []mcp.PromptMessage{
|
||||
{
|
||||
Role: mcp.RoleUser,
|
||||
Content: mcp.TextContent{
|
||||
Text: fmt.Sprintf("Can you explain this %s code?", req.FileType),
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: mcp.RoleAssistant,
|
||||
Content: mcp.EmbeddedResource{
|
||||
Type: mcp.ContentTypeResource,
|
||||
Resource: struct {
|
||||
URI string `json:"uri"`
|
||||
MimeType string `json:"mimeType"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Blob string `json:"blob,omitempty"`
|
||||
}{
|
||||
URI: resourceURI,
|
||||
MimeType: mimeType,
|
||||
Text: fileContent,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: mcp.RoleAssistant,
|
||||
Content: mcp.TextContent{
|
||||
Text: fmt.Sprintf("Above is a simple Hello World example in %s. Let me explain how it works.", req.FileType),
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
### Multiple File Resources Example
|
||||
|
||||
```go
|
||||
// Register a prompt that demonstrates embedding multiple resource files
|
||||
server.RegisterPrompt(mcp.Prompt{
|
||||
Name: "go-code-example",
|
||||
Description: "A prompt that correctly embeds multiple resource files",
|
||||
Arguments: []mcp.PromptArgument{
|
||||
{
|
||||
Name: "format",
|
||||
Description: "How to format the code display",
|
||||
},
|
||||
},
|
||||
Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) {
|
||||
var req struct {
|
||||
Format string `json:"format,optional"`
|
||||
}
|
||||
|
||||
if err := mcp.ParseArguments(args, &req); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse args: %w", err)
|
||||
}
|
||||
|
||||
// Get the Go code for multiple files
|
||||
var mainGoText string = "package main\n\nimport (\n\t\"fmt\"\n\t\"./greeting\"\n)\n\nfunc main() {\n\tfmt.Println(greeting.Hello(\"world\"))\n}"
|
||||
var greetingGoText string = "package greeting\n\nfunc Hello(name string) string {\n\treturn \"Hello, \" + name + \"!\"\n}"
|
||||
|
||||
// Create message with properly formatted embedded resource per MCP spec
|
||||
messages := []mcp.PromptMessage{
|
||||
{
|
||||
Role: mcp.RoleUser,
|
||||
Content: mcp.TextContent{
|
||||
Text: "Show me a simple Go example with proper imports.",
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: mcp.RoleAssistant,
|
||||
Content: mcp.TextContent{
|
||||
Text: "Here's a simple Go example project:",
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: mcp.RoleAssistant,
|
||||
Content: mcp.EmbeddedResource{
|
||||
Type: mcp.ContentTypeResource,
|
||||
Resource: struct {
|
||||
URI string `json:"uri"`
|
||||
MimeType string `json:"mimeType"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Blob string `json:"blob,omitempty"`
|
||||
}{
|
||||
URI: "file:///project/src/main.go",
|
||||
MimeType: "text/x-go",
|
||||
Text: mainGoText,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Add explanation and additional file if requested
|
||||
if req.Format == "with_explanation" {
|
||||
messages = append(messages, mcp.PromptMessage{
|
||||
Role: mcp.RoleAssistant,
|
||||
Content: mcp.TextContent{
|
||||
Text: "This example demonstrates a simple Go application with modular structure. The main.go file imports from a local 'greeting' package that provides the Hello function.",
|
||||
},
|
||||
})
|
||||
|
||||
// Also show the greeting.go file with correct resource format
|
||||
messages = append(messages, mcp.PromptMessage{
|
||||
Role: mcp.RoleAssistant,
|
||||
Content: mcp.EmbeddedResource{
|
||||
Type: mcp.ContentTypeResource,
|
||||
Resource: struct {
|
||||
URI string `json:"uri"`
|
||||
MimeType string `json:"mimeType"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Blob string `json:"blob,omitempty"`
|
||||
}{
|
||||
URI: "file:///project/src/greeting/greeting.go",
|
||||
MimeType: "text/x-go",
|
||||
Text: greetingGoText,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return messages, nil
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
### Complete Application Example
|
||||
|
||||
Here's a complete example demonstrating all the components:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/conf"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/mcp"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Load configuration
|
||||
var c mcp.McpConf
|
||||
if err := conf.Load("config.yaml", &c); err != nil {
|
||||
log.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
// Set up logging
|
||||
logx.DisableStat()
|
||||
|
||||
// Create MCP server
|
||||
server := mcp.NewMcpServer(c)
|
||||
defer server.Stop()
|
||||
|
||||
// Register a simple echo tool
|
||||
echoTool := mcp.Tool{
|
||||
Name: "echo",
|
||||
Description: "Echoes back the message provided by the user",
|
||||
InputSchema: mcp.InputSchema{
|
||||
Properties: map[string]any{
|
||||
"message": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The message to echo back",
|
||||
},
|
||||
"prefix": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Optional prefix to add to the echoed message",
|
||||
"default": "Echo: ",
|
||||
},
|
||||
},
|
||||
Required: []string{"message"},
|
||||
},
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
var req struct {
|
||||
Message string `json:"message"`
|
||||
Prefix string `json:"prefix,optional"`
|
||||
}
|
||||
|
||||
if err := mcp.ParseArguments(params, &req); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse args: %w", err)
|
||||
}
|
||||
|
||||
prefix := "Echo: "
|
||||
if len(req.Prefix) > 0 {
|
||||
prefix = req.Prefix
|
||||
}
|
||||
|
||||
return prefix + req.Message, nil
|
||||
},
|
||||
}
|
||||
server.RegisterTool(echoTool)
|
||||
|
||||
// Register a static prompt
|
||||
server.RegisterPrompt(mcp.Prompt{
|
||||
Name: "greeting",
|
||||
Description: "A simple greeting prompt",
|
||||
Arguments: []mcp.PromptArgument{
|
||||
{
|
||||
Name: "name",
|
||||
Description: "The name to greet",
|
||||
Required: true,
|
||||
},
|
||||
},
|
||||
Content: "Hello {{name}}! How can I assist you today?",
|
||||
})
|
||||
|
||||
// Register a dynamic prompt
|
||||
server.RegisterPrompt(mcp.Prompt{
|
||||
Name: "dynamic-prompt",
|
||||
Description: "A prompt that uses a handler to generate dynamic content",
|
||||
Arguments: []mcp.PromptArgument{
|
||||
{
|
||||
Name: "username",
|
||||
Description: "User's name for personalized greeting",
|
||||
Required: true,
|
||||
},
|
||||
{
|
||||
Name: "topic",
|
||||
Description: "Topic of expertise",
|
||||
Required: true,
|
||||
},
|
||||
},
|
||||
Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) {
|
||||
var req struct {
|
||||
Username string `json:"username"`
|
||||
Topic string `json:"topic"`
|
||||
}
|
||||
|
||||
if err := mcp.ParseArguments(args, &req); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse args: %w", err)
|
||||
}
|
||||
|
||||
// Create messages with current time
|
||||
currentTime := time.Now().Format(time.RFC1123)
|
||||
return []mcp.PromptMessage{
|
||||
{
|
||||
Role: mcp.RoleUser,
|
||||
Content: mcp.TextContent{
|
||||
Text: fmt.Sprintf("Hello, I'm %s and I'd like to learn about %s.", req.Username, req.Topic),
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: mcp.RoleAssistant,
|
||||
Content: mcp.TextContent{
|
||||
Text: fmt.Sprintf("Hello %s! I'm an AI assistant and I'll help you learn about %s. The current time is %s.",
|
||||
req.Username, req.Topic, currentTime),
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
|
||||
// Register a resource
|
||||
server.RegisterResource(mcp.Resource{
|
||||
Name: "example-doc",
|
||||
URI: "file:///example/doc.txt",
|
||||
Description: "An example document",
|
||||
MimeType: "text/plain",
|
||||
Handler: func(ctx context.Context) (mcp.ResourceContent, error) {
|
||||
return mcp.ResourceContent{
|
||||
URI: "file:///example/doc.txt",
|
||||
MimeType: "text/plain",
|
||||
Text: "This is the content of the example document.",
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
|
||||
// Start the server
|
||||
fmt.Printf("Starting MCP server on %s:%d\n", c.Host, c.Port)
|
||||
server.Start()
|
||||
}
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
The MCP implementation provides comprehensive error handling:
|
||||
|
||||
- Tool execution errors are properly reported back to clients
|
||||
- Missing or invalid parameters are detected and reported with appropriate error codes
|
||||
- Resource and prompt lookup failures are handled gracefully
|
||||
- Timeout handling for long-running tool executions using context
|
||||
- Panic recovery to prevent server crashes
|
||||
|
||||
## Advanced Features
|
||||
|
||||
- **Annotations**: Add audience and priority metadata to content
|
||||
- **Content Types**: Support for text, images, audio, and other content formats
|
||||
- **Embedded Resources**: Include file resources directly in prompt responses
|
||||
- **Context Awareness**: All handlers receive context.Context for timeout and cancellation support
|
||||
- **Progress Tokens**: Support for tracking progress of long-running operations
|
||||
- **Customizable Timeouts**: Configure execution timeouts for tools and operations
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
- Tool execution runs with configurable timeouts to prevent blocking
|
||||
- Efficient client tracking and cleanup to prevent resource leaks
|
||||
- Proper concurrency handling with mutex protection for shared resources
|
||||
- Buffered message channels to prevent blocking on client message delivery
|
||||
|
||||
290
mcp/server.go
290
mcp/server.go
@@ -42,14 +42,14 @@ func NewMcpServer(c McpConf) McpServer {
|
||||
Method: http.MethodGet,
|
||||
Path: s.conf.Mcp.SseEndpoint,
|
||||
Handler: s.handleSSE,
|
||||
}, rest.WithSSE())
|
||||
}, rest.WithSSE(), rest.WithTimeout(c.Mcp.SseTimeout))
|
||||
|
||||
// JSON-RPC message endpoint for regular requests
|
||||
s.server.AddRoute(rest.Route{
|
||||
Method: http.MethodPost,
|
||||
Path: s.conf.Mcp.MessageEndpoint,
|
||||
Handler: s.handleRequest,
|
||||
})
|
||||
}, rest.WithTimeout(c.Mcp.MessageTimeout))
|
||||
|
||||
return s
|
||||
}
|
||||
@@ -173,30 +173,35 @@ func (s *sseMcpServer) handleRequest(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
|
||||
// For notification methods (no ID), we don't send a response
|
||||
isNotification := req.ID == 0
|
||||
isNotification, err := req.isNotification()
|
||||
if err != nil {
|
||||
http.Error(w, "Invalid request.ID", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
|
||||
// Special handling for initialization sequence
|
||||
// Always allow initialize and notifications/initialized regardless of client state
|
||||
if req.Method == methodInitialize {
|
||||
logx.Infof("Processing initialize request with ID: %d", req.ID)
|
||||
s.processInitialize(client, req)
|
||||
logx.Infof("Sent initialize response for ID: %d, waiting for notifications/initialized", req.ID)
|
||||
logx.Infof("Processing initialize request with ID: %v", req.ID)
|
||||
s.processInitialize(r.Context(), client, req)
|
||||
logx.Infof("Sent initialize response for ID: %v, waiting for notifications/initialized", req.ID)
|
||||
return
|
||||
} else if req.Method == methodNotificationsInitialized {
|
||||
// Handle initialized notification
|
||||
logx.Info("Received notifications/initialized notification")
|
||||
if !isNotification {
|
||||
s.sendErrorResponse(client, req.ID, "Method should be used as a notification", errCodeInvalidRequest)
|
||||
s.sendErrorResponse(r.Context(), client, req.ID,
|
||||
"Method should be used as a notification", errCodeInvalidRequest)
|
||||
return
|
||||
}
|
||||
s.processNotificationInitialized(client)
|
||||
return
|
||||
} else if !client.initialized && req.Method != methodNotificationsCancelled {
|
||||
// Block most requests until client is initialized (except for cancellations)
|
||||
s.sendErrorResponse(client, req.ID, "Client not fully initialized, waiting for notifications/initialized",
|
||||
s.sendErrorResponse(r.Context(), client, req.ID,
|
||||
"Client not fully initialized, waiting for notifications/initialized",
|
||||
errCodeClientNotInitialized)
|
||||
return
|
||||
}
|
||||
@@ -204,42 +209,42 @@ func (s *sseMcpServer) handleRequest(w http.ResponseWriter, r *http.Request) {
|
||||
// Process normal requests only after initialization
|
||||
switch req.Method {
|
||||
case methodToolsCall:
|
||||
logx.Infof("Received tools call request with ID: %d", req.ID)
|
||||
s.processToolCall(client, req)
|
||||
logx.Infof("Sent tools call response for ID: %d", req.ID)
|
||||
logx.Infof("Received tools call request with ID: %v", req.ID)
|
||||
s.processToolCall(r.Context(), client, req)
|
||||
logx.Infof("Sent tools call response for ID: %v", req.ID)
|
||||
case methodToolsList:
|
||||
logx.Infof("Processing tools/list request with ID: %d", req.ID)
|
||||
s.processListTools(client, req)
|
||||
logx.Infof("Sent tools/list response for ID: %d", req.ID)
|
||||
logx.Infof("Processing tools/list request with ID: %v", req.ID)
|
||||
s.processListTools(r.Context(), client, req)
|
||||
logx.Infof("Sent tools/list response for ID: %v", req.ID)
|
||||
case methodPromptsList:
|
||||
logx.Infof("Processing prompts/list request with ID: %d", req.ID)
|
||||
s.processListPrompts(client, req)
|
||||
logx.Infof("Sent prompts/list response for ID: %d", req.ID)
|
||||
logx.Infof("Processing prompts/list request with ID: %v", req.ID)
|
||||
s.processListPrompts(r.Context(), client, req)
|
||||
logx.Infof("Sent prompts/list response for ID: %v", req.ID)
|
||||
case methodPromptsGet:
|
||||
logx.Infof("Processing prompts/get request with ID: %d", req.ID)
|
||||
s.processGetPrompt(client, req)
|
||||
logx.Infof("Sent prompts/get response for ID: %d", req.ID)
|
||||
logx.Infof("Processing prompts/get request with ID: %v", req.ID)
|
||||
s.processGetPrompt(r.Context(), client, req)
|
||||
logx.Infof("Sent prompts/get response for ID: %v", req.ID)
|
||||
case methodResourcesList:
|
||||
logx.Infof("Processing resources/list request with ID: %d", req.ID)
|
||||
s.processListResources(client, req)
|
||||
logx.Infof("Sent resources/list response for ID: %d", req.ID)
|
||||
logx.Infof("Processing resources/list request with ID: %v", req.ID)
|
||||
s.processListResources(r.Context(), client, req)
|
||||
logx.Infof("Sent resources/list response for ID: %v", req.ID)
|
||||
case methodResourcesRead:
|
||||
logx.Infof("Processing resources/read request with ID: %d", req.ID)
|
||||
s.processResourcesRead(client, req)
|
||||
logx.Infof("Sent resources/read response for ID: %d", req.ID)
|
||||
logx.Infof("Processing resources/read request with ID: %v", req.ID)
|
||||
s.processResourcesRead(r.Context(), client, req)
|
||||
logx.Infof("Sent resources/read response for ID: %v", req.ID)
|
||||
case methodResourcesSubscribe:
|
||||
logx.Infof("Processing resources/subscribe request with ID: %d", req.ID)
|
||||
s.processResourceSubscribe(client, req)
|
||||
logx.Infof("Sent resources/subscribe response for ID: %d", req.ID)
|
||||
logx.Infof("Processing resources/subscribe request with ID: %v", req.ID)
|
||||
s.processResourceSubscribe(r.Context(), client, req)
|
||||
logx.Infof("Sent resources/subscribe response for ID: %v", req.ID)
|
||||
case methodPing:
|
||||
logx.Infof("Processing ping request with ID: %d", req.ID)
|
||||
s.processPing(client, req)
|
||||
logx.Infof("Processing ping request with ID: %v", req.ID)
|
||||
s.processPing(r.Context(), client, req)
|
||||
case methodNotificationsCancelled:
|
||||
logx.Infof("Received notifications/cancelled notification: %v", req.Params)
|
||||
s.processNotificationCancelled(client, req)
|
||||
logx.Infof("Received notifications/cancelled notification: %v", req.ID)
|
||||
s.processNotificationCancelled(r.Context(), client, req)
|
||||
default:
|
||||
logx.Infof("Unknown method: %s", req.Method)
|
||||
s.sendErrorResponse(client, req.ID, "Method not found", errCodeMethodNotFound)
|
||||
logx.Infof("Unknown method: %s from client: %v", req.Method, req.ID)
|
||||
s.sendErrorResponse(r.Context(), client, req.ID, "Method not found", errCodeMethodNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -321,7 +326,7 @@ func (s *sseMcpServer) handleSSE(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
flusher.Flush()
|
||||
case <-r.Context().Done():
|
||||
// Client disconnected or request was canceled
|
||||
// Client disconnected or request was canceled or timed out
|
||||
logx.Infof("Client %s disconnected: context done", sessionId)
|
||||
return
|
||||
}
|
||||
@@ -329,7 +334,7 @@ func (s *sseMcpServer) handleSSE(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// processInitialize processes the initialize request
|
||||
func (s *sseMcpServer) processInitialize(client *mcpClient, req Request) {
|
||||
func (s *sseMcpServer) processInitialize(ctx context.Context, client *mcpClient, req Request) {
|
||||
// Create a proper JSON-RPC response that preserves the client's request ID
|
||||
result := initializationResponse{
|
||||
ProtocolVersion: s.conf.Mcp.ProtocolVersion,
|
||||
@@ -362,11 +367,11 @@ func (s *sseMcpServer) processInitialize(client *mcpClient, req Request) {
|
||||
client.initialized = true
|
||||
|
||||
// Send response with client's original request ID
|
||||
s.sendResponse(client, req.ID, result)
|
||||
s.sendResponse(ctx, client, req.ID, result)
|
||||
}
|
||||
|
||||
// processListTools processes the tools/list request
|
||||
func (s *sseMcpServer) processListTools(client *mcpClient, req Request) {
|
||||
func (s *sseMcpServer) processListTools(ctx context.Context, client *mcpClient, req Request) {
|
||||
// Extract pagination params if any
|
||||
var nextCursor string
|
||||
var progressToken any
|
||||
@@ -390,6 +395,9 @@ func (s *sseMcpServer) processListTools(client *mcpClient, req Request) {
|
||||
var toolsList []Tool
|
||||
s.toolsLock.Lock()
|
||||
for _, tool := range s.tools {
|
||||
if len(tool.InputSchema.Type) == 0 {
|
||||
tool.InputSchema.Type = ContentTypeObject
|
||||
}
|
||||
toolsList = append(toolsList, tool)
|
||||
}
|
||||
s.toolsLock.Unlock()
|
||||
@@ -405,15 +413,15 @@ func (s *sseMcpServer) processListTools(client *mcpClient, req Request) {
|
||||
// Add meta information if progress token was provided
|
||||
if progressToken != nil {
|
||||
result.Result.Meta = map[string]any{
|
||||
"progressToken": progressToken,
|
||||
progressTokenKey: progressToken,
|
||||
}
|
||||
}
|
||||
|
||||
s.sendResponse(client, req.ID, result)
|
||||
s.sendResponse(ctx, client, req.ID, result)
|
||||
}
|
||||
|
||||
// processListPrompts processes the prompts/list request
|
||||
func (s *sseMcpServer) processListPrompts(client *mcpClient, req Request) {
|
||||
func (s *sseMcpServer) processListPrompts(ctx context.Context, client *mcpClient, req Request) {
|
||||
// Extract pagination params if any
|
||||
var nextCursor string
|
||||
if req.Params != nil {
|
||||
@@ -447,11 +455,11 @@ func (s *sseMcpServer) processListPrompts(client *mcpClient, req Request) {
|
||||
NextCursor: nextCursor,
|
||||
}
|
||||
|
||||
s.sendResponse(client, req.ID, result)
|
||||
s.sendResponse(ctx, client, req.ID, result)
|
||||
}
|
||||
|
||||
// processListResources processes the resources/list request
|
||||
func (s *sseMcpServer) processListResources(client *mcpClient, req Request) {
|
||||
func (s *sseMcpServer) processListResources(ctx context.Context, client *mcpClient, req Request) {
|
||||
// Extract pagination params if any
|
||||
var nextCursor string
|
||||
var progressToken any
|
||||
@@ -493,15 +501,15 @@ func (s *sseMcpServer) processListResources(client *mcpClient, req Request) {
|
||||
// Add meta information if progress token was provided
|
||||
if progressToken != nil {
|
||||
result.Result.Meta = map[string]any{
|
||||
"progressToken": progressToken,
|
||||
progressTokenKey: progressToken,
|
||||
}
|
||||
}
|
||||
|
||||
s.sendResponse(client, req.ID, result)
|
||||
s.sendResponse(ctx, client, req.ID, result)
|
||||
}
|
||||
|
||||
// processGetPrompt processes the prompts/get request
|
||||
func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) {
|
||||
func (s *sseMcpServer) processGetPrompt(ctx context.Context, client *mcpClient, req Request) {
|
||||
type GetPromptParams struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]string `json:"arguments,omitempty"`
|
||||
@@ -509,7 +517,7 @@ func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) {
|
||||
|
||||
var params GetPromptParams
|
||||
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
||||
s.sendErrorResponse(client, req.ID, "Invalid parameters", errCodeInvalidParams)
|
||||
s.sendErrorResponse(ctx, client, req.ID, "Invalid parameters", errCodeInvalidParams)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -519,7 +527,7 @@ func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) {
|
||||
s.promptsLock.Unlock()
|
||||
if !exists {
|
||||
message := fmt.Sprintf("Prompt '%s' not found", params.Name)
|
||||
s.sendErrorResponse(client, req.ID, message, errCodeInvalidParams)
|
||||
s.sendErrorResponse(ctx, client, req.ID, message, errCodeInvalidParams)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -529,12 +537,15 @@ func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) {
|
||||
missingArgs := validatePromptArguments(prompt, params.Arguments)
|
||||
if len(missingArgs) > 0 {
|
||||
message := fmt.Sprintf("Missing required arguments: %s", strings.Join(missingArgs, ", "))
|
||||
s.sendErrorResponse(client, req.ID, message, errCodeInvalidParams)
|
||||
s.sendErrorResponse(ctx, client, req.ID, message, errCodeInvalidParams)
|
||||
return
|
||||
}
|
||||
|
||||
// Apply default values for missing optional arguments
|
||||
args := applyDefaultArguments(prompt, params.Arguments)
|
||||
// Ensure arguments are initialized to an empty map if nil
|
||||
if params.Arguments == nil {
|
||||
params.Arguments = make(map[string]string)
|
||||
}
|
||||
args := params.Arguments
|
||||
|
||||
// Generate messages using handler or static content
|
||||
var messages []PromptMessage
|
||||
@@ -542,17 +553,17 @@ func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) {
|
||||
|
||||
if prompt.Handler != nil {
|
||||
// Use dynamic handler to generate messages
|
||||
logx.Info("Using prompt handler to generate content")
|
||||
messages, err = prompt.Handler(args)
|
||||
messages, err = prompt.Handler(ctx, args)
|
||||
if err != nil {
|
||||
logx.Errorf("Error from prompt handler: %v", err)
|
||||
s.sendErrorResponse(client, req.ID, fmt.Sprintf("Error generating prompt content: %v", err), errCodeInternalError)
|
||||
s.sendErrorResponse(ctx, client, req.ID,
|
||||
fmt.Sprintf("Error generating prompt content: %v", err), errCodeInternalError)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// No handler, generate messages from static content
|
||||
var messageText string
|
||||
if prompt.Content != "" {
|
||||
if len(prompt.Content) > 0 {
|
||||
messageText = prompt.Content
|
||||
|
||||
// Apply argument substitutions to static content
|
||||
@@ -560,21 +571,13 @@ func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) {
|
||||
placeholder := fmt.Sprintf("{{%s}}", key)
|
||||
messageText = strings.Replace(messageText, placeholder, value, -1)
|
||||
}
|
||||
} else {
|
||||
// No content, use a default fallback
|
||||
topic := "this topic"
|
||||
if t, ok := args["topic"]; ok && t != "" {
|
||||
topic = t
|
||||
}
|
||||
messageText = fmt.Sprintf("Tell me about %s", topic)
|
||||
}
|
||||
|
||||
// Create a single user message with the content
|
||||
messages = []PromptMessage{
|
||||
{
|
||||
Role: roleUser,
|
||||
Role: RoleUser,
|
||||
Content: TextContent{
|
||||
Type: contentTypeText,
|
||||
Text: messageText,
|
||||
},
|
||||
},
|
||||
@@ -587,49 +590,14 @@ func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) {
|
||||
Messages []PromptMessage `json:"messages"`
|
||||
}{
|
||||
Description: prompt.Description,
|
||||
Messages: messages,
|
||||
Messages: toTypedPromptMessages(messages),
|
||||
}
|
||||
|
||||
s.sendResponse(client, req.ID, result)
|
||||
}
|
||||
|
||||
// validatePromptArguments checks if all required arguments are provided
|
||||
// Returns a list of missing required arguments
|
||||
func validatePromptArguments(prompt Prompt, providedArgs map[string]string) []string {
|
||||
var missingArgs []string
|
||||
|
||||
for _, arg := range prompt.Arguments {
|
||||
if arg.Required {
|
||||
if value, exists := providedArgs[arg.Name]; !exists || len(value) == 0 {
|
||||
missingArgs = append(missingArgs, arg.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return missingArgs
|
||||
}
|
||||
|
||||
// applyDefaultArguments adds default values for missing optional arguments
|
||||
func applyDefaultArguments(prompt Prompt, providedArgs map[string]string) map[string]string {
|
||||
result := make(map[string]string)
|
||||
|
||||
// Copy all provided arguments
|
||||
for k, v := range providedArgs {
|
||||
result[k] = v
|
||||
}
|
||||
|
||||
// Add defaults for missing arguments
|
||||
for _, arg := range prompt.Arguments {
|
||||
if _, exists := result[arg.Name]; !exists && arg.Default != "" {
|
||||
result[arg.Name] = arg.Default
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
s.sendResponse(ctx, client, req.ID, result)
|
||||
}
|
||||
|
||||
// processToolCall processes the tools/call request
|
||||
func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
||||
func (s *sseMcpServer) processToolCall(ctx context.Context, client *mcpClient, req Request) {
|
||||
var toolCallParams struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]any `json:"arguments,omitempty"`
|
||||
@@ -642,7 +610,7 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
||||
// If it's a RawMessage (JSON), unmarshal it
|
||||
if err := json.Unmarshal(req.Params, &toolCallParams); err != nil {
|
||||
logx.Errorf("Failed to unmarshal tool call params: %v", err)
|
||||
s.sendErrorResponse(client, req.ID, "Invalid tool call parameters", errCodeInvalidParams)
|
||||
s.sendErrorResponse(ctx, client, req.ID, "Invalid tool call parameters", errCodeInvalidParams)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -654,15 +622,11 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
||||
tool, exists := s.tools[toolCallParams.Name]
|
||||
s.toolsLock.Unlock()
|
||||
if !exists {
|
||||
s.sendErrorResponse(client, req.ID, fmt.Sprintf("Tool '%s' not found",
|
||||
s.sendErrorResponse(ctx, client, req.ID, fmt.Sprintf("Tool '%s' not found",
|
||||
toolCallParams.Name), errCodeInvalidParams)
|
||||
return
|
||||
}
|
||||
|
||||
// Create a context with the configured timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), s.conf.Mcp.ToolTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Log parameters before execution
|
||||
logx.Infof("Executing tool '%s' with arguments: %#v", toolCallParams.Name, toolCallParams.Arguments)
|
||||
|
||||
@@ -671,6 +635,7 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
||||
var err error
|
||||
|
||||
// Create a channel to receive the result
|
||||
// make sure to have 1 size buffer to avoid channel leak if timeout
|
||||
resultCh := make(chan struct {
|
||||
result any
|
||||
err error
|
||||
@@ -678,7 +643,7 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
||||
|
||||
// Execute the tool handler in a goroutine
|
||||
go func() {
|
||||
toolResult, toolErr := tool.Handler(toolCallParams.Arguments)
|
||||
toolResult, toolErr := tool.Handler(ctx, toolCallParams.Arguments)
|
||||
resultCh <- struct {
|
||||
result any
|
||||
err error
|
||||
@@ -694,9 +659,9 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
||||
result = res.result
|
||||
err = res.err
|
||||
case <-ctx.Done():
|
||||
// Handle timeout
|
||||
logx.Errorf("Tool execution timed out after %v: %s", s.conf.Mcp.ToolTimeout, toolCallParams.Name)
|
||||
s.sendErrorResponse(client, req.ID, "Tool execution timed out", errCodeTimeout)
|
||||
// Handle request timeout
|
||||
logx.Errorf("Tool execution timed out after %v: %s", s.conf.Mcp.MessageTimeout, toolCallParams.Name)
|
||||
s.sendErrorResponse(ctx, client, req.ID, "Tool execution timed out", errCodeTimeout)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -710,7 +675,7 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
||||
// Add meta information if progress token was provided
|
||||
if progressToken != nil {
|
||||
callToolResult.Result.Meta = map[string]any{
|
||||
"progressToken": progressToken,
|
||||
progressTokenKey: progressToken,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -722,12 +687,11 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
||||
|
||||
callToolResult.Content = []any{
|
||||
TextContent{
|
||||
Type: contentTypeText,
|
||||
Text: fmt.Sprintf("Error: %v", err),
|
||||
},
|
||||
}
|
||||
callToolResult.IsError = true
|
||||
s.sendResponse(client, req.ID, callToolResult)
|
||||
s.sendResponse(ctx, client, req.ID, callToolResult)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -736,10 +700,9 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
||||
case string:
|
||||
// Simple string becomes text content
|
||||
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||
Type: contentTypeText,
|
||||
Text: v,
|
||||
Annotations: &Annotations{
|
||||
Audience: []roleType{roleUser, roleAssistant},
|
||||
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||
},
|
||||
})
|
||||
case map[string]any:
|
||||
@@ -749,69 +712,63 @@ func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
|
||||
jsonStr = []byte(err.Error())
|
||||
}
|
||||
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||
Type: contentTypeText,
|
||||
Text: string(jsonStr),
|
||||
Annotations: &Annotations{
|
||||
Audience: []roleType{roleUser, roleAssistant},
|
||||
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||
},
|
||||
})
|
||||
case TextContent:
|
||||
// Direct TextContent object
|
||||
callToolResult.Content = append(callToolResult.Content, v)
|
||||
case ImageContent:
|
||||
// Direct ImageContent object
|
||||
callToolResult.Content = append(callToolResult.Content, v)
|
||||
case []any:
|
||||
// Array of content items
|
||||
callToolResult.Content = v
|
||||
case ToolResult:
|
||||
// Handle legacy ToolResult type
|
||||
switch v.Type {
|
||||
case contentTypeText:
|
||||
case ContentTypeText:
|
||||
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||
Type: contentTypeText,
|
||||
Text: fmt.Sprintf("%v", v.Content),
|
||||
Annotations: &Annotations{
|
||||
Audience: []roleType{roleUser, roleAssistant},
|
||||
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||
},
|
||||
})
|
||||
case contentTypeImage:
|
||||
case ContentTypeImage:
|
||||
if imgData, ok := v.Content.(map[string]any); ok {
|
||||
callToolResult.Content = append(callToolResult.Content, ImageContent{
|
||||
Type: contentTypeImage,
|
||||
Data: fmt.Sprintf("%v", imgData["data"]),
|
||||
MimeType: fmt.Sprintf("%v", imgData["mimeType"]),
|
||||
})
|
||||
}
|
||||
default:
|
||||
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||
Type: contentTypeText,
|
||||
Text: fmt.Sprintf("%v", v.Content),
|
||||
Annotations: &Annotations{
|
||||
Audience: []roleType{roleUser, roleAssistant},
|
||||
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||
},
|
||||
})
|
||||
}
|
||||
default:
|
||||
// For any other type, convert to string
|
||||
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||
Type: contentTypeText,
|
||||
Text: fmt.Sprintf("%v", v),
|
||||
Annotations: &Annotations{
|
||||
Audience: []roleType{roleUser, roleAssistant},
|
||||
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
callToolResult.Content = toTypedContents(callToolResult.Content)
|
||||
logx.Infof("Tool call result: %#v", callToolResult)
|
||||
s.sendResponse(client, req.ID, callToolResult)
|
||||
|
||||
s.sendResponse(ctx, client, req.ID, callToolResult)
|
||||
}
|
||||
|
||||
// processResourcesRead processes the resources/read request
|
||||
func (s *sseMcpServer) processResourcesRead(client *mcpClient, req Request) {
|
||||
func (s *sseMcpServer) processResourcesRead(ctx context.Context, client *mcpClient, req Request) {
|
||||
var params ResourceReadParams
|
||||
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
||||
s.sendErrorResponse(client, req.ID, "Invalid parameters", errCodeInvalidParams)
|
||||
s.sendErrorResponse(ctx, client, req.ID, "Invalid parameters", errCodeInvalidParams)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -821,7 +778,7 @@ func (s *sseMcpServer) processResourcesRead(client *mcpClient, req Request) {
|
||||
s.resourcesLock.Unlock()
|
||||
|
||||
if !exists {
|
||||
s.sendErrorResponse(client, req.ID, fmt.Sprintf("Resource with URI '%s' not found",
|
||||
s.sendErrorResponse(ctx, client, req.ID, fmt.Sprintf("Resource with URI '%s' not found",
|
||||
params.URI), errCodeResourceNotFound)
|
||||
return
|
||||
}
|
||||
@@ -837,14 +794,14 @@ func (s *sseMcpServer) processResourcesRead(client *mcpClient, req Request) {
|
||||
},
|
||||
},
|
||||
}
|
||||
s.sendResponse(client, req.ID, result)
|
||||
s.sendResponse(ctx, client, req.ID, result)
|
||||
return
|
||||
}
|
||||
|
||||
// Execute the resource handler
|
||||
content, err := resource.Handler()
|
||||
content, err := resource.Handler(ctx)
|
||||
if err != nil {
|
||||
s.sendErrorResponse(client, req.ID, fmt.Sprintf("Error reading resource: %v", err),
|
||||
s.sendErrorResponse(ctx, client, req.ID, fmt.Sprintf("Error reading resource: %v", err),
|
||||
errCodeInternalError)
|
||||
return
|
||||
}
|
||||
@@ -855,7 +812,7 @@ func (s *sseMcpServer) processResourcesRead(client *mcpClient, req Request) {
|
||||
}
|
||||
|
||||
// Ensure MimeType is set if available from the resource definition
|
||||
if len(content.MimeType) == 0 && resource.MimeType != "" {
|
||||
if len(content.MimeType) == 0 && len(resource.MimeType) > 0 {
|
||||
content.MimeType = resource.MimeType
|
||||
}
|
||||
|
||||
@@ -865,14 +822,14 @@ func (s *sseMcpServer) processResourcesRead(client *mcpClient, req Request) {
|
||||
Contents: []ResourceContent{content},
|
||||
}
|
||||
|
||||
s.sendResponse(client, req.ID, result)
|
||||
s.sendResponse(ctx, client, req.ID, result)
|
||||
}
|
||||
|
||||
// processResourceSubscribe processes the resources/subscribe request
|
||||
func (s *sseMcpServer) processResourceSubscribe(client *mcpClient, req Request) {
|
||||
func (s *sseMcpServer) processResourceSubscribe(ctx context.Context, client *mcpClient, req Request) {
|
||||
var params ResourceSubscribeParams
|
||||
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
||||
s.sendErrorResponse(client, req.ID, "Invalid parameters", errCodeInvalidParams)
|
||||
s.sendErrorResponse(ctx, client, req.ID, "Invalid parameters", errCodeInvalidParams)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -882,19 +839,17 @@ func (s *sseMcpServer) processResourceSubscribe(client *mcpClient, req Request)
|
||||
s.resourcesLock.Unlock()
|
||||
|
||||
if !exists {
|
||||
s.sendErrorResponse(client, req.ID, fmt.Sprintf("Resource with URI '%s' not found",
|
||||
s.sendErrorResponse(ctx, client, req.ID, fmt.Sprintf("Resource with URI '%s' not found",
|
||||
params.URI), errCodeResourceNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Send success response for the subscription
|
||||
s.sendResponse(client, req.ID, struct{}{})
|
||||
|
||||
logx.Infof("Client %s subscribed to resource '%s'", client.id, params.URI)
|
||||
s.sendResponse(ctx, client, req.ID, struct{}{})
|
||||
}
|
||||
|
||||
// processNotificationCancelled processes the notifications/cancelled notification
|
||||
func (s *sseMcpServer) processNotificationCancelled(client *mcpClient, req Request) {
|
||||
func (s *sseMcpServer) processNotificationCancelled(ctx context.Context, client *mcpClient, req Request) {
|
||||
// Extract the requestId that was canceled
|
||||
type CancelParams struct {
|
||||
RequestId int64 `json:"requestId"`
|
||||
@@ -918,21 +873,20 @@ func (s *sseMcpServer) processNotificationInitialized(client *mcpClient) {
|
||||
}
|
||||
|
||||
// processPing processes the ping request and responds immediately
|
||||
func (s *sseMcpServer) processPing(client *mcpClient, req Request) {
|
||||
func (s *sseMcpServer) processPing(ctx context.Context, client *mcpClient, req Request) {
|
||||
// A ping request should simply respond with an empty result to confirm the server is alive
|
||||
logx.Infof("Received ping request with ID: %d", req.ID)
|
||||
|
||||
// Send an empty response with client's original request ID
|
||||
s.sendResponse(client, req.ID, struct{}{})
|
||||
|
||||
logx.Infof("Sent ping response for ID: %d", req.ID)
|
||||
s.sendResponse(ctx, client, req.ID, struct{}{})
|
||||
}
|
||||
|
||||
// sendErrorResponse sends an error response via the SSE channel
|
||||
func (s *sseMcpServer) sendErrorResponse(client *mcpClient, id int64, message string, code int) {
|
||||
func (s *sseMcpServer) sendErrorResponse(ctx context.Context, client *mcpClient,
|
||||
id any, message string, code int) {
|
||||
errorResponse := struct {
|
||||
JsonRpc string `json:"jsonrpc"`
|
||||
ID int64 `json:"id"`
|
||||
ID any `json:"id"`
|
||||
Error errorMessage `json:"error"`
|
||||
}{
|
||||
JsonRpc: jsonRpcVersion,
|
||||
@@ -947,13 +901,19 @@ func (s *sseMcpServer) sendErrorResponse(client *mcpClient, id int64, message st
|
||||
jsonData, _ := json.Marshal(errorResponse)
|
||||
// Use CRLF line endings as requested
|
||||
sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData))
|
||||
logx.Infof("Sending error for ID %d: %s", id, sseMessage)
|
||||
logx.Infof("Sending error for ID %v: %s", id, sseMessage)
|
||||
|
||||
client.channel <- sseMessage
|
||||
// cannot receive from ctx.Done() because we're sending to the channel for SSE messages
|
||||
select {
|
||||
case client.channel <- sseMessage:
|
||||
default:
|
||||
// Channel buffer is full, log warning and continue
|
||||
logx.Infof("Client %s channel is full while sending error response with ID %d", client.id, id)
|
||||
}
|
||||
}
|
||||
|
||||
// sendResponse sends a success response via the SSE channel
|
||||
func (s *sseMcpServer) sendResponse(client *mcpClient, id int64, result any) {
|
||||
func (s *sseMcpServer) sendResponse(ctx context.Context, client *mcpClient, id any, result any) {
|
||||
response := Response{
|
||||
JsonRpc: jsonRpcVersion,
|
||||
ID: id,
|
||||
@@ -962,13 +922,19 @@ func (s *sseMcpServer) sendResponse(client *mcpClient, id int64, result any) {
|
||||
|
||||
jsonData, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
s.sendErrorResponse(client, id, "Failed to marshal response", errCodeInternalError)
|
||||
s.sendErrorResponse(ctx, client, id, "Failed to marshal response", errCodeInternalError)
|
||||
return
|
||||
}
|
||||
|
||||
// Use CRLF line endings as requested
|
||||
sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData))
|
||||
logx.Infof("Sending response for ID %d: %s", id, sseMessage)
|
||||
logx.Infof("Sending response for ID %v: %s", id, sseMessage)
|
||||
|
||||
client.channel <- sseMessage
|
||||
// cannot receive from ctx.Done() because we're sending to the channel for SSE messages
|
||||
select {
|
||||
case client.channel <- sseMessage:
|
||||
default:
|
||||
// Channel buffer is full, log warning and continue
|
||||
logx.Infof("Client %s channel is full while sending response with ID %v", client.id, id)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -34,7 +34,7 @@ host: localhost
|
||||
port: 8080
|
||||
mcp:
|
||||
name: mcp-test-server
|
||||
toolTimeout: 5s
|
||||
messageTimeout: 5s
|
||||
`
|
||||
|
||||
var c McpConf
|
||||
@@ -82,7 +82,6 @@ func (m *mockMcpServer) registerExampleTool() {
|
||||
Name: "test.tool",
|
||||
Description: "A test tool",
|
||||
InputSchema: InputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{
|
||||
"input": map[string]any{
|
||||
"type": "string",
|
||||
@@ -91,7 +90,7 @@ func (m *mockMcpServer) registerExampleTool() {
|
||||
},
|
||||
Required: []string{"input"},
|
||||
},
|
||||
Handler: func(params map[string]any) (any, error) {
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
input, ok := params["input"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid input parameter")
|
||||
@@ -135,7 +134,7 @@ port: 8080
|
||||
mcp:
|
||||
cors:
|
||||
- http://localhost:3000
|
||||
toolTimeout: 5s
|
||||
messageTimeout: 5s
|
||||
`
|
||||
|
||||
var c McpConf
|
||||
@@ -176,6 +175,20 @@ func TestHandleRequest_badRequest(t *testing.T) {
|
||||
mock.server.handleRequest(w, r)
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
})
|
||||
|
||||
t.Run("bad id", func(t *testing.T) {
|
||||
mock := newMockMcpServer(t)
|
||||
defer mock.shutdown()
|
||||
|
||||
addTestClient(mock.server, "test-session", true)
|
||||
|
||||
body := `{"jsonrpc": "2.0", "id": {}, "method": "tools.call", "params": {}}`
|
||||
r := httptest.NewRequest(http.MethodPost, "/?session_id=test-session", bytes.NewReader([]byte(body)))
|
||||
w := httptest.NewRecorder()
|
||||
mock.server.handleRequest(w, r)
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "Invalid request.ID")
|
||||
})
|
||||
}
|
||||
|
||||
func TestRegisterTool(t *testing.T) {
|
||||
@@ -186,7 +199,6 @@ func TestRegisterTool(t *testing.T) {
|
||||
Name: "example.tool",
|
||||
Description: "An example tool",
|
||||
InputSchema: InputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{
|
||||
"input": map[string]any{
|
||||
"type": "string",
|
||||
@@ -194,7 +206,7 @@ func TestRegisterTool(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
Handler: func(params map[string]any) (any, error) {
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
return "result", nil
|
||||
},
|
||||
}
|
||||
@@ -280,7 +292,7 @@ func TestToolsList(t *testing.T) {
|
||||
}
|
||||
|
||||
// Process the tool call
|
||||
mock.server.processListTools(client, req)
|
||||
mock.server.processListTools(context.Background(), client, req)
|
||||
|
||||
// Get the response from the client's channel
|
||||
select {
|
||||
@@ -328,7 +340,7 @@ func TestToolCallBasic(t *testing.T) {
|
||||
}
|
||||
|
||||
// Process the tool call
|
||||
mock.server.processToolCall(client, req)
|
||||
mock.server.processToolCall(context.Background(), client, req)
|
||||
|
||||
// Get the response from the client's channel
|
||||
select {
|
||||
@@ -355,8 +367,7 @@ func TestToolCallBasic(t *testing.T) {
|
||||
|
||||
// Verify the response content
|
||||
assert.Len(t, parsed.Result.Content, 1, "Response should contain one content item")
|
||||
assert.Equal(t, "text", parsed.Result.Content[0]["type"], "Content type should be text")
|
||||
assert.Equal(t, "Processed: test-input", parsed.Result.Content[0]["text"], "Tool result incorrect")
|
||||
assert.Equal(t, "Processed: test-input", parsed.Result.Content[0][ContentTypeText], "Tool result incorrect")
|
||||
assert.False(t, parsed.Result.IsError, "Response should not be an error")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for tool call response")
|
||||
@@ -373,10 +384,9 @@ func TestToolCallMapResult(t *testing.T) {
|
||||
Name: "map.tool",
|
||||
Description: "A tool that returns a map result",
|
||||
InputSchema: InputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{},
|
||||
},
|
||||
Handler: func(params map[string]any) (any, error) {
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
// Return a complex nested map structure
|
||||
return map[string]any{
|
||||
"string": "value",
|
||||
@@ -417,7 +427,7 @@ func TestToolCallMapResult(t *testing.T) {
|
||||
}
|
||||
|
||||
// Process the tool call
|
||||
mock.server.processToolCall(client, req)
|
||||
mock.server.processToolCall(context.Background(), client, req)
|
||||
|
||||
// Get the response from the client's channel
|
||||
select {
|
||||
@@ -445,13 +455,8 @@ func TestToolCallMapResult(t *testing.T) {
|
||||
firstItem, ok := content[0].(map[string]any)
|
||||
require.True(t, ok, "First content item should be an object")
|
||||
|
||||
// Verify it's a text content
|
||||
contentType, ok := firstItem["type"].(string)
|
||||
require.True(t, ok, "Content should have a type")
|
||||
assert.Equal(t, "text", contentType, "Content type should be text")
|
||||
|
||||
// Get the text content which should be our JSON
|
||||
text, ok := firstItem["text"].(string)
|
||||
text, ok := firstItem[ContentTypeText].(string)
|
||||
require.True(t, ok, "Content should have text")
|
||||
|
||||
// Verify the text is valid JSON and contains our data
|
||||
@@ -496,10 +501,9 @@ func TestToolCallArrayResult(t *testing.T) {
|
||||
Name: "array.tool",
|
||||
Description: "A tool that returns an array result",
|
||||
InputSchema: InputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{},
|
||||
},
|
||||
Handler: func(params map[string]any) (any, error) {
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
// Return an array of mixed content types
|
||||
return []any{
|
||||
"string item",
|
||||
@@ -536,7 +540,7 @@ func TestToolCallArrayResult(t *testing.T) {
|
||||
}
|
||||
|
||||
// Process the tool call
|
||||
mock.server.processToolCall(client, req)
|
||||
mock.server.processToolCall(context.Background(), client, req)
|
||||
|
||||
// Get the response from the client's channel
|
||||
select {
|
||||
@@ -574,16 +578,14 @@ func TestToolCallTextContentResult(t *testing.T) {
|
||||
Name: "text.content.tool",
|
||||
Description: "A tool that returns a TextContent result",
|
||||
InputSchema: InputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{},
|
||||
},
|
||||
Handler: func(params map[string]any) (any, error) {
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
// Return a TextContent object directly
|
||||
return TextContent{
|
||||
Type: "text",
|
||||
Text: "This is a direct TextContent result",
|
||||
Annotations: &Annotations{
|
||||
Audience: []roleType{roleUser, roleAssistant},
|
||||
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||
Priority: func() *float64 { p := 0.9; return &p }(),
|
||||
},
|
||||
}, nil
|
||||
@@ -614,7 +616,7 @@ func TestToolCallTextContentResult(t *testing.T) {
|
||||
}
|
||||
|
||||
// Process the tool call
|
||||
mock.server.processToolCall(client, req)
|
||||
mock.server.processToolCall(context.Background(), client, req)
|
||||
|
||||
// Get the response from the client's channel
|
||||
select {
|
||||
@@ -642,16 +644,6 @@ func TestToolCallTextContentResult(t *testing.T) {
|
||||
firstItem, ok := content[0].(map[string]any)
|
||||
require.True(t, ok, "First content item should be an object")
|
||||
|
||||
// Verify it's a text content
|
||||
contentType, ok := firstItem["type"].(string)
|
||||
require.True(t, ok, "Content should have a type")
|
||||
assert.Equal(t, "text", contentType, "Content type should be text")
|
||||
|
||||
// Check text content
|
||||
text, ok := firstItem["text"].(string)
|
||||
require.True(t, ok, "Content should have text")
|
||||
assert.Equal(t, "This is a direct TextContent result", text, "Text content should match")
|
||||
|
||||
// Check annotations
|
||||
annotations, ok := firstItem["annotations"].(map[string]any)
|
||||
require.True(t, ok, "Should have annotations")
|
||||
@@ -679,13 +671,11 @@ func TestToolCallImageContentResult(t *testing.T) {
|
||||
Name: "image.content.tool",
|
||||
Description: "A tool that returns an ImageContent result",
|
||||
InputSchema: InputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{},
|
||||
},
|
||||
Handler: func(params map[string]any) (any, error) {
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
// Return an ImageContent object directly
|
||||
return ImageContent{
|
||||
Type: "image",
|
||||
Data: "dGVzdCBpbWFnZSBkYXRhIChiYXNlNjQgZW5jb2RlZCk=", // "test image data (base64 encoded)" in base64
|
||||
MimeType: "image/png",
|
||||
}, nil
|
||||
@@ -716,7 +706,7 @@ func TestToolCallImageContentResult(t *testing.T) {
|
||||
}
|
||||
|
||||
// Process the tool call
|
||||
mock.server.processToolCall(client, req)
|
||||
mock.server.processToolCall(context.Background(), client, req)
|
||||
|
||||
// Get the response from the client's channel
|
||||
select {
|
||||
@@ -744,11 +734,6 @@ func TestToolCallImageContentResult(t *testing.T) {
|
||||
firstItem, ok := content[0].(map[string]any)
|
||||
require.True(t, ok, "First content item should be an object")
|
||||
|
||||
// Verify it's an image content
|
||||
contentType, ok := firstItem["type"].(string)
|
||||
require.True(t, ok, "Content should have a type")
|
||||
assert.Equal(t, "image", contentType, "Content type should be image")
|
||||
|
||||
// Check image data
|
||||
data, ok := firstItem["data"].(string)
|
||||
require.True(t, ok, "Content should have data")
|
||||
@@ -773,12 +758,12 @@ func TestToolCallToolResultType(t *testing.T) {
|
||||
Name: "toolresult.tool",
|
||||
Description: "A tool that returns a ToolResult object",
|
||||
InputSchema: InputSchema{
|
||||
Type: "object",
|
||||
Type: ContentTypeObject,
|
||||
Properties: map[string]any{},
|
||||
},
|
||||
Handler: func(params map[string]any) (any, error) {
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
return ToolResult{
|
||||
Type: "text",
|
||||
Type: ContentTypeText,
|
||||
Content: "This is a ToolResult with text content type",
|
||||
}, nil
|
||||
},
|
||||
@@ -790,10 +775,10 @@ func TestToolCallToolResultType(t *testing.T) {
|
||||
Name: "toolresult.image.tool",
|
||||
Description: "A tool that returns a ToolResult with image content",
|
||||
InputSchema: InputSchema{
|
||||
Type: "object",
|
||||
Type: ContentTypeObject,
|
||||
Properties: map[string]any{},
|
||||
},
|
||||
Handler: func(params map[string]any) (any, error) {
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
return ToolResult{
|
||||
Type: "image",
|
||||
Content: map[string]any{
|
||||
@@ -810,10 +795,9 @@ func TestToolCallToolResultType(t *testing.T) {
|
||||
Name: "toolresult.audio.tool",
|
||||
Description: "A tool that returns a ToolResult with audio content",
|
||||
InputSchema: InputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{},
|
||||
},
|
||||
Handler: func(params map[string]any) (any, error) {
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
// Test with image type
|
||||
return ToolResult{
|
||||
Type: "audio",
|
||||
@@ -831,10 +815,9 @@ func TestToolCallToolResultType(t *testing.T) {
|
||||
Name: "toolresult.int.tool",
|
||||
Description: "A tool that returns a ToolResult with int content",
|
||||
InputSchema: InputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{},
|
||||
},
|
||||
Handler: func(params map[string]any) (any, error) {
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
return 2, nil
|
||||
},
|
||||
}
|
||||
@@ -845,10 +828,9 @@ func TestToolCallToolResultType(t *testing.T) {
|
||||
Name: "toolresult.bad.tool",
|
||||
Description: "A tool that returns a ToolResult with bad content",
|
||||
InputSchema: InputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{},
|
||||
},
|
||||
Handler: func(params map[string]any) (any, error) {
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
return map[string]any{
|
||||
"type": "custom",
|
||||
"data": make(chan int),
|
||||
@@ -881,7 +863,7 @@ func TestToolCallToolResultType(t *testing.T) {
|
||||
}
|
||||
|
||||
// Process the tool call
|
||||
mock.server.processToolCall(client, req)
|
||||
mock.server.processToolCall(context.Background(), client, req)
|
||||
|
||||
// Get the response from the client's channel
|
||||
select {
|
||||
@@ -909,13 +891,8 @@ func TestToolCallToolResultType(t *testing.T) {
|
||||
firstItem, ok := content[0].(map[string]any)
|
||||
require.True(t, ok, "First content item should be an object")
|
||||
|
||||
// Verify it's a text content
|
||||
contentType, ok := firstItem["type"].(string)
|
||||
require.True(t, ok, "Content should have a type")
|
||||
assert.Equal(t, "text", contentType, "Content type should be text")
|
||||
|
||||
// Check text content
|
||||
text, ok := firstItem["text"].(string)
|
||||
text, ok := firstItem[ContentTypeText].(string)
|
||||
require.True(t, ok, "Content should have text")
|
||||
assert.Equal(t, "This is a ToolResult with text content type", text, "Text content should match")
|
||||
|
||||
@@ -947,7 +924,7 @@ func TestToolCallToolResultType(t *testing.T) {
|
||||
}
|
||||
|
||||
// Process the tool call
|
||||
mock.server.processToolCall(client, req)
|
||||
mock.server.processToolCall(context.Background(), client, req)
|
||||
|
||||
// Get the response from the client's channel
|
||||
select {
|
||||
@@ -975,11 +952,6 @@ func TestToolCallToolResultType(t *testing.T) {
|
||||
firstItem, ok := content[0].(map[string]any)
|
||||
require.True(t, ok, "First content item should be an object")
|
||||
|
||||
// Verify it's an image content
|
||||
contentType, ok := firstItem["type"].(string)
|
||||
require.True(t, ok, "Content should have a type")
|
||||
assert.Equal(t, "image", contentType, "Content type should be image")
|
||||
|
||||
// Check image data and mime type
|
||||
data, ok := firstItem["data"].(string)
|
||||
require.True(t, ok, "Content should have data")
|
||||
@@ -1017,7 +989,7 @@ func TestToolCallToolResultType(t *testing.T) {
|
||||
}
|
||||
|
||||
// Process the tool call
|
||||
mock.server.processToolCall(client, req)
|
||||
mock.server.processToolCall(context.Background(), client, req)
|
||||
|
||||
// Get the response from the client's channel
|
||||
select {
|
||||
@@ -1040,15 +1012,6 @@ func TestToolCallToolResultType(t *testing.T) {
|
||||
content, ok := result["content"].([]any)
|
||||
require.True(t, ok, "Result should have a content array")
|
||||
require.NotEmpty(t, content, "Content should not be empty")
|
||||
|
||||
// The first content item should be converted from ToolResult to ImageContent
|
||||
firstItem, ok := content[0].(map[string]any)
|
||||
require.True(t, ok, "First content item should be an object")
|
||||
|
||||
// Verify it's an image content
|
||||
contentType, ok := firstItem["type"].(string)
|
||||
require.True(t, ok, "Content should have a type")
|
||||
assert.Equal(t, "text", contentType, "Content type should be image")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for tool call response")
|
||||
}
|
||||
@@ -1077,7 +1040,7 @@ func TestToolCallToolResultType(t *testing.T) {
|
||||
}
|
||||
|
||||
// Process the tool call
|
||||
mock.server.processToolCall(client, req)
|
||||
mock.server.processToolCall(context.Background(), client, req)
|
||||
|
||||
// Get the response from the client's channel
|
||||
select {
|
||||
@@ -1100,15 +1063,6 @@ func TestToolCallToolResultType(t *testing.T) {
|
||||
content, ok := result["content"].([]any)
|
||||
require.True(t, ok, "Result should have a content array")
|
||||
require.NotEmpty(t, content, "Content should not be empty")
|
||||
|
||||
// The first content item should be converted from ToolResult to ImageContent
|
||||
firstItem, ok := content[0].(map[string]any)
|
||||
require.True(t, ok, "First content item should be an object")
|
||||
|
||||
// Verify it's an image content
|
||||
contentType, ok := firstItem["type"].(string)
|
||||
require.True(t, ok, "Content should have a type")
|
||||
assert.Equal(t, "text", contentType, "Content type should be image")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for tool call response")
|
||||
}
|
||||
@@ -1137,7 +1091,7 @@ func TestToolCallToolResultType(t *testing.T) {
|
||||
}
|
||||
|
||||
// Process the tool call
|
||||
mock.server.processToolCall(client, req)
|
||||
mock.server.processToolCall(context.Background(), client, req)
|
||||
|
||||
// Get the response from the client's channel
|
||||
select {
|
||||
@@ -1159,10 +1113,9 @@ func TestToolCallError(t *testing.T) {
|
||||
Name: "error.tool",
|
||||
Description: "A tool that returns an error",
|
||||
InputSchema: InputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{},
|
||||
},
|
||||
Handler: func(params map[string]any) (any, error) {
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
return nil, fmt.Errorf("tool execution failed")
|
||||
},
|
||||
})
|
||||
@@ -1189,7 +1142,7 @@ func TestToolCallError(t *testing.T) {
|
||||
}
|
||||
|
||||
// Process the tool call
|
||||
mock.server.processToolCall(client, req)
|
||||
mock.server.processToolCall(context.Background(), client, req)
|
||||
|
||||
// Check the response
|
||||
select {
|
||||
@@ -1207,20 +1160,16 @@ func TestToolCallTimeout(t *testing.T) {
|
||||
mock := newMockMcpServer(t)
|
||||
defer mock.shutdown()
|
||||
|
||||
// Set a very short timeout for testing
|
||||
mock.server.conf.Mcp.ToolTimeout = 10 * time.Millisecond
|
||||
|
||||
// Register a tool that times out
|
||||
err := mock.server.RegisterTool(Tool{
|
||||
Name: "timeout.tool",
|
||||
Description: "A tool that times out",
|
||||
InputSchema: InputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{},
|
||||
},
|
||||
Handler: func(params map[string]any) (any, error) {
|
||||
time.Sleep(50 * time.Millisecond) // Sleep longer than timeout
|
||||
return "this should never be returned", nil
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
<-ctx.Done()
|
||||
return nil, fmt.Errorf("tool execution timed out")
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
@@ -1244,16 +1193,24 @@ func TestToolCallTimeout(t *testing.T) {
|
||||
Method: methodToolsCall,
|
||||
Params: paramBytes,
|
||||
}
|
||||
jsonBody, _ := json.Marshal(req)
|
||||
|
||||
// Process the tool call
|
||||
mock.server.processToolCall(client, req)
|
||||
// Create HTTP request
|
||||
r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client", bytes.NewReader(jsonBody))
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Millisecond)
|
||||
defer cancel()
|
||||
r = r.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Process through handleRequest
|
||||
go mock.server.handleRequest(w, r)
|
||||
|
||||
// Check the response
|
||||
select {
|
||||
case response := <-client.channel:
|
||||
assert.Contains(t, response, "event: message", "Response should have message event")
|
||||
assert.Contains(t, response, `-32001`, "Response should contain a timeout error code")
|
||||
case <-time.After(150 * time.Millisecond):
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for tool call response")
|
||||
}
|
||||
}
|
||||
@@ -1274,7 +1231,7 @@ func TestInitializeAndNotifications(t *testing.T) {
|
||||
Params: json.RawMessage(`{}`),
|
||||
}
|
||||
|
||||
mock.server.processInitialize(client, initReq)
|
||||
mock.server.processInitialize(context.Background(), client, initReq)
|
||||
|
||||
// Check that client is initialized after initialize request
|
||||
assert.True(t, client.initialized, "Client should be marked as initialized after initialize request")
|
||||
@@ -1418,7 +1375,7 @@ func TestNotificationCancelled_badParams(t *testing.T) {
|
||||
}
|
||||
|
||||
buf := logtest.NewCollector(t)
|
||||
mock.server.processNotificationCancelled(client, cancelReq)
|
||||
mock.server.processNotificationCancelled(context.Background(), client, cancelReq)
|
||||
|
||||
select {
|
||||
case <-client.channel:
|
||||
@@ -1593,7 +1550,7 @@ func TestGetPrompt(t *testing.T) {
|
||||
}
|
||||
|
||||
// Process the request
|
||||
mock.server.processGetPrompt(client, promptReq)
|
||||
mock.server.processGetPrompt(context.Background(), client, promptReq)
|
||||
|
||||
// Check response
|
||||
select {
|
||||
@@ -1622,7 +1579,7 @@ func TestGetPrompt(t *testing.T) {
|
||||
}
|
||||
|
||||
// Process the request
|
||||
mock.server.processGetPrompt(client, promptReq)
|
||||
mock.server.processGetPrompt(context.Background(), client, promptReq)
|
||||
|
||||
// Check response
|
||||
select {
|
||||
@@ -1636,6 +1593,44 @@ func TestGetPrompt(t *testing.T) {
|
||||
t.Fatal("Timed out waiting for prompt response")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("test prompt with nil params", func(t *testing.T) {
|
||||
mock := newMockMcpServer(t)
|
||||
defer mock.shutdown()
|
||||
|
||||
// Create a test client
|
||||
client := addTestClient(mock.server, "test-client", true)
|
||||
|
||||
// Register a test prompt
|
||||
testPrompt := Prompt{
|
||||
Name: "test.prompt",
|
||||
Description: "A test prompt",
|
||||
}
|
||||
mock.server.RegisterPrompt(testPrompt)
|
||||
|
||||
// Create a get prompt request
|
||||
paramBytes, _ := json.Marshal(map[string]any{
|
||||
"name": "test.prompt",
|
||||
})
|
||||
promptReq := Request{
|
||||
JsonRpc: jsonRpcVersion,
|
||||
ID: 1,
|
||||
Method: "prompts/get",
|
||||
Params: paramBytes,
|
||||
}
|
||||
|
||||
// Process the request
|
||||
mock.server.processGetPrompt(context.Background(), client, promptReq)
|
||||
|
||||
// Check response
|
||||
select {
|
||||
case response := <-client.channel:
|
||||
_, err := parseEvent(response)
|
||||
assert.NoError(t, err, "Should be able to parse event")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for prompt response")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestBroadcast tests the broadcast functionality
|
||||
@@ -1903,34 +1898,79 @@ func TestNotificationInitialized(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestSendBadResponse(t *testing.T) {
|
||||
mock := newMockMcpServer(t)
|
||||
defer mock.shutdown()
|
||||
func TestSendResponse(t *testing.T) {
|
||||
t.Run("bad response", func(t *testing.T) {
|
||||
mock := newMockMcpServer(t)
|
||||
defer mock.shutdown()
|
||||
|
||||
// Create a test client
|
||||
client := addTestClient(mock.server, "test-client", true)
|
||||
// Create a test client
|
||||
client := addTestClient(mock.server, "test-client", true)
|
||||
|
||||
// Create a response
|
||||
response := Response{
|
||||
JsonRpc: jsonRpcVersion,
|
||||
ID: 1,
|
||||
Result: make(chan int),
|
||||
}
|
||||
// Create a response
|
||||
response := Response{
|
||||
JsonRpc: jsonRpcVersion,
|
||||
ID: 1,
|
||||
Result: make(chan int),
|
||||
}
|
||||
|
||||
// Send the response
|
||||
mock.server.sendResponse(client, 1, response)
|
||||
// Send the response
|
||||
mock.server.sendResponse(context.Background(), client, 1, response)
|
||||
|
||||
// Check the response in the client's channel
|
||||
select {
|
||||
case res := <-client.channel:
|
||||
evt, err := parseEvent(res)
|
||||
require.NoError(t, err, "Should parse event without error")
|
||||
errMsg, ok := evt.Data["error"].(map[string]any)
|
||||
require.True(t, ok, "Should have error in response")
|
||||
assert.Equal(t, float64(errCodeInternalError), errMsg["code"], "Error code should match")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for response")
|
||||
}
|
||||
// Check the response in the client's channel
|
||||
select {
|
||||
case res := <-client.channel:
|
||||
evt, err := parseEvent(res)
|
||||
require.NoError(t, err, "Should parse event without error")
|
||||
errMsg, ok := evt.Data["error"].(map[string]any)
|
||||
require.True(t, ok, "Should have error in response")
|
||||
assert.Equal(t, float64(errCodeInternalError), errMsg["code"], "Error code should match")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for response")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("channel full", func(t *testing.T) {
|
||||
mock := newMockMcpServer(t)
|
||||
defer mock.shutdown()
|
||||
|
||||
// Create a test client
|
||||
client := addTestClient(mock.server, "test-client", true)
|
||||
for i := 0; i < eventChanSize; i++ {
|
||||
client.channel <- "test"
|
||||
}
|
||||
|
||||
// Create a response
|
||||
response := Response{
|
||||
JsonRpc: jsonRpcVersion,
|
||||
ID: 1,
|
||||
Result: "foo",
|
||||
}
|
||||
|
||||
buf := logtest.NewCollector(t)
|
||||
// Send the response
|
||||
mock.server.sendResponse(context.Background(), client, 1, response)
|
||||
// Check the response in the client's channel
|
||||
assert.Contains(t, buf.String(), "channel is full")
|
||||
})
|
||||
}
|
||||
|
||||
func TestSendErrorResponse(t *testing.T) {
|
||||
t.Run("channel full", func(t *testing.T) {
|
||||
mock := newMockMcpServer(t)
|
||||
defer mock.shutdown()
|
||||
|
||||
// Create a test client
|
||||
client := addTestClient(mock.server, "test-client", true)
|
||||
for i := 0; i < eventChanSize; i++ {
|
||||
client.channel <- "test"
|
||||
}
|
||||
|
||||
buf := logtest.NewCollector(t)
|
||||
// Send the response
|
||||
mock.server.sendErrorResponse(context.Background(), client, 1, "foo", errCodeInternalError)
|
||||
// Check the response in the client's channel
|
||||
assert.Contains(t, buf.String(), "channel is full")
|
||||
})
|
||||
}
|
||||
|
||||
// TestMethodToolsCall tests the handling of tools/call method through handleRequest
|
||||
@@ -2028,8 +2068,7 @@ func TestMethodToolsCall(t *testing.T) {
|
||||
if len(content) > 0 {
|
||||
firstItem, ok := content[0].(map[string]any)
|
||||
if ok {
|
||||
assert.Equal(t, "text", firstItem["type"], "Content type should be text")
|
||||
assert.Contains(t, firstItem["text"], "Processed: test-input", "Content should include processed input")
|
||||
assert.Contains(t, firstItem[ContentTypeText], "Processed: test-input", "Content should include processed input")
|
||||
}
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
@@ -2145,7 +2184,6 @@ func TestMethodPromptsGet(t *testing.T) {
|
||||
{
|
||||
Name: "topic",
|
||||
Description: "Topic to discuss",
|
||||
Default: "artificial intelligence",
|
||||
},
|
||||
},
|
||||
Content: "Hello {{name}}! Let's talk about {{topic}}.",
|
||||
@@ -2227,14 +2265,12 @@ func TestMethodPromptsGet(t *testing.T) {
|
||||
if len(messages) > 0 {
|
||||
message, ok := messages[0].(map[string]any)
|
||||
require.True(t, ok, "Message should be an object")
|
||||
assert.Equal(t, "user", message["role"], "Role should be 'user'")
|
||||
assert.Equal(t, string(RoleUser), message["role"], "Role should be 'user'")
|
||||
|
||||
content, ok := message["content"].(map[string]any)
|
||||
require.True(t, ok, "Should have content object")
|
||||
assert.Equal(t, "text", content["type"], "Content type should be text")
|
||||
assert.Contains(t, content["text"], "Hello Test User", "Content should include the name argument")
|
||||
assert.Contains(t, content["text"], "about artificial intelligence",
|
||||
"Content should include the default topic argument")
|
||||
assert.Equal(t, ContentTypeText, content["type"], "Content type should be text")
|
||||
assert.Contains(t, content[ContentTypeText], "Hello Test User", "Content should include the name argument")
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for prompt get response")
|
||||
@@ -2255,27 +2291,24 @@ func TestMethodPromptsGet(t *testing.T) {
|
||||
{
|
||||
Name: "question",
|
||||
Description: "User's question",
|
||||
Default: "How does this work?",
|
||||
},
|
||||
},
|
||||
Handler: func(args map[string]string) ([]PromptMessage, error) {
|
||||
Handler: func(ctx context.Context, args map[string]string) ([]PromptMessage, error) {
|
||||
username := args["username"]
|
||||
question := args["question"]
|
||||
|
||||
// Create a system message
|
||||
systemMessage := PromptMessage{
|
||||
Role: "system",
|
||||
Role: RoleAssistant,
|
||||
Content: TextContent{
|
||||
Type: "text",
|
||||
Text: "You are a helpful assistant.",
|
||||
},
|
||||
}
|
||||
|
||||
// Create a user message
|
||||
userMessage := PromptMessage{
|
||||
Role: "user",
|
||||
Role: RoleUser,
|
||||
Content: TextContent{
|
||||
Type: "text",
|
||||
Text: fmt.Sprintf("Hi, I'm %s and I'm wondering: %s", username, question),
|
||||
},
|
||||
}
|
||||
@@ -2340,20 +2373,20 @@ func TestMethodPromptsGet(t *testing.T) {
|
||||
|
||||
// Check message content
|
||||
if len(messages) >= 2 {
|
||||
// First message should be system
|
||||
// First message should be assistant
|
||||
message1, _ := messages[0].(map[string]any)
|
||||
assert.Equal(t, "system", message1["role"], "First role should be 'system'")
|
||||
assert.Equal(t, string(RoleAssistant), message1["role"], "First role should be 'system'")
|
||||
|
||||
content1, _ := message1["content"].(map[string]any)
|
||||
assert.Contains(t, content1["text"], "helpful assistant", "System message should be correct")
|
||||
assert.Contains(t, content1[ContentTypeText], "helpful assistant", "System message should be correct")
|
||||
|
||||
// Second message should be user
|
||||
message2, _ := messages[1].(map[string]any)
|
||||
assert.Equal(t, "user", message2["role"], "Second role should be 'user'")
|
||||
assert.Equal(t, string(RoleUser), message2["role"], "Second role should be 'user'")
|
||||
|
||||
content2, _ := message2["content"].(map[string]any)
|
||||
assert.Contains(t, content2["text"], "Dynamic User", "User message should contain username")
|
||||
assert.Contains(t, content2["text"], "How to test this?", "User message should contain question")
|
||||
assert.Contains(t, content2[ContentTypeText], "Dynamic User", "User message should contain username")
|
||||
assert.Contains(t, content2[ContentTypeText], "How to test this?", "User message should contain question")
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for prompt get response")
|
||||
@@ -2459,7 +2492,7 @@ func TestMethodPromptsGet(t *testing.T) {
|
||||
Name: "error-handler-prompt",
|
||||
Description: "A prompt with a handler that returns an error",
|
||||
Arguments: []PromptArgument{},
|
||||
Handler: func(args map[string]string) ([]PromptMessage, error) {
|
||||
Handler: func(ctx context.Context, args map[string]string) ([]PromptMessage, error) {
|
||||
return nil, fmt.Errorf("test handler error")
|
||||
},
|
||||
}
|
||||
@@ -2583,7 +2616,7 @@ func TestMethodResourcesList(t *testing.T) {
|
||||
URI: "file:///test/resource.txt",
|
||||
Description: "A test resource with handler",
|
||||
MimeType: "text/plain",
|
||||
Handler: func() (ResourceContent, error) {
|
||||
Handler: func(ctx context.Context) (ResourceContent, error) {
|
||||
return ResourceContent{
|
||||
URI: "file:///test/resource.txt",
|
||||
MimeType: "text/plain",
|
||||
@@ -2654,7 +2687,7 @@ func TestMethodResourcesRead(t *testing.T) {
|
||||
URI: "file:///test/resource.txt",
|
||||
Description: "A test resource with handler",
|
||||
MimeType: "text/plain",
|
||||
Handler: func() (ResourceContent, error) {
|
||||
Handler: func(ctx context.Context) (ResourceContent, error) {
|
||||
return ResourceContent{
|
||||
URI: "file:///test/resource.txt",
|
||||
MimeType: "text/plain",
|
||||
@@ -2729,7 +2762,7 @@ func TestMethodResourcesRead(t *testing.T) {
|
||||
require.True(t, ok, "Content should be an object")
|
||||
assert.Equal(t, "file:///test/resource.txt", content["uri"], "URI should match")
|
||||
assert.Equal(t, "text/plain", content["mimeType"], "MimeType should match")
|
||||
assert.Equal(t, "This is test resource content", content["text"], "Text content should match")
|
||||
assert.Equal(t, "This is test resource content", content[ContentTypeText], "Text content should match")
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for resource read response")
|
||||
@@ -2799,7 +2832,7 @@ func TestMethodResourcesRead(t *testing.T) {
|
||||
require.True(t, ok, "Content should be an object")
|
||||
assert.Equal(t, "file:///test/no-handler.txt", content["uri"], "URI should match")
|
||||
assert.Equal(t, "text/plain", content["mimeType"], "MimeType should match")
|
||||
_, ok = content["text"]
|
||||
_, ok = content[ContentTypeText]
|
||||
assert.False(t, ok, "Text content should be empty string")
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
@@ -2880,7 +2913,7 @@ func TestMethodResourcesRead(t *testing.T) {
|
||||
client := addTestClient(mock.server, "test-client-resources", true)
|
||||
|
||||
// Process through handleRequest
|
||||
mock.server.processResourcesRead(client, req)
|
||||
mock.server.processResourcesRead(context.Background(), client, req)
|
||||
|
||||
select {
|
||||
case response := <-client.channel:
|
||||
@@ -2898,7 +2931,7 @@ func TestMethodResourcesRead(t *testing.T) {
|
||||
URI: "file:///test/error.txt",
|
||||
Description: "A test resource with handler that returns error",
|
||||
MimeType: "text/plain",
|
||||
Handler: func() (ResourceContent, error) {
|
||||
Handler: func(ctx context.Context) (ResourceContent, error) {
|
||||
return ResourceContent{}, fmt.Errorf("test handler error")
|
||||
},
|
||||
}
|
||||
@@ -2946,7 +2979,7 @@ func TestMethodResourcesRead(t *testing.T) {
|
||||
URI: "file:///test/missing-fields.txt",
|
||||
Description: "A test resource with handler that returns content missing fields",
|
||||
MimeType: "text/plain",
|
||||
Handler: func() (ResourceContent, error) {
|
||||
Handler: func(ctx context.Context) (ResourceContent, error) {
|
||||
// Return ResourceContent without URI and MimeType
|
||||
return ResourceContent{
|
||||
Text: "Content with missing fields",
|
||||
@@ -3006,7 +3039,7 @@ func TestMethodResourcesRead(t *testing.T) {
|
||||
require.True(t, ok, "Content should be an object")
|
||||
assert.Equal(t, "file:///test/missing-fields.txt", content["uri"], "URI should be filled from request")
|
||||
assert.Equal(t, "text/plain", content["mimeType"], "MimeType should be filled from resource")
|
||||
assert.Equal(t, "Content with missing fields", content["text"], "Text content should match")
|
||||
assert.Equal(t, "Content with missing fields", content[ContentTypeText], "Text content should match")
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for resource read response")
|
||||
@@ -3159,7 +3192,7 @@ func TestMethodResourcesSubscribe(t *testing.T) {
|
||||
}
|
||||
|
||||
client := addTestClient(mock.server, "test-client-sub-not-found", true)
|
||||
mock.server.processResourceSubscribe(client, req)
|
||||
mock.server.processResourceSubscribe(context.Background(), client, req)
|
||||
|
||||
select {
|
||||
case response := <-client.channel:
|
||||
@@ -3268,7 +3301,7 @@ func TestToolCallUnmarshalError(t *testing.T) {
|
||||
}
|
||||
|
||||
// Process the tool call directly
|
||||
mock.server.processToolCall(client, req)
|
||||
mock.server.processToolCall(context.Background(), client, req)
|
||||
|
||||
// Check for error response about invalid JSON
|
||||
select {
|
||||
@@ -3316,7 +3349,7 @@ func TestToolCallWithInvalidParams(t *testing.T) {
|
||||
jsonBody, _ := json.Marshal(req)
|
||||
|
||||
// Create HTTP request
|
||||
r := httptest.NewRequest("POST", "/?session_id=test-client-invalid-json", bytes.NewReader(jsonBody))
|
||||
r := httptest.NewRequest(http.MethodPost, "/?session_id=test-client-invalid-json", bytes.NewReader(jsonBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Process through handleRequest
|
||||
|
||||
73
mcp/types.go
73
mcp/types.go
@@ -1,7 +1,9 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/zeromicro/go-zero/rest"
|
||||
@@ -14,11 +16,28 @@ type Cursor string
|
||||
type Request struct {
|
||||
SessionId string `form:"session_id"` // Session identifier for client tracking
|
||||
JsonRpc string `json:"jsonrpc"` // Must be "2.0" per JSON-RPC spec
|
||||
ID int64 `json:"id"` // Request identifier for matching responses
|
||||
ID any `json:"id"` // Request identifier for matching responses
|
||||
Method string `json:"method"` // Method name to invoke
|
||||
Params json.RawMessage `json:"params"` // Parameters for the method
|
||||
}
|
||||
|
||||
func (r Request) isNotification() (bool, error) {
|
||||
switch val := r.ID.(type) {
|
||||
case int:
|
||||
return val == 0, nil
|
||||
case int64:
|
||||
return val == 0, nil
|
||||
case float64:
|
||||
return val == 0.0, nil
|
||||
case string:
|
||||
return len(val) == 0, nil
|
||||
case nil:
|
||||
return true, nil
|
||||
default:
|
||||
return false, fmt.Errorf("invalid type %T", val)
|
||||
}
|
||||
}
|
||||
|
||||
type PaginatedParams struct {
|
||||
Cursor string `json:"cursor"`
|
||||
Meta struct {
|
||||
@@ -45,19 +64,18 @@ type ListToolsResult struct {
|
||||
|
||||
// Message Content Types
|
||||
|
||||
// roleType represents the sender or recipient of messages in a conversation
|
||||
type roleType string
|
||||
// RoleType represents the sender or recipient of messages in a conversation
|
||||
type RoleType string
|
||||
|
||||
// PromptArgument defines a single argument that can be passed to a prompt
|
||||
type PromptArgument struct {
|
||||
Name string `json:"name"` // Argument name
|
||||
Description string `json:"description,omitempty"` // Human-readable description
|
||||
Required bool `json:"required,omitempty"` // Whether this argument is required
|
||||
Default string `json:"default,omitempty"` // Default value if not provided
|
||||
}
|
||||
|
||||
// PromptHandler is a function that dynamically generates prompt content
|
||||
type PromptHandler func(args map[string]string) ([]PromptMessage, error)
|
||||
type PromptHandler func(ctx context.Context, args map[string]string) ([]PromptMessage, error)
|
||||
|
||||
// Prompt represents an MCP Prompt definition
|
||||
type Prompt struct {
|
||||
@@ -70,31 +88,43 @@ type Prompt struct {
|
||||
|
||||
// PromptMessage represents a message in a conversation
|
||||
type PromptMessage struct {
|
||||
Role roleType `json:"role"` // Message sender role
|
||||
Role RoleType `json:"role"` // Message sender role
|
||||
Content any `json:"content"` // Message content (TextContent, ImageContent, etc.)
|
||||
}
|
||||
|
||||
// TextContent represents text content in a message
|
||||
type TextContent struct {
|
||||
Type string `json:"type"` // Always "text"
|
||||
Text string `json:"text"` // The text content
|
||||
Annotations *Annotations `json:"annotations,omitempty"` // Optional annotations
|
||||
}
|
||||
|
||||
type typedTextContent struct {
|
||||
Type string `json:"type"`
|
||||
TextContent
|
||||
}
|
||||
|
||||
// ImageContent represents image data in a message
|
||||
type ImageContent struct {
|
||||
Type string `json:"type"` // Always "image"
|
||||
Data string `json:"data"` // Base64-encoded image data
|
||||
MimeType string `json:"mimeType"` // MIME type (e.g., "image/png")
|
||||
}
|
||||
|
||||
type typedImageContent struct {
|
||||
Type string `json:"type"`
|
||||
ImageContent
|
||||
}
|
||||
|
||||
// AudioContent represents audio data in a message
|
||||
type AudioContent struct {
|
||||
Type string `json:"type"` // Always "audio"
|
||||
Data string `json:"data"` // Base64-encoded audio data
|
||||
MimeType string `json:"mimeType"` // MIME type (e.g., "audio/mp3")
|
||||
}
|
||||
|
||||
type typedAudioContent struct {
|
||||
Type string `json:"type"`
|
||||
AudioContent
|
||||
}
|
||||
|
||||
// FileContent represents file content
|
||||
type FileContent struct {
|
||||
URI string `json:"uri"` // URI identifying the file
|
||||
@@ -104,27 +134,20 @@ type FileContent struct {
|
||||
|
||||
// EmbeddedResource represents a resource embedded in a message
|
||||
type EmbeddedResource struct {
|
||||
Type string `json:"type"` // Always "resource"
|
||||
Resource struct {
|
||||
URI string `json:"uri"` // Resource URI
|
||||
MimeType string `json:"mimeType"` // MIME type of the resource
|
||||
Text string `json:"text,omitempty"` // Text content (if available)
|
||||
Blob string `json:"blob,omitempty"` // Base64 encoded blob data (if available)
|
||||
} `json:"resource"` // The resource data
|
||||
Type string `json:"type"` // Always "resource"
|
||||
Resource ResourceContent `json:"resource"` // The resource data
|
||||
}
|
||||
|
||||
// Annotations provides additional metadata for content
|
||||
type Annotations struct {
|
||||
Audience []roleType `json:"audience,omitempty"` // Who should see this content
|
||||
Audience []RoleType `json:"audience,omitempty"` // Who should see this content
|
||||
Priority *float64 `json:"priority,omitempty"` // Optional priority (0-1)
|
||||
}
|
||||
|
||||
// Tool-related Types
|
||||
|
||||
// Tool Definition Types
|
||||
|
||||
// ToolHandler is a function that handles tool calls
|
||||
type ToolHandler func(params map[string]any) (any, error)
|
||||
type ToolHandler func(ctx context.Context, params map[string]any) (any, error)
|
||||
|
||||
// Tool represents a Model Context Protocol Tool definition
|
||||
type Tool struct {
|
||||
@@ -136,7 +159,7 @@ type Tool struct {
|
||||
|
||||
// InputSchema represents tool's input schema in JSON Schema format
|
||||
type InputSchema struct {
|
||||
Type string `json:"type"` // Always "object" for tool inputs
|
||||
Type string `json:"type"`
|
||||
Properties map[string]any `json:"properties"` // Property definitions
|
||||
Required []string `json:"required,omitempty"` // List of required properties
|
||||
}
|
||||
@@ -144,8 +167,8 @@ type InputSchema struct {
|
||||
// CallToolResult represents a tool call result that conforms to the MCP schema
|
||||
type CallToolResult struct {
|
||||
Result
|
||||
Content []interface{} `json:"content"` // Content items (text, images, etc.)
|
||||
IsError bool `json:"isError,omitempty"` // True if tool execution failed
|
||||
Content []any `json:"content"` // Content items (text, images, etc.)
|
||||
IsError bool `json:"isError,omitempty"` // True if tool execution failed
|
||||
}
|
||||
|
||||
// Resource represents a Model Context Protocol Resource definition
|
||||
@@ -158,7 +181,7 @@ type Resource struct {
|
||||
}
|
||||
|
||||
// ResourceHandler is a function that handles resource read requests
|
||||
type ResourceHandler func() (ResourceContent, error)
|
||||
type ResourceHandler func(ctx context.Context) (ResourceContent, error)
|
||||
|
||||
// ResourceContent represents the content of a resource
|
||||
type ResourceContent struct {
|
||||
@@ -239,7 +262,7 @@ type errorObj struct {
|
||||
// Response represents a JSON-RPC response
|
||||
type Response struct {
|
||||
JsonRpc string `json:"jsonrpc"` // Always "2.0"
|
||||
ID int64 `json:"id"` // Same as request ID
|
||||
ID any `json:"id"` // Same as request ID
|
||||
Result any `json:"result"` // Result object (null if error)
|
||||
Error *errorObj `json:"error,omitempty"` // Error object (null if success)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -54,7 +56,7 @@ func TestRequestUnmarshaling(t *testing.T) {
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "2.0", req.JsonRpc)
|
||||
assert.Equal(t, int64(789), req.ID)
|
||||
assert.Equal(t, float64(789), req.ID)
|
||||
assert.Equal(t, "test_method", req.Method)
|
||||
|
||||
// Check params unmarshaled correctly
|
||||
@@ -79,7 +81,7 @@ func TestToolStructs(t *testing.T) {
|
||||
},
|
||||
Required: []string{"input"},
|
||||
},
|
||||
Handler: func(params map[string]any) (any, error) {
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
return "result", nil
|
||||
},
|
||||
}
|
||||
@@ -145,44 +147,38 @@ func TestResourceStructs(t *testing.T) {
|
||||
func TestContentTypes(t *testing.T) {
|
||||
// Test TextContent
|
||||
textContent := TextContent{
|
||||
Type: "text",
|
||||
Text: "Sample text",
|
||||
Annotations: &Annotations{
|
||||
Audience: []roleType{roleUser, roleAssistant},
|
||||
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||
Priority: ptr(1.0),
|
||||
},
|
||||
}
|
||||
|
||||
data, err := json.Marshal(textContent)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(data), `"type":"text"`)
|
||||
assert.Contains(t, string(data), `"text":"Sample text"`)
|
||||
assert.Contains(t, string(data), `"audience":["user","assistant"]`)
|
||||
assert.Contains(t, string(data), `"priority":1`)
|
||||
|
||||
// Test ImageContent
|
||||
imageContent := ImageContent{
|
||||
Type: "image",
|
||||
Data: "base64data",
|
||||
MimeType: "image/png",
|
||||
}
|
||||
|
||||
data, err = json.Marshal(imageContent)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(data), `"type":"image"`)
|
||||
assert.Contains(t, string(data), `"data":"base64data"`)
|
||||
assert.Contains(t, string(data), `"mimeType":"image/png"`)
|
||||
|
||||
// Test AudioContent
|
||||
audioContent := AudioContent{
|
||||
Type: "audio",
|
||||
Data: "base64audio",
|
||||
MimeType: "audio/mp3",
|
||||
}
|
||||
|
||||
data, err = json.Marshal(audioContent)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(data), `"type":"audio"`)
|
||||
assert.Contains(t, string(data), `"data":"base64audio"`)
|
||||
assert.Contains(t, string(data), `"mimeType":"audio/mp3"`)
|
||||
}
|
||||
@@ -197,7 +193,6 @@ func TestCallToolResult(t *testing.T) {
|
||||
},
|
||||
Content: []interface{}{
|
||||
TextContent{
|
||||
Type: "text",
|
||||
Text: "Sample result",
|
||||
},
|
||||
},
|
||||
@@ -207,6 +202,70 @@ func TestCallToolResult(t *testing.T) {
|
||||
data, err := json.Marshal(result)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(data), `"_meta":{"progressToken":"token123"}`)
|
||||
assert.Contains(t, string(data), `"content":[{"type":"text","text":"Sample result"}]`)
|
||||
assert.Contains(t, string(data), `"content":[{"text":"Sample result"}]`)
|
||||
assert.NotContains(t, string(data), `"isError":`)
|
||||
}
|
||||
|
||||
func TestRequest_isNotification(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
id any
|
||||
want bool
|
||||
wantErr error
|
||||
}{
|
||||
// integer test cases
|
||||
{name: "int zero", id: 0, want: true, wantErr: nil},
|
||||
{name: "int non-zero", id: 1, want: false, wantErr: nil},
|
||||
{name: "int64 zero", id: int64(0), want: true, wantErr: nil},
|
||||
{name: "int64 max", id: int64(9223372036854775807), want: false, wantErr: nil},
|
||||
|
||||
// floating point number test cases
|
||||
{name: "float64 zero", id: float64(0.0), want: true, wantErr: nil},
|
||||
{name: "float64 positive", id: float64(0.000001), want: false, wantErr: nil},
|
||||
{name: "float64 negative", id: float64(-0.000001), want: false, wantErr: nil},
|
||||
{name: "float64 epsilon", id: float64(1e-300), want: false, wantErr: nil},
|
||||
|
||||
// string test cases
|
||||
{name: "empty string", id: "", want: true, wantErr: nil},
|
||||
{name: "non-empty string", id: "abc", want: false, wantErr: nil},
|
||||
{name: "space string", id: " ", want: false, wantErr: nil},
|
||||
{name: "unicode string", id: "こんにちは", want: false, wantErr: nil},
|
||||
|
||||
// special cases
|
||||
{name: "nil", id: nil, want: true, wantErr: nil},
|
||||
|
||||
// logical type test cases
|
||||
{name: "bool true", id: true, want: false, wantErr: errors.New("invalid type bool")},
|
||||
{name: "bool false", id: false, want: false, wantErr: errors.New("invalid type bool")},
|
||||
{name: "struct type", id: struct{}{}, want: false, wantErr: errors.New("invalid type struct {}")},
|
||||
{name: "slice type", id: []int{1, 2, 3}, want: false, wantErr: errors.New("invalid type []int")},
|
||||
{name: "map type", id: map[string]int{"a": 1}, want: false, wantErr: errors.New("invalid type map[string]int")},
|
||||
{name: "pointer type", id: new(int), want: false, wantErr: errors.New("invalid type *int")},
|
||||
{name: "func type", id: func() {}, want: false, wantErr: errors.New("invalid type func()")},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := Request{
|
||||
SessionId: "test-session",
|
||||
JsonRpc: "2.0",
|
||||
ID: tt.id,
|
||||
Method: "testMethod",
|
||||
Params: json.RawMessage(`{}`),
|
||||
}
|
||||
|
||||
got, err := req.isNotification()
|
||||
|
||||
if (err != nil) != (tt.wantErr != nil) {
|
||||
t.Fatalf("error presence mismatch: got error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
if err != nil && tt.wantErr != nil && err.Error() != tt.wantErr.Error() {
|
||||
t.Fatalf("error message mismatch:\ngot %q\nwant %q", err.Error(), tt.wantErr.Error())
|
||||
}
|
||||
|
||||
if got != tt.want {
|
||||
t.Errorf("isNotification() = %v, want %v for ID %v (%T)", got, tt.want, tt.id, tt.id)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
104
mcp/util.go
104
mcp/util.go
@@ -1,15 +1,107 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
import "fmt"
|
||||
|
||||
// formatSSEMessage formats a Server-Sent Event message with proper CRLF line endings
|
||||
func formatSSEMessage(event string, data []byte) string {
|
||||
return fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", event, string(data))
|
||||
}
|
||||
|
||||
// ptr is a helper function to get a pointer to a value
|
||||
func ptr[T any](v T) *T {
|
||||
return &v
|
||||
}
|
||||
|
||||
// formatSSEMessage formats a Server-Sent Event message with proper CRLF line endings
|
||||
func formatSSEMessage(event string, data []byte) string {
|
||||
return fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", event, string(data))
|
||||
func toTypedContents(contents []any) []any {
|
||||
typedContents := make([]any, len(contents))
|
||||
|
||||
for i, content := range contents {
|
||||
switch v := content.(type) {
|
||||
case TextContent:
|
||||
typedContents[i] = typedTextContent{
|
||||
Type: ContentTypeText,
|
||||
TextContent: v,
|
||||
}
|
||||
case ImageContent:
|
||||
typedContents[i] = typedImageContent{
|
||||
Type: ContentTypeImage,
|
||||
ImageContent: v,
|
||||
}
|
||||
case AudioContent:
|
||||
typedContents[i] = typedAudioContent{
|
||||
Type: ContentTypeAudio,
|
||||
AudioContent: v,
|
||||
}
|
||||
default:
|
||||
typedContents[i] = typedTextContent{
|
||||
Type: ContentTypeText,
|
||||
TextContent: TextContent{
|
||||
Text: fmt.Sprintf("Unknown content type: %T", v),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return typedContents
|
||||
}
|
||||
|
||||
func toTypedPromptMessages(messages []PromptMessage) []PromptMessage {
|
||||
typedMessages := make([]PromptMessage, len(messages))
|
||||
|
||||
for i, msg := range messages {
|
||||
switch v := msg.Content.(type) {
|
||||
case TextContent:
|
||||
typedMessages[i] = PromptMessage{
|
||||
Role: msg.Role,
|
||||
Content: typedTextContent{
|
||||
Type: ContentTypeText,
|
||||
TextContent: v,
|
||||
},
|
||||
}
|
||||
case ImageContent:
|
||||
typedMessages[i] = PromptMessage{
|
||||
Role: msg.Role,
|
||||
Content: typedImageContent{
|
||||
Type: ContentTypeImage,
|
||||
ImageContent: v,
|
||||
},
|
||||
}
|
||||
case AudioContent:
|
||||
typedMessages[i] = PromptMessage{
|
||||
Role: msg.Role,
|
||||
Content: typedAudioContent{
|
||||
Type: ContentTypeAudio,
|
||||
AudioContent: v,
|
||||
},
|
||||
}
|
||||
default:
|
||||
typedMessages[i] = PromptMessage{
|
||||
Role: msg.Role,
|
||||
Content: typedTextContent{
|
||||
Type: ContentTypeText,
|
||||
TextContent: TextContent{
|
||||
Text: fmt.Sprintf("Unknown content type: %T", v),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return typedMessages
|
||||
}
|
||||
|
||||
// validatePromptArguments checks if all required arguments are provided
|
||||
// Returns a list of missing required arguments
|
||||
func validatePromptArguments(prompt Prompt, providedArgs map[string]string) []string {
|
||||
var missingArgs []string
|
||||
|
||||
for _, arg := range prompt.Arguments {
|
||||
if arg.Required {
|
||||
if value, exists := providedArgs[arg.Name]; !exists || len(value) == 0 {
|
||||
missingArgs = append(missingArgs, arg.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return missingArgs
|
||||
}
|
||||
|
||||
253
mcp/util_test.go
253
mcp/util_test.go
@@ -8,29 +8,9 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPtr(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
v interface{}
|
||||
}{
|
||||
{"string", "test"},
|
||||
{"int", 42},
|
||||
{"bool", true},
|
||||
{"float", 3.14},
|
||||
{"struct", struct{ Name string }{"test"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ptr(tt.v)
|
||||
assert.NotNil(t, got, "ptr() should not return nil")
|
||||
assert.Equal(t, tt.v, *got, "dereferenced pointer should equal input value")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type Event struct {
|
||||
Type string
|
||||
Data map[string]any
|
||||
@@ -61,3 +41,234 @@ func parseEvent(input string) (*Event, error) {
|
||||
|
||||
return &evt, nil
|
||||
}
|
||||
|
||||
// TestToTypedPromptMessages tests the toTypedPromptMessages function
|
||||
func TestToTypedPromptMessages(t *testing.T) {
|
||||
// Test with multiple message types in one test
|
||||
t.Run("MixedContentTypes", func(t *testing.T) {
|
||||
// Create test data with different content types
|
||||
messages := []PromptMessage{
|
||||
{
|
||||
Role: RoleUser,
|
||||
Content: TextContent{
|
||||
Text: "Hello, this is a text message",
|
||||
Annotations: &Annotations{
|
||||
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||
Priority: ptr(0.8),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: RoleAssistant,
|
||||
Content: ImageContent{
|
||||
Data: "base64ImageData",
|
||||
MimeType: "image/jpeg",
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: RoleUser,
|
||||
Content: AudioContent{
|
||||
Data: "base64AudioData",
|
||||
MimeType: "audio/mp3",
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: "system",
|
||||
Content: "This is a simple string that should be handled as unknown type",
|
||||
},
|
||||
}
|
||||
|
||||
// Call the function
|
||||
result := toTypedPromptMessages(messages)
|
||||
|
||||
// Validate results
|
||||
require.Len(t, result, 4, "Should return the same number of messages")
|
||||
|
||||
// Validate first message (TextContent)
|
||||
msg := result[0]
|
||||
assert.Equal(t, RoleUser, msg.Role, "Role should be preserved")
|
||||
|
||||
// Type assertion using reflection since Content is an interface
|
||||
typed, ok := msg.Content.(typedTextContent)
|
||||
require.True(t, ok, "Should be typedTextContent")
|
||||
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
|
||||
assert.Equal(t, "Hello, this is a text message", typed.Text, "Text content should be preserved")
|
||||
require.NotNil(t, typed.Annotations, "Annotations should be preserved")
|
||||
assert.Equal(t, []RoleType{RoleUser, RoleAssistant}, typed.Annotations.Audience, "Audience should be preserved")
|
||||
require.NotNil(t, typed.Annotations.Priority, "Priority should be preserved")
|
||||
assert.Equal(t, 0.8, *typed.Annotations.Priority, "Priority value should be preserved")
|
||||
|
||||
// Validate second message (ImageContent)
|
||||
msg = result[1]
|
||||
assert.Equal(t, RoleAssistant, msg.Role, "Role should be preserved")
|
||||
|
||||
// Type assertion for image content
|
||||
typedImg, ok := msg.Content.(typedImageContent)
|
||||
require.True(t, ok, "Should be typedImageContent")
|
||||
assert.Equal(t, ContentTypeImage, typedImg.Type, "Type should be image")
|
||||
assert.Equal(t, "base64ImageData", typedImg.Data, "Image data should be preserved")
|
||||
assert.Equal(t, "image/jpeg", typedImg.MimeType, "MimeType should be preserved")
|
||||
|
||||
// Validate third message (AudioContent)
|
||||
msg = result[2]
|
||||
assert.Equal(t, RoleUser, msg.Role, "Role should be preserved")
|
||||
|
||||
// Type assertion for audio content
|
||||
typedAudio, ok := msg.Content.(typedAudioContent)
|
||||
require.True(t, ok, "Should be typedAudioContent")
|
||||
assert.Equal(t, ContentTypeAudio, typedAudio.Type, "Type should be audio")
|
||||
assert.Equal(t, "base64AudioData", typedAudio.Data, "Audio data should be preserved")
|
||||
assert.Equal(t, "audio/mp3", typedAudio.MimeType, "MimeType should be preserved")
|
||||
|
||||
// Validate fourth message (unknown type converted to TextContent)
|
||||
msg = result[3]
|
||||
assert.Equal(t, RoleType("system"), msg.Role, "Role should be preserved")
|
||||
|
||||
// Should be converted to a typedTextContent with error message
|
||||
typedUnknown, ok := msg.Content.(typedTextContent)
|
||||
require.True(t, ok, "Unknown content should be converted to typedTextContent")
|
||||
assert.Equal(t, ContentTypeText, typedUnknown.Type, "Type should be text")
|
||||
assert.Contains(t, typedUnknown.Text, "Unknown content type:", "Should contain error about unknown type")
|
||||
assert.Contains(t, typedUnknown.Text, "string", "Should mention the actual type")
|
||||
})
|
||||
|
||||
// Test empty input
|
||||
t.Run("EmptyInput", func(t *testing.T) {
|
||||
messages := []PromptMessage{}
|
||||
result := toTypedPromptMessages(messages)
|
||||
assert.Empty(t, result, "Should return empty slice for empty input")
|
||||
})
|
||||
|
||||
// Test with nil annotations
|
||||
t.Run("NilAnnotations", func(t *testing.T) {
|
||||
messages := []PromptMessage{
|
||||
{
|
||||
Role: RoleUser,
|
||||
Content: TextContent{
|
||||
Text: "Text with nil annotations",
|
||||
Annotations: nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := toTypedPromptMessages(messages)
|
||||
require.Len(t, result, 1, "Should return one message")
|
||||
|
||||
typed, ok := result[0].Content.(typedTextContent)
|
||||
require.True(t, ok, "Should be typedTextContent")
|
||||
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
|
||||
assert.Equal(t, "Text with nil annotations", typed.Text, "Text content should be preserved")
|
||||
assert.Nil(t, typed.Annotations, "Nil annotations should be preserved as nil")
|
||||
})
|
||||
}
|
||||
|
||||
// TestToTypedContents tests the toTypedContents function
|
||||
func TestToTypedContents(t *testing.T) {
|
||||
// Test with multiple content types in one test
|
||||
t.Run("MixedContentTypes", func(t *testing.T) {
|
||||
// Create test data with different content types
|
||||
contents := []any{
|
||||
TextContent{
|
||||
Text: "Hello, this is a text content",
|
||||
Annotations: &Annotations{
|
||||
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||
Priority: ptr(0.7),
|
||||
},
|
||||
},
|
||||
ImageContent{
|
||||
Data: "base64ImageData",
|
||||
MimeType: "image/png",
|
||||
},
|
||||
AudioContent{
|
||||
Data: "base64AudioData",
|
||||
MimeType: "audio/wav",
|
||||
},
|
||||
"This is a simple string that should be handled as unknown type",
|
||||
}
|
||||
|
||||
// Call the function
|
||||
result := toTypedContents(contents)
|
||||
|
||||
// Validate results
|
||||
require.Len(t, result, 4, "Should return the same number of contents")
|
||||
|
||||
// Validate first content (TextContent)
|
||||
typed, ok := result[0].(typedTextContent)
|
||||
require.True(t, ok, "Should be typedTextContent")
|
||||
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
|
||||
assert.Equal(t, "Hello, this is a text content", typed.Text, "Text content should be preserved")
|
||||
require.NotNil(t, typed.Annotations, "Annotations should be preserved")
|
||||
assert.Equal(t, []RoleType{RoleUser, RoleAssistant}, typed.Annotations.Audience, "Audience should be preserved")
|
||||
require.NotNil(t, typed.Annotations.Priority, "Priority should be preserved")
|
||||
assert.Equal(t, 0.7, *typed.Annotations.Priority, "Priority value should be preserved")
|
||||
|
||||
// Validate second content (ImageContent)
|
||||
typedImg, ok := result[1].(typedImageContent)
|
||||
require.True(t, ok, "Should be typedImageContent")
|
||||
assert.Equal(t, ContentTypeImage, typedImg.Type, "Type should be image")
|
||||
assert.Equal(t, "base64ImageData", typedImg.Data, "Image data should be preserved")
|
||||
assert.Equal(t, "image/png", typedImg.MimeType, "MimeType should be preserved")
|
||||
|
||||
// Validate third content (AudioContent)
|
||||
typedAudio, ok := result[2].(typedAudioContent)
|
||||
require.True(t, ok, "Should be typedAudioContent")
|
||||
assert.Equal(t, ContentTypeAudio, typedAudio.Type, "Type should be audio")
|
||||
assert.Equal(t, "base64AudioData", typedAudio.Data, "Audio data should be preserved")
|
||||
assert.Equal(t, "audio/wav", typedAudio.MimeType, "MimeType should be preserved")
|
||||
|
||||
// Validate fourth content (unknown type converted to TextContent)
|
||||
typedUnknown, ok := result[3].(typedTextContent)
|
||||
require.True(t, ok, "Unknown content should be converted to typedTextContent")
|
||||
assert.Equal(t, ContentTypeText, typedUnknown.Type, "Type should be text")
|
||||
assert.Contains(t, typedUnknown.Text, "Unknown content type:", "Should contain error about unknown type")
|
||||
assert.Contains(t, typedUnknown.Text, "string", "Should mention the actual type")
|
||||
})
|
||||
|
||||
// Test empty input
|
||||
t.Run("EmptyInput", func(t *testing.T) {
|
||||
contents := []any{}
|
||||
result := toTypedContents(contents)
|
||||
assert.Empty(t, result, "Should return empty slice for empty input")
|
||||
})
|
||||
|
||||
// Test with nil annotations
|
||||
t.Run("NilAnnotations", func(t *testing.T) {
|
||||
contents := []any{
|
||||
TextContent{
|
||||
Text: "Text with nil annotations",
|
||||
Annotations: nil,
|
||||
},
|
||||
}
|
||||
|
||||
result := toTypedContents(contents)
|
||||
require.Len(t, result, 1, "Should return one content")
|
||||
|
||||
typed, ok := result[0].(typedTextContent)
|
||||
require.True(t, ok, "Should be typedTextContent")
|
||||
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
|
||||
assert.Equal(t, "Text with nil annotations", typed.Text, "Text content should be preserved")
|
||||
assert.Nil(t, typed.Annotations, "Nil annotations should be preserved as nil")
|
||||
})
|
||||
|
||||
// Test with custom struct (should be handled as unknown type)
|
||||
t.Run("CustomStruct", func(t *testing.T) {
|
||||
type CustomContent struct {
|
||||
Data string
|
||||
}
|
||||
|
||||
contents := []any{
|
||||
CustomContent{
|
||||
Data: "custom data",
|
||||
},
|
||||
}
|
||||
|
||||
result := toTypedContents(contents)
|
||||
require.Len(t, result, 1, "Should return one content")
|
||||
|
||||
typed, ok := result[0].(typedTextContent)
|
||||
require.True(t, ok, "Custom struct should be converted to typedTextContent")
|
||||
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
|
||||
assert.Contains(t, typed.Text, "Unknown content type:", "Should contain error about unknown type")
|
||||
assert.Contains(t, typed.Text, "CustomContent", "Should mention the actual type")
|
||||
})
|
||||
}
|
||||
|
||||
28
mcp/vars.go
28
mcp/vars.go
@@ -13,6 +13,9 @@ const (
|
||||
|
||||
// Session identifier key used in request URLs
|
||||
sessionIdKey = "session_id"
|
||||
|
||||
// progressTokenKey is used to track progress of long-running tasks
|
||||
progressTokenKey = "progressToken"
|
||||
)
|
||||
|
||||
// Server-Sent Events (SSE) event types
|
||||
@@ -26,11 +29,20 @@ const (
|
||||
|
||||
// Content type identifiers
|
||||
const (
|
||||
// Text content type
|
||||
contentTypeText = "text"
|
||||
// ContentTypeObject is object content type
|
||||
ContentTypeObject = "object"
|
||||
|
||||
// Image content type
|
||||
contentTypeImage = "image"
|
||||
// ContentTypeText is text content type
|
||||
ContentTypeText = "text"
|
||||
|
||||
// ContentTypeImage is image content type
|
||||
ContentTypeImage = "image"
|
||||
|
||||
// ContentTypeAudio is audio content type
|
||||
ContentTypeAudio = "audio"
|
||||
|
||||
// ContentTypeResource is resource content type
|
||||
ContentTypeResource = "resource"
|
||||
)
|
||||
|
||||
// Collection keys for broadcast events
|
||||
@@ -72,11 +84,11 @@ const (
|
||||
|
||||
// User and assistant role definitions
|
||||
const (
|
||||
// The "user" role - the entity asking questions
|
||||
roleUser roleType = "user"
|
||||
// RoleUser is the "user" role - the entity asking questions
|
||||
RoleUser RoleType = "user"
|
||||
|
||||
// The "assistant" role - the entity providing responses
|
||||
roleAssistant roleType = "assistant"
|
||||
// RoleAssistant is the "assistant" role - the entity providing responses
|
||||
RoleAssistant RoleType = "assistant"
|
||||
)
|
||||
|
||||
// Method names as defined in the MCP specification
|
||||
|
||||
@@ -146,13 +146,9 @@ func TestCollectionKeys(t *testing.T) {
|
||||
|
||||
// TestRoleTypes checks that role types are used correctly
|
||||
func TestRoleTypes(t *testing.T) {
|
||||
// Verify role type constants
|
||||
assert.Equal(t, "user", string(roleUser), "User role should be 'user'")
|
||||
assert.Equal(t, "assistant", string(roleAssistant), "Assistant role should be 'assistant'")
|
||||
|
||||
// Test in annotations
|
||||
annotations := Annotations{
|
||||
Audience: []roleType{roleUser, roleAssistant},
|
||||
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||
}
|
||||
data, err := json.Marshal(annotations)
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -302,6 +302,8 @@ go-zero 已被许多公司用于生产部署,接入场景如在线教育、电
|
||||
>103. 爱芯元智半导体股份有限公司
|
||||
>104. 杭州升恒科技有限公司
|
||||
>105. 昆仑万维科技股份有限公司
|
||||
>106. 无锡盛算信息技术有限公司
|
||||
>107. 深圳市聚货通信息科技有限公司
|
||||
|
||||
如果贵公司也已使用 go-zero,欢迎在 [登记地址](https://github.com/zeromicro/go-zero/issues/602) 登记,仅仅为了推广,不做其它用途。
|
||||
|
||||
|
||||
@@ -251,7 +251,3 @@ go-zero enlisted in the [CNCF Cloud Native Landscape](https://landscape.cncf.io/
|
||||
## Give a Star! ⭐
|
||||
|
||||
If you like this project or are using it to learn or start your own solution, give it a star to get updates on new releases. Your support matters!
|
||||
|
||||
## Buy me a coffee
|
||||
|
||||
<a href="https://www.buymeacoffee.com/kevwan" target="_blank"><img src="https://cdn.buymeacoffee.com/buttons/v2/default-yellow.png" alt="Buy Me A Coffee" style="height: 60px !important;width: 217px !important;" ></a>
|
||||
|
||||
@@ -228,6 +228,10 @@ func (ng *engine) getShedder(priority bool) load.Shedder {
|
||||
return ng.shedder
|
||||
}
|
||||
|
||||
func (ng *engine) hasTimeout() bool {
|
||||
return ng.conf.Middlewares.Timeout && ng.timeout > 0
|
||||
}
|
||||
|
||||
// notFoundHandler returns a middleware that handles 404 not found requests.
|
||||
func (ng *engine) notFoundHandler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -354,16 +358,17 @@ func (ng *engine) use(middleware Middleware) {
|
||||
|
||||
func (ng *engine) withTimeout() internal.StartOption {
|
||||
return func(svr *http.Server) {
|
||||
timeout := ng.timeout
|
||||
if timeout > 0 {
|
||||
// factor 0.8, to avoid clients send longer content-length than the actual content,
|
||||
// without this timeout setting, the server will time out and respond 503 Service Unavailable,
|
||||
// which triggers the circuit breaker.
|
||||
svr.ReadTimeout = 4 * timeout / 5
|
||||
// factor 1.1, to avoid servers don't have enough time to write responses.
|
||||
// setting the factor less than 1.0 may lead clients not receiving the responses.
|
||||
svr.WriteTimeout = 11 * timeout / 10
|
||||
if !ng.hasTimeout() {
|
||||
return
|
||||
}
|
||||
|
||||
// factor 0.8, to avoid clients send longer content-length than the actual content,
|
||||
// without this timeout setting, the server will time out and respond 503 Service Unavailable,
|
||||
// which triggers the circuit breaker.
|
||||
svr.ReadTimeout = 4 * ng.timeout / 5
|
||||
// factor 1.1, to avoid servers don't have enough time to write responses.
|
||||
// setting the factor less than 1.0 may lead clients not receiving the responses.
|
||||
svr.WriteTimeout = 11 * ng.timeout / 10
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -394,7 +394,12 @@ func TestEngine_withTimeout(t *testing.T) {
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
ng := newEngine(RestConf{Timeout: test.timeout})
|
||||
ng := newEngine(RestConf{
|
||||
Timeout: test.timeout,
|
||||
Middlewares: MiddlewaresConf{
|
||||
Timeout: true,
|
||||
},
|
||||
})
|
||||
svr := &http.Server{}
|
||||
ng.withTimeout()(svr)
|
||||
|
||||
@@ -406,6 +411,62 @@ func TestEngine_withTimeout(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngine_ReadWriteTimeout(t *testing.T) {
|
||||
logx.Disable()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
timeout int64
|
||||
middleware bool
|
||||
}{
|
||||
{
|
||||
name: "0/false",
|
||||
timeout: 0,
|
||||
middleware: false,
|
||||
},
|
||||
{
|
||||
name: "0/true",
|
||||
timeout: 0,
|
||||
middleware: true,
|
||||
},
|
||||
{
|
||||
name: "set/false",
|
||||
timeout: 1000,
|
||||
middleware: false,
|
||||
},
|
||||
{
|
||||
name: "both set",
|
||||
timeout: 1000,
|
||||
middleware: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
ng := newEngine(RestConf{
|
||||
Timeout: test.timeout,
|
||||
Middlewares: MiddlewaresConf{
|
||||
Timeout: test.middleware,
|
||||
},
|
||||
})
|
||||
svr := &http.Server{}
|
||||
ng.withTimeout()(svr)
|
||||
|
||||
assert.Equal(t, time.Duration(0), svr.ReadHeaderTimeout)
|
||||
assert.Equal(t, time.Duration(0), svr.IdleTimeout)
|
||||
|
||||
if test.timeout > 0 && test.middleware {
|
||||
assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*4/5, svr.ReadTimeout)
|
||||
assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*11/10, svr.WriteTimeout)
|
||||
} else {
|
||||
assert.Equal(t, time.Duration(0), svr.ReadTimeout)
|
||||
assert.Equal(t, time.Duration(0), svr.WriteTimeout)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngine_start(t *testing.T) {
|
||||
logx.Disable()
|
||||
|
||||
|
||||
@@ -106,8 +106,8 @@ func (h *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
case <-ctx.Done():
|
||||
tw.mu.Lock()
|
||||
defer tw.mu.Unlock()
|
||||
// there isn't any user-defined middleware before TimoutHandler,
|
||||
// so we can guarantee that cancelation in biz related code won't come here.
|
||||
// there isn't any user-defined middleware before TimeoutHandler,
|
||||
// so we can guarantee that cancellation in biz related code won't come here.
|
||||
httpx.ErrorCtx(r.Context(), w, ctx.Err(), func(w http.ResponseWriter, err error) {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
w.WriteHeader(statusClientClosedRequest)
|
||||
@@ -151,7 +151,7 @@ func (tw *timeoutWriter) Flush() {
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
// Header returns the underline temporary http.Header.
|
||||
// Header returns the underlying temporary http.Header.
|
||||
func (tw *timeoutWriter) Header() http.Header {
|
||||
return tw.h
|
||||
}
|
||||
|
||||
@@ -2,15 +2,16 @@ package fileserver
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"path"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Middleware returns a middleware that serves files from the given file system.
|
||||
func Middleware(path string, fs http.FileSystem) func(http.HandlerFunc) http.HandlerFunc {
|
||||
func Middleware(upath string, fs http.FileSystem) func(http.HandlerFunc) http.HandlerFunc {
|
||||
fileServer := http.FileServer(fs)
|
||||
pathWithoutTrailSlash := ensureNoTrailingSlash(path)
|
||||
canServe := createServeChecker(path, fs)
|
||||
pathWithoutTrailSlash := ensureNoTrailingSlash(upath)
|
||||
canServe := createServeChecker(upath, fs)
|
||||
|
||||
return func(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -28,9 +29,22 @@ func createFileChecker(fs http.FileSystem) func(string) bool {
|
||||
var lock sync.RWMutex
|
||||
fileChecker := make(map[string]bool)
|
||||
|
||||
return func(path string) bool {
|
||||
return func(upath string) bool {
|
||||
// Emulate http.Dir.Open’s path normalization for embed.FS.Open.
|
||||
// http.FileServer redirects any request ending in "/index.html"
|
||||
// to the same path without the final "index.html".
|
||||
// So the path here may be empty or end with a "/".
|
||||
// http.Dir.Open uses this logic to clean the path,
|
||||
// correctly handling those two cases.
|
||||
// embed.FS doesn’t perform this normalization, so we apply the same logic here.
|
||||
upath = path.Clean("/" + upath)[1:]
|
||||
if len(upath) == 0 {
|
||||
// if the path is empty, we use "." to open the current directory
|
||||
upath = "."
|
||||
}
|
||||
|
||||
lock.RLock()
|
||||
exist, ok := fileChecker[path]
|
||||
exist, ok := fileChecker[upath]
|
||||
lock.RUnlock()
|
||||
if ok {
|
||||
return exist
|
||||
@@ -39,9 +53,9 @@ func createFileChecker(fs http.FileSystem) func(string) bool {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
file, err := fs.Open(path)
|
||||
file, err := fs.Open(upath)
|
||||
exist = err == nil
|
||||
fileChecker[path] = exist
|
||||
fileChecker[upath] = exist
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
@@ -51,8 +65,8 @@ func createFileChecker(fs http.FileSystem) func(string) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func createServeChecker(path string, fs http.FileSystem) func(r *http.Request) bool {
|
||||
pathWithTrailSlash := ensureTrailingSlash(path)
|
||||
func createServeChecker(upath string, fs http.FileSystem) func(r *http.Request) bool {
|
||||
pathWithTrailSlash := ensureTrailingSlash(upath)
|
||||
fileChecker := createFileChecker(fs)
|
||||
|
||||
return func(r *http.Request) bool {
|
||||
@@ -62,18 +76,18 @@ func createServeChecker(path string, fs http.FileSystem) func(r *http.Request) b
|
||||
}
|
||||
}
|
||||
|
||||
func ensureTrailingSlash(path string) string {
|
||||
if strings.HasSuffix(path, "/") {
|
||||
return path
|
||||
func ensureTrailingSlash(upath string) string {
|
||||
if strings.HasSuffix(upath, "/") {
|
||||
return upath
|
||||
}
|
||||
|
||||
return path + "/"
|
||||
return upath + "/"
|
||||
}
|
||||
|
||||
func ensureNoTrailingSlash(path string) string {
|
||||
if strings.HasSuffix(path, "/") {
|
||||
return path[:len(path)-1]
|
||||
func ensureNoTrailingSlash(upath string) string {
|
||||
if strings.HasSuffix(upath, "/") {
|
||||
return upath[:len(upath)-1]
|
||||
}
|
||||
|
||||
return path
|
||||
return upath
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package fileserver
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
@@ -61,6 +63,46 @@ func TestMiddleware(t *testing.T) {
|
||||
requestPath: "/ws",
|
||||
expectedStatus: http.StatusAlreadyReported,
|
||||
},
|
||||
|
||||
// http.FileServer redirects any request ending in "/index.html"
|
||||
// to the same path, without the final "index.html".
|
||||
{
|
||||
name: "Serve index.html",
|
||||
path: "/static",
|
||||
dir: "testdata",
|
||||
requestPath: "/static/index.html",
|
||||
expectedStatus: http.StatusMovedPermanently,
|
||||
},
|
||||
{
|
||||
name: "Serve index.html with path with trailing slash",
|
||||
path: "/static/",
|
||||
dir: "testdata",
|
||||
requestPath: "/static/index.html",
|
||||
expectedStatus: http.StatusMovedPermanently,
|
||||
},
|
||||
{
|
||||
name: "Serve index.html in a nested directory",
|
||||
path: "/static",
|
||||
dir: "testdata",
|
||||
requestPath: "/static/nested/index.html",
|
||||
expectedStatus: http.StatusMovedPermanently,
|
||||
},
|
||||
{
|
||||
name: "Request index.html indirectly",
|
||||
path: "/static",
|
||||
dir: "testdata",
|
||||
requestPath: "/static/",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedContent: "hello",
|
||||
},
|
||||
{
|
||||
name: "Request index.html in a nested directory indirectly",
|
||||
path: "/static",
|
||||
dir: "testdata",
|
||||
requestPath: "/static/nested/",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedContent: "hello",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -87,6 +129,128 @@ func TestMiddleware(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
//go:embed testdata
|
||||
testdataFS embed.FS
|
||||
)
|
||||
|
||||
func TestMiddleware_embedFS(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
requestPath string
|
||||
expectedStatus int
|
||||
expectedContent string
|
||||
}{
|
||||
{
|
||||
name: "Serve static file",
|
||||
path: "/static",
|
||||
requestPath: "/static/example.txt",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedContent: "1",
|
||||
},
|
||||
{
|
||||
name: "Path with trailing slash",
|
||||
path: "/static/",
|
||||
requestPath: "/static/example.txt",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedContent: "1",
|
||||
},
|
||||
{
|
||||
name: "Root path",
|
||||
path: "/",
|
||||
requestPath: "/example.txt",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedContent: "1",
|
||||
},
|
||||
{
|
||||
name: "Pass through non-matching path",
|
||||
path: "/static/",
|
||||
requestPath: "/other/path",
|
||||
expectedStatus: http.StatusAlreadyReported,
|
||||
},
|
||||
{
|
||||
name: "Not exist file",
|
||||
path: "/assets",
|
||||
requestPath: "/assets/not-exist.txt",
|
||||
expectedStatus: http.StatusAlreadyReported,
|
||||
},
|
||||
{
|
||||
name: "Not exist file in root",
|
||||
path: "/",
|
||||
requestPath: "/not-exist.txt",
|
||||
expectedStatus: http.StatusAlreadyReported,
|
||||
},
|
||||
{
|
||||
name: "websocket request",
|
||||
path: "/",
|
||||
requestPath: "/ws",
|
||||
expectedStatus: http.StatusAlreadyReported,
|
||||
},
|
||||
|
||||
// http.FileServer redirects any request ending in "/index.html"
|
||||
// to the same path, without the final "index.html".
|
||||
{
|
||||
name: "Serve index.html",
|
||||
path: "/static",
|
||||
requestPath: "/static/index.html",
|
||||
expectedStatus: http.StatusMovedPermanently,
|
||||
},
|
||||
{
|
||||
name: "Serve index.html with path with trailing slash",
|
||||
path: "/static/",
|
||||
requestPath: "/static/index.html",
|
||||
expectedStatus: http.StatusMovedPermanently,
|
||||
},
|
||||
{
|
||||
name: "Serve index.html in a nested directory",
|
||||
path: "/static",
|
||||
requestPath: "/static/nested/index.html",
|
||||
expectedStatus: http.StatusMovedPermanently,
|
||||
},
|
||||
{
|
||||
name: "Request index.html indirectly",
|
||||
path: "/static",
|
||||
requestPath: "/static/",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedContent: "hello",
|
||||
},
|
||||
{
|
||||
name: "Request index.html in a nested directory indirectly",
|
||||
path: "/static",
|
||||
requestPath: "/static/nested/",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedContent: "hello",
|
||||
},
|
||||
}
|
||||
|
||||
subFS, err := fs.Sub(testdataFS, "testdata")
|
||||
assert.Nil(t, err)
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
middleware := Middleware(tt.path, http.FS(subFS))
|
||||
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusAlreadyReported)
|
||||
})
|
||||
|
||||
handlerToTest := middleware(nextHandler)
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
req := httptest.NewRequest(http.MethodGet, tt.requestPath, nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
handlerToTest.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, tt.expectedStatus, rr.Code)
|
||||
if len(tt.expectedContent) > 0 {
|
||||
assert.Equal(t, tt.expectedContent, rr.Body.String())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureTrailingSlash(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
|
||||
1
rest/internal/fileserver/testdata/index.html
vendored
Normal file
1
rest/internal/fileserver/testdata/index.html
vendored
Normal file
@@ -0,0 +1 @@
|
||||
hello
|
||||
1
rest/internal/fileserver/testdata/nested/index.html
vendored
Normal file
1
rest/internal/fileserver/testdata/nested/index.html
vendored
Normal file
@@ -0,0 +1 @@
|
||||
hello
|
||||
1
tools/goctl/.gitignore
vendored
Normal file
1
tools/goctl/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
dist
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
"github.com/zeromicro/go-zero/tools/goctl/api/validate"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/config"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/internal/cobrax"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/pkg/env"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/plugin"
|
||||
)
|
||||
|
||||
@@ -78,6 +77,7 @@ func init() {
|
||||
goCmdFlags.StringVar(&gogen.VarStringRemote, "remote")
|
||||
goCmdFlags.StringVar(&gogen.VarStringBranch, "branch")
|
||||
goCmdFlags.BoolVar(&gogen.VarBoolWithTest, "test")
|
||||
goCmdFlags.BoolVar(&gogen.VarBoolTypeGroup, "type-group")
|
||||
goCmdFlags.StringVarWithDefaultValue(&gogen.VarStringStyle, "style", config.DefaultFormat)
|
||||
|
||||
javaCmdFlags.StringVar(&javagen.VarStringDir, "dir")
|
||||
@@ -110,8 +110,5 @@ func init() {
|
||||
validateCmdFlags.StringVar(&validate.VarStringAPI, "api")
|
||||
|
||||
// Add sub-commands
|
||||
Cmd.AddCommand(dartCmd, docCmd, formatCmd, goCmd, javaCmd, ktCmd, newCmd, pluginCmd, tsCmd, validateCmd)
|
||||
if env.UseExperimental() {
|
||||
Cmd.AddCommand(swaggerCmd)
|
||||
}
|
||||
Cmd.AddCommand(dartCmd, docCmd, formatCmd, goCmd, javaCmd, ktCmd, newCmd, pluginCmd, tsCmd, validateCmd, swaggerCmd)
|
||||
}
|
||||
|
||||
@@ -42,8 +42,19 @@ var (
|
||||
func GoFormatApi(_ *cobra.Command, _ []string) error {
|
||||
var be errorx.BatchError
|
||||
if VarBoolUseStdin {
|
||||
if err := apiFormatReader(os.Stdin, VarStringDir, VarBoolSkipCheckDeclare); err != nil {
|
||||
be.Add(err)
|
||||
if env.UseExperimental() {
|
||||
data, err := io.ReadAll(os.Stdin)
|
||||
if err != nil {
|
||||
be.Add(err)
|
||||
} else {
|
||||
if err := apiF.Source(data, os.Stdout); err != nil {
|
||||
be.Add(err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if err := apiFormatReader(os.Stdin, VarStringDir, VarBoolSkipCheckDeclare); err != nil {
|
||||
be.Add(err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if len(VarStringDir) == 0 {
|
||||
|
||||
@@ -40,6 +40,8 @@ var (
|
||||
// VarStringStyle describes the style of output files.
|
||||
VarStringStyle string
|
||||
VarBoolWithTest bool
|
||||
// VarBoolTypeGroup describes whether to group types.
|
||||
VarBoolTypeGroup bool
|
||||
)
|
||||
|
||||
// GoCommand gen go project files from command line
|
||||
|
||||
@@ -9,11 +9,11 @@ import (
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/collection"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||
apiutil "github.com/zeromicro/go-zero/tools/goctl/api/util"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/config"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/internal/version"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/pkg/env"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/util"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/util/format"
|
||||
)
|
||||
@@ -41,53 +41,116 @@ func BuildTypes(types []spec.Type) (string, error) {
|
||||
return builder.String(), nil
|
||||
}
|
||||
|
||||
func removeTypeFromDefault(tp spec.Type, group string, groupTypes map[string]map[string]spec.Type) map[string]map[string]spec.Type {
|
||||
func getTypeName(tp spec.Type) string {
|
||||
if tp == nil {
|
||||
return ""
|
||||
}
|
||||
switch val := tp.(type) {
|
||||
case spec.DefineStruct:
|
||||
typeName := util.Title(tp.Name())
|
||||
defaultGroups, ok := groupTypes[groupTypeDefault]
|
||||
if ok {
|
||||
delete(defaultGroups, typeName)
|
||||
types, ok := groupTypes[group]
|
||||
if !ok {
|
||||
types = make(map[string]spec.Type)
|
||||
}
|
||||
types[typeName] = tp
|
||||
groupTypes[group] = types
|
||||
}
|
||||
groupTypes[groupTypeDefault] = defaultGroups
|
||||
return typeName
|
||||
case spec.PointerType:
|
||||
groupTypes = removeTypeFromDefault(val.Type, group, groupTypes)
|
||||
return getTypeName(val.Type)
|
||||
case spec.ArrayType:
|
||||
groupTypes = removeTypeFromDefault(val.Value, group, groupTypes)
|
||||
return getTypeName(val.Value)
|
||||
}
|
||||
return groupTypes
|
||||
return ""
|
||||
}
|
||||
|
||||
func genTypesWithGroup(dir string, cfg *config.Config, api *spec.ApiSpec) error {
|
||||
groupTypes := make(map[string]map[string]spec.Type)
|
||||
for _, v := range api.Types {
|
||||
types, ok := groupTypes[groupTypeDefault]
|
||||
if !ok {
|
||||
types = make(map[string]spec.Type)
|
||||
}
|
||||
types[util.Title(v.Name())] = v
|
||||
groupTypes[groupTypeDefault] = types
|
||||
}
|
||||
typesBelongToFiles := make(map[string]*collection.Set)
|
||||
|
||||
for _, v := range api.Service.Groups {
|
||||
group := v.GetAnnotation(groupProperty)
|
||||
if len(group) == 0 {
|
||||
group = groupTypeDefault
|
||||
}
|
||||
// convert filepath to Identifier name spec.
|
||||
group = strings.TrimPrefix(group, "/")
|
||||
group = strings.TrimSuffix(group, "/")
|
||||
group = util.SafeString(group)
|
||||
for _, v := range v.Routes {
|
||||
requestTypeName := getTypeName(v.RequestType)
|
||||
responseTypeName := getTypeName(v.ResponseType)
|
||||
requestTypeFileSet, ok := typesBelongToFiles[requestTypeName]
|
||||
if !ok {
|
||||
requestTypeFileSet = collection.NewSet()
|
||||
}
|
||||
if len(requestTypeName) > 0 {
|
||||
requestTypeFileSet.AddStr(group)
|
||||
typesBelongToFiles[requestTypeName] = requestTypeFileSet
|
||||
}
|
||||
|
||||
responseTypeFileSet, ok := typesBelongToFiles[responseTypeName]
|
||||
if !ok {
|
||||
responseTypeFileSet = collection.NewSet()
|
||||
}
|
||||
if len(responseTypeName) > 0 {
|
||||
responseTypeFileSet.AddStr(group)
|
||||
typesBelongToFiles[responseTypeName] = responseTypeFileSet
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
typesInOneFile := make(map[string]*collection.Set)
|
||||
for typeName, fileSet := range typesBelongToFiles {
|
||||
count := fileSet.Count()
|
||||
switch {
|
||||
case count == 0: // it means there has no structure type or no request/response body
|
||||
continue
|
||||
case count == 1: // it means a structure type used in only one group.
|
||||
groupName := fileSet.KeysStr()[0]
|
||||
typeSet, ok := typesInOneFile[groupName]
|
||||
if !ok {
|
||||
typeSet = collection.NewSet()
|
||||
}
|
||||
typeSet.AddStr(typeName)
|
||||
typesInOneFile[groupName] = typeSet
|
||||
default: // it means this type is used in multiple groups.
|
||||
continue
|
||||
}
|
||||
for _, v := range v.Routes {
|
||||
if v.RequestType != nil {
|
||||
groupTypes = removeTypeFromDefault(v.RequestType, group, groupTypes)
|
||||
}
|
||||
if v.ResponseType != nil {
|
||||
groupTypes = removeTypeFromDefault(v.ResponseType, group, groupTypes)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range api.Types {
|
||||
typeName := util.Title(v.Name())
|
||||
groupSet, ok := typesBelongToFiles[typeName]
|
||||
var typeCount int
|
||||
if !ok {
|
||||
typeCount = 0
|
||||
} else {
|
||||
typeCount = groupSet.Count()
|
||||
}
|
||||
|
||||
if typeCount == 0 { // not belong to any group
|
||||
types, ok := groupTypes[groupTypeDefault]
|
||||
if !ok {
|
||||
types = make(map[string]spec.Type)
|
||||
}
|
||||
types[typeName] = v
|
||||
groupTypes[groupTypeDefault] = types
|
||||
continue
|
||||
}
|
||||
|
||||
if typeCount == 1 { // belong to one group
|
||||
groupName := groupSet.KeysStr()[0]
|
||||
types, ok := groupTypes[groupName]
|
||||
if !ok {
|
||||
types = make(map[string]spec.Type)
|
||||
}
|
||||
types[typeName] = v
|
||||
groupTypes[groupName] = types
|
||||
continue
|
||||
}
|
||||
|
||||
// belong to multiple groups
|
||||
types, ok := groupTypes[groupTypeDefault]
|
||||
if !ok {
|
||||
types = make(map[string]spec.Type)
|
||||
}
|
||||
types[typeName] = v
|
||||
groupTypes[groupTypeDefault] = types
|
||||
|
||||
}
|
||||
|
||||
for group, typeGroup := range groupTypes {
|
||||
@@ -142,7 +205,7 @@ func writeTypes(dir, baseFilename string, cfg *config.Config, types []spec.Type)
|
||||
}
|
||||
|
||||
func genTypes(dir string, cfg *config.Config, api *spec.ApiSpec) error {
|
||||
if env.UseExperimental() {
|
||||
if VarBoolTypeGroup {
|
||||
return genTypesWithGroup(dir, cfg, api)
|
||||
}
|
||||
return writeTypes(dir, typesFile, cfg, api.Types)
|
||||
|
||||
@@ -8,10 +8,10 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"path"
|
||||
"slices"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/stringx"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||
apiutil "github.com/zeromicro/go-zero/tools/goctl/api/util"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/internal/version"
|
||||
@@ -96,13 +96,13 @@ func (c *componentsContext) createComponent(dir, packetName string, ty spec.Type
|
||||
for _, item := range c.responseTypes {
|
||||
if item.Name() == defineStruct.Name() {
|
||||
superClassName = "HttpResponseData"
|
||||
if !stringx.Contains(c.imports, httpResponseData) {
|
||||
if !slices.Contains(c.imports, httpResponseData) {
|
||||
c.imports = append(c.imports, httpResponseData)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
if superClassName == "HttpData" && !stringx.Contains(c.imports, httpData) {
|
||||
if superClassName == "HttpData" && !slices.Contains(c.imports, httpData) {
|
||||
c.imports = append(c.imports, httpData)
|
||||
}
|
||||
|
||||
@@ -266,7 +266,7 @@ func (c *componentsContext) genGetSet(writer io.Writer, indent int) error {
|
||||
tyString := javaType
|
||||
decorator := ""
|
||||
javaPrimitiveType := []string{"int", "long", "boolean", "float", "double", "short"}
|
||||
if !stringx.Contains(javaPrimitiveType, javaType) {
|
||||
if !slices.Contains(javaPrimitiveType, javaType) {
|
||||
if member.IsOptional() || member.IsOmitEmpty() {
|
||||
decorator = "@Nullable "
|
||||
} else {
|
||||
|
||||
@@ -3,9 +3,9 @@ package spec
|
||||
import (
|
||||
"errors"
|
||||
"path"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/stringx"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/util"
|
||||
)
|
||||
|
||||
@@ -64,7 +64,7 @@ func (m Member) IsOptional() bool {
|
||||
tag := m.Tags()
|
||||
for _, item := range tag {
|
||||
if item.Key == bodyTagKey || item.Key == formTagKey {
|
||||
if stringx.Contains(item.Options, "optional") {
|
||||
if slices.Contains(item.Options, "optional") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -81,7 +81,7 @@ func (m Member) IsOmitEmpty() bool {
|
||||
tag := m.Tags()
|
||||
for _, item := range tag {
|
||||
if item.Key == bodyTagKey {
|
||||
if stringx.Contains(item.Options, "omitempty") {
|
||||
if slices.Contains(item.Options, "omitempty") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -93,7 +93,7 @@ func (m Member) IsOmitEmpty() bool {
|
||||
func (m Member) GetPropertyName() (string, error) {
|
||||
tags := m.Tags()
|
||||
for _, tag := range tags {
|
||||
if stringx.Contains(definedKeys, tag.Key) {
|
||||
if slices.Contains(definedKeys, tag.Key) {
|
||||
if tag.Name == "-" {
|
||||
return util.Untitle(m.Name), nil
|
||||
}
|
||||
|
||||
@@ -7,15 +7,6 @@ import (
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
func hasKey(properties map[string]string, key string) bool {
|
||||
if len(properties) == 0 {
|
||||
return false
|
||||
}
|
||||
md := metadata.New(properties)
|
||||
_, ok := md[key]
|
||||
return ok
|
||||
}
|
||||
|
||||
func getBoolFromKVOrDefault(properties map[string]string, key string, def bool) bool {
|
||||
if len(properties) == 0 {
|
||||
return def
|
||||
@@ -69,3 +60,16 @@ func getListFromInfoOrDefault(properties map[string]string, key string, def []st
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
func getFirstUsableString(def ...string) string {
|
||||
if len(def) == 0 {
|
||||
return ""
|
||||
}
|
||||
for _, val := range def {
|
||||
str := util.Unquote(val)
|
||||
if len(str) != 0 {
|
||||
return str
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -42,7 +42,7 @@ func Test_getListFromInfoOrDefault(t *testing.T) {
|
||||
"empty": `""`,
|
||||
}
|
||||
|
||||
assert.Equal(t, []string{"a", "b", "c"}, getListFromInfoOrDefault(properties, "list", []string{"default"}))
|
||||
assert.Equal(t, []string{"a", " b", " c"}, getListFromInfoOrDefault(properties, "list", []string{"default"}))
|
||||
assert.Equal(t, []string{"default"}, getListFromInfoOrDefault(properties, "empty", []string{"default"}))
|
||||
assert.Equal(t, []string{"default"}, getListFromInfoOrDefault(properties, "missing", []string{"default"}))
|
||||
assert.Equal(t, []string{"default"}, getListFromInfoOrDefault(nil, "nil", []string{"default"}))
|
||||
|
||||
@@ -8,10 +8,9 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"gopkg.in/yaml.v2"
|
||||
|
||||
"github.com/zeromicro/go-zero/tools/goctl/pkg/parser/api/parser"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
var (
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
package swagger
|
||||
|
||||
const (
|
||||
tagHeader = "header"
|
||||
tagPath = "path"
|
||||
tagForm = "form"
|
||||
tagJson = "json"
|
||||
defFlag = "default="
|
||||
enumFlag = "options="
|
||||
rangeFlag = "range="
|
||||
exampleFlag = "example="
|
||||
tagHeader = "header"
|
||||
tagPath = "path"
|
||||
tagForm = "form"
|
||||
tagJson = "json"
|
||||
defFlag = "default="
|
||||
enumFlag = "options="
|
||||
rangeFlag = "range="
|
||||
exampleFlag = "example="
|
||||
optionalFlag = "optional"
|
||||
|
||||
paramsInHeader = "header"
|
||||
paramsInPath = "path"
|
||||
@@ -27,6 +28,38 @@ const (
|
||||
applicationJson = "application/json"
|
||||
applicationForm = "application/x-www-form-urlencoded"
|
||||
schemeHttps = "https"
|
||||
defaultHost = "127.0.0.1"
|
||||
defaultBasePath = "/"
|
||||
)
|
||||
|
||||
const (
|
||||
propertyKeyUseDefinitions = "useDefinitions"
|
||||
propertyKeyExternalDocsDescription = "externalDocsDescription"
|
||||
propertyKeyExternalDocsURL = "externalDocsURL"
|
||||
propertyKeyTitle = "title"
|
||||
propertyKeyTermsOfService = "termsOfService"
|
||||
propertyKeyDescription = "description"
|
||||
propertyKeyVersion = "version"
|
||||
propertyKeyContactName = "contactName"
|
||||
propertyKeyContactURL = "contactURL"
|
||||
propertyKeyContactEmail = "contactEmail"
|
||||
propertyKeyLicenseName = "licenseName"
|
||||
propertyKeyLicenseURL = "licenseURL"
|
||||
propertyKeyProduces = "produces"
|
||||
propertyKeyConsumes = "consumes"
|
||||
propertyKeySchemes = "schemes"
|
||||
propertyKeyTags = "tags"
|
||||
propertyKeySummary = "summary"
|
||||
propertyKeyGroup = "group"
|
||||
propertyKeyOperationId = "operationId"
|
||||
propertyKeyDeprecated = "deprecated"
|
||||
propertyKeyPrefix = "prefix"
|
||||
propertyKeyAuthType = "authType"
|
||||
propertyKeyHost = "host"
|
||||
propertyKeyBasePath = "basePath"
|
||||
propertyKeyWrapCodeMsg = "wrapCodeMsg"
|
||||
propertyKeyBizCodeEnumDescription = "bizCodeEnumDescription"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultValueOfPropertyUseDefinition = false
|
||||
)
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||
)
|
||||
|
||||
func consumesFromTypeOrDef(method string, tp spec.Type) []string {
|
||||
func consumesFromTypeOrDef(ctx Context, method string, tp spec.Type) []string {
|
||||
if strings.EqualFold(method, http.MethodGet) {
|
||||
return []string{}
|
||||
}
|
||||
@@ -18,7 +18,7 @@ func consumesFromTypeOrDef(method string, tp spec.Type) []string {
|
||||
if !ok {
|
||||
return []string{}
|
||||
}
|
||||
if typeContainsTag(structType, tagJson) {
|
||||
if typeContainsTag(ctx, structType, tagJson) {
|
||||
return []string{applicationJson}
|
||||
}
|
||||
return []string{applicationForm}
|
||||
|
||||
@@ -61,7 +61,7 @@ func TestConsumesFromTypeOrDef(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := consumesFromTypeOrDef(tt.method, tt.tp)
|
||||
result := consumesFromTypeOrDef(testingContext(t), tt.method, tt.tp)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
|
||||
28
tools/goctl/api/swagger/context.go
Normal file
28
tools/goctl/api/swagger/context.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package swagger
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||
)
|
||||
|
||||
type Context struct {
|
||||
UseDefinitions bool
|
||||
WrapCodeMsg bool
|
||||
BizCodeEnumDescription string
|
||||
}
|
||||
|
||||
func testingContext(_ *testing.T) Context {
|
||||
return Context{}
|
||||
}
|
||||
|
||||
func contextFromApi(info spec.Info) Context {
|
||||
if len(info.Properties) == 0 {
|
||||
return Context{}
|
||||
}
|
||||
return Context{
|
||||
UseDefinitions: getBoolFromKVOrDefault(info.Properties, propertyKeyUseDefinitions, defaultValueOfPropertyUseDefinition),
|
||||
WrapCodeMsg: getBoolFromKVOrDefault(info.Properties, propertyKeyWrapCodeMsg, false),
|
||||
BizCodeEnumDescription: getStringFromKVOrDefault(info.Properties, propertyKeyBizCodeEnumDescription, "business code"),
|
||||
}
|
||||
}
|
||||
32
tools/goctl/api/swagger/definition.go
Normal file
32
tools/goctl/api/swagger/definition.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package swagger
|
||||
|
||||
import (
|
||||
"github.com/go-openapi/spec"
|
||||
apiSpec "github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||
)
|
||||
|
||||
func definitionsFromTypes(ctx Context, types []apiSpec.Type) spec.Definitions {
|
||||
if !ctx.UseDefinitions {
|
||||
return nil
|
||||
}
|
||||
definitions := make(spec.Definitions)
|
||||
for _, tp := range types {
|
||||
typeName := tp.Name()
|
||||
definitions[typeName] = schemaFromType(ctx, tp)
|
||||
}
|
||||
return definitions
|
||||
}
|
||||
|
||||
func schemaFromType(ctx Context, tp apiSpec.Type) spec.Schema {
|
||||
p, r := propertiesFromType(ctx, tp)
|
||||
props := spec.SchemaProps{
|
||||
Type: typeFromGoType(ctx, tp),
|
||||
Properties: p,
|
||||
AdditionalProperties: mapFromGoType(ctx, tp),
|
||||
Items: itemsFromGoType(ctx, tp),
|
||||
Required: r,
|
||||
}
|
||||
return spec.Schema{
|
||||
SchemaProps: props,
|
||||
}
|
||||
}
|
||||
@@ -12,15 +12,16 @@ info (
|
||||
licenseURL: "https://github.com/zeromicro/go-zero" // licenseURL corresponding to Swagger
|
||||
consumes: "application/json" // consumes corresponding to Swagger,default value is `application/json`
|
||||
produces: "application/json" // produces corresponding to Swagger,default value is `application/json`
|
||||
schemes: "https" // schemes corresponding to Swagger,default value is `https``
|
||||
schemes: "http,https" // schemes corresponding to Swagger,default value is `https``
|
||||
host: "example.com" // host corresponding to Swagger,default value is `127.0.0.1`
|
||||
basePath: "/v1" // basePath corresponding to Swagger,default value is `/`
|
||||
wrapCodeMsg: "true" // to wrap in the universal code-msg structure, like {"code":0,"msg":"OK","data":$data}
|
||||
wrapCodeMsg: true // to wrap in the universal code-msg structure, like {"code":0,"msg":"OK","data":$data}
|
||||
bizCodeEnumDescription: "1001-User not login<br>1002-User permission denied" // enums of business error codes, in JSON format, with the key being the business error code and the value being the description of that error code. This only takes effect when wrapCodeMsg is set to true.
|
||||
// securityDefinitionsFromJson is a custom authentication configuration, and the JSON content will be directly inserted into the securityDefinitions of Swagger.
|
||||
// Format reference: https://swagger.io/specification/v2/#security-definitions-object
|
||||
// You can declare authType in the @server of the API to specify the authentication type used for its routes.
|
||||
securityDefinitionsFromJson: `{"apiKey":{"description":"apiKey type description","type":"apiKey","name":"x-api-key","in":"header"}}`
|
||||
useDefinitions: true // if set true, the definitions will be generated in the swagger.json for response body or json request body file, and the models will be referenced in the API.
|
||||
)
|
||||
|
||||
type (
|
||||
|
||||
4980
tools/goctl/api/swagger/example/example.swagger.json
Normal file
4980
tools/goctl/api/swagger/example/example.swagger.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -12,15 +12,16 @@ info (
|
||||
licenseURL: "https://github.com/zeromicro/go-zero" // 对应 swagger 的 licenseURL
|
||||
consumes: "application/json" // 对应 swagger 的 consumes,不填默认为 application/json
|
||||
produces: "application/json" // 对应 swagger 的 produces,不填默认为 application/json
|
||||
schemes: "https" // 对应 swagger 的 schemes,不填默认为 https
|
||||
schemes: "http,https" // 对应 swagger 的 schemes,不填默认为 https
|
||||
host: "example.com" // 对应 swagger 的 host,不填默认为 127.0.0.1
|
||||
basePath: "/v1" // 对应 swagger 的 basePath,不填默认为 /
|
||||
wrapCodeMsg: "true" // 是否用 code-msg 通用响应体,如果开启,则以格式 {"code":0,"msg":"OK","data":$data} 包括响应体
|
||||
bizCodeEnumDescription: "1001-未登录<br>1002-无权限操作" // 业务错误码枚举描述,json 格式,key 为业务错误码,value 为该错误码的描述,仅当 wrapCodeMsg 为 true 时生效
|
||||
wrapCodeMsg: true // 是否用 code-msg 通用响应体,如果开启,则以格式 {"code":0,"msg":"OK","data":$data} 包括响应体
|
||||
bizCodeEnumDescription: "1001-未登录<br>1002-无权限操作" // 全局业务错误码枚举描述,json 格式,key 为业务错误码,value 为该错误码的描述,仅当 wrapCodeMsg 为 true 时生效
|
||||
// securityDefinitionsFromJson 为自定义鉴权配置,json 内容将直接放入 swagger 的 securityDefinitions 中,
|
||||
// 格式参考 https://swagger.io/specification/v2/#security-definitions-object
|
||||
// 在 api 的 @server 中可声明 authType 来指定其路由使用的鉴权类型
|
||||
securityDefinitionsFromJson: `{"apiKey":{"description":"apiKey 类型鉴权自定义","type":"apiKey","name":"x-api-key","in":"header"}}`
|
||||
useDefinitions: true// 开启声明将生成models 进行关联,definitions 仅对响应体和 json 请求体生效
|
||||
)
|
||||
|
||||
type (
|
||||
@@ -48,10 +49,12 @@ type (
|
||||
summary: "query 类型接口集合" // 对应 swagger 的 summary
|
||||
prefix: v1
|
||||
authType: apiKey // 指定该路由使用的鉴权类型,值为 securityDefinitionsFromJson 中定义的名称
|
||||
group:"demo"
|
||||
)
|
||||
service Swagger {
|
||||
@doc (
|
||||
description: "query 接口"
|
||||
bizCodeEnumDescription: " 1003-用不存在<br>1004-非法操作" // 接口级别业务错误码枚举描述,会覆盖全局的业务错误码,json 格式,key 为业务错误码,value 为该错误码的描述,仅当 wrapCodeMsg 为 true 且 useDefinitions 为 false 时生效
|
||||
)
|
||||
@handler query
|
||||
get /query (QueryReq) returns (QueryResp)
|
||||
|
||||
5608
tools/goctl/api/swagger/example/example_cn.swagger.json
Normal file
5608
tools/goctl/api/swagger/example/example_cn.swagger.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -81,21 +81,21 @@ func enumsValueFromOptions(options []string) []any {
|
||||
return []any{}
|
||||
}
|
||||
|
||||
func defValueFromOptions(options []string, apiType spec.Type) any {
|
||||
tp := sampleTypeFromGoType(apiType)
|
||||
return valueFromOptions(options, defFlag, tp)
|
||||
func defValueFromOptions(ctx Context, options []string, apiType spec.Type) any {
|
||||
tp := sampleTypeFromGoType(ctx, apiType)
|
||||
return valueFromOptions(ctx, options, defFlag, tp)
|
||||
}
|
||||
|
||||
func exampleValueFromOptions(options []string, apiType spec.Type) any {
|
||||
tp := sampleTypeFromGoType(apiType)
|
||||
val := valueFromOptions(options, exampleFlag, tp)
|
||||
func exampleValueFromOptions(ctx Context, options []string, apiType spec.Type) any {
|
||||
tp := sampleTypeFromGoType(ctx, apiType)
|
||||
val := valueFromOptions(ctx, options, exampleFlag, tp)
|
||||
if val != nil {
|
||||
return val
|
||||
}
|
||||
return defValueFromOptions(options, apiType)
|
||||
return defValueFromOptions(ctx, options, apiType)
|
||||
}
|
||||
|
||||
func valueFromOptions(options []string, key string, tp string) any {
|
||||
func valueFromOptions(_ Context, options []string, key string, tp string) any {
|
||||
if len(options) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -103,16 +103,18 @@ func valueFromOptions(options []string, key string, tp string) any {
|
||||
if strings.HasPrefix(option, key) {
|
||||
s := option[len(key):]
|
||||
switch tp {
|
||||
case "integer":
|
||||
case swaggerTypeInteger:
|
||||
val, _ := strconv.ParseInt(s, 10, 64)
|
||||
return val
|
||||
case "boolean":
|
||||
case swaggerTypeBoolean:
|
||||
val, _ := strconv.ParseBool(s)
|
||||
return val
|
||||
case "number":
|
||||
case swaggerTypeNumber:
|
||||
val, _ := strconv.ParseFloat(s, 64)
|
||||
return val
|
||||
case "string":
|
||||
case swaggerTypeArray:
|
||||
return s
|
||||
case swaggerTypeString:
|
||||
return s
|
||||
default:
|
||||
return nil
|
||||
|
||||
@@ -161,7 +161,7 @@ func TestDefValueFromOptions(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := defValueFromOptions(tt.options, tt.apiType)
|
||||
result := defValueFromOptions(testingContext(t), tt.options, tt.apiType)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
@@ -202,7 +202,7 @@ func TestExampleValueFromOptions(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
exampleValueFromOptions(tt.options, tt.apiType)
|
||||
exampleValueFromOptions(testingContext(t), tt.options, tt.apiType)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -247,7 +247,7 @@ func TestValueFromOptions(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := valueFromOptions(tt.options, tt.key, tt.tp)
|
||||
result := valueFromOptions(testingContext(t), tt.options, tt.key, tt.tp)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -8,7 +8,25 @@ import (
|
||||
apiSpec "github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||
)
|
||||
|
||||
func parametersFromType(method string, tp apiSpec.Type) []spec.Parameter {
|
||||
func isPostJson(ctx Context, method string, tp apiSpec.Type) (string, bool) {
|
||||
if strings.EqualFold(method, http.MethodPost) {
|
||||
return "", false
|
||||
}
|
||||
structType, ok := tp.(apiSpec.DefineStruct)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
var isPostJson bool
|
||||
rangeMemberAndDo(ctx, structType, func(tag *apiSpec.Tags, required bool, member apiSpec.Member) {
|
||||
jsonTag, _ := tag.Get(tagJson)
|
||||
if !isPostJson {
|
||||
isPostJson = jsonTag != nil
|
||||
}
|
||||
})
|
||||
return structType.RawName, isPostJson
|
||||
}
|
||||
|
||||
func parametersFromType(ctx Context, method string, tp apiSpec.Type) []spec.Parameter {
|
||||
if tp == nil {
|
||||
return []spec.Parameter{}
|
||||
}
|
||||
@@ -16,12 +34,13 @@ func parametersFromType(method string, tp apiSpec.Type) []spec.Parameter {
|
||||
if !ok {
|
||||
return []spec.Parameter{}
|
||||
}
|
||||
|
||||
var (
|
||||
resp []spec.Parameter
|
||||
properties = map[string]spec.Schema{}
|
||||
requiredFields []string
|
||||
)
|
||||
rangeMemberAndDo(structType, func(tag *apiSpec.Tags, required bool, member apiSpec.Member) {
|
||||
rangeMemberAndDo(ctx, structType, func(tag *apiSpec.Tags, required bool, member apiSpec.Member) {
|
||||
headerTag, _ := tag.Get(tagHeader)
|
||||
hasHeader := headerTag != nil
|
||||
|
||||
@@ -44,10 +63,9 @@ func parametersFromType(method string, tp apiSpec.Type) []spec.Parameter {
|
||||
Enum: enumsValueFromOptions(headerTag.Options),
|
||||
},
|
||||
SimpleSchema: spec.SimpleSchema{
|
||||
Type: sampleTypeFromGoType(member.Type),
|
||||
Default: defValueFromOptions(headerTag.Options, member.Type),
|
||||
Example: exampleValueFromOptions(headerTag.Options, member.Type),
|
||||
Items: sampleItemsFromGoType(member.Type),
|
||||
Type: sampleTypeFromGoType(ctx, member.Type),
|
||||
Default: defValueFromOptions(ctx, headerTag.Options, member.Type),
|
||||
Items: sampleItemsFromGoType(ctx, member.Type),
|
||||
},
|
||||
ParamProps: spec.ParamProps{
|
||||
In: paramsInHeader,
|
||||
@@ -68,10 +86,9 @@ func parametersFromType(method string, tp apiSpec.Type) []spec.Parameter {
|
||||
Enum: enumsValueFromOptions(pathParameterTag.Options),
|
||||
},
|
||||
SimpleSchema: spec.SimpleSchema{
|
||||
Type: sampleTypeFromGoType(member.Type),
|
||||
Default: defValueFromOptions(pathParameterTag.Options, member.Type),
|
||||
Example: exampleValueFromOptions(pathParameterTag.Options, member.Type),
|
||||
Items: sampleItemsFromGoType(member.Type),
|
||||
Type: sampleTypeFromGoType(ctx, member.Type),
|
||||
Default: defValueFromOptions(ctx, pathParameterTag.Options, member.Type),
|
||||
Items: sampleItemsFromGoType(ctx, member.Type),
|
||||
},
|
||||
ParamProps: spec.ParamProps{
|
||||
In: paramsInPath,
|
||||
@@ -93,10 +110,9 @@ func parametersFromType(method string, tp apiSpec.Type) []spec.Parameter {
|
||||
Enum: enumsValueFromOptions(formTag.Options),
|
||||
},
|
||||
SimpleSchema: spec.SimpleSchema{
|
||||
Type: sampleTypeFromGoType(member.Type),
|
||||
Default: defValueFromOptions(formTag.Options, member.Type),
|
||||
Example: exampleValueFromOptions(formTag.Options, member.Type),
|
||||
Items: sampleItemsFromGoType(member.Type),
|
||||
Type: sampleTypeFromGoType(ctx, member.Type),
|
||||
Default: defValueFromOptions(ctx, formTag.Options, member.Type),
|
||||
Items: sampleItemsFromGoType(ctx, member.Type),
|
||||
},
|
||||
ParamProps: spec.ParamProps{
|
||||
In: paramsInQuery,
|
||||
@@ -116,10 +132,9 @@ func parametersFromType(method string, tp apiSpec.Type) []spec.Parameter {
|
||||
Enum: enumsValueFromOptions(formTag.Options),
|
||||
},
|
||||
SimpleSchema: spec.SimpleSchema{
|
||||
Type: sampleTypeFromGoType(member.Type),
|
||||
Default: defValueFromOptions(formTag.Options, member.Type),
|
||||
Example: exampleValueFromOptions(formTag.Options, member.Type),
|
||||
Items: sampleItemsFromGoType(member.Type),
|
||||
Type: sampleTypeFromGoType(ctx, member.Type),
|
||||
Default: defValueFromOptions(ctx, formTag.Options, member.Type),
|
||||
Items: sampleItemsFromGoType(ctx, member.Type),
|
||||
},
|
||||
ParamProps: spec.ParamProps{
|
||||
In: paramsInForm,
|
||||
@@ -139,25 +154,25 @@ func parametersFromType(method string, tp apiSpec.Type) []spec.Parameter {
|
||||
}
|
||||
var schema = spec.Schema{
|
||||
SwaggerSchemaProps: spec.SwaggerSchemaProps{
|
||||
Example: exampleValueFromOptions(jsonTag.Options, member.Type),
|
||||
Example: exampleValueFromOptions(ctx, jsonTag.Options, member.Type),
|
||||
},
|
||||
SchemaProps: spec.SchemaProps{
|
||||
Description: formatComment(member.Comment),
|
||||
Type: typeFromGoType(member.Type),
|
||||
Default: defValueFromOptions(jsonTag.Options, member.Type),
|
||||
Type: typeFromGoType(ctx, member.Type),
|
||||
Default: defValueFromOptions(ctx, jsonTag.Options, member.Type),
|
||||
Maximum: maximum,
|
||||
ExclusiveMaximum: exclusiveMaximum,
|
||||
Minimum: minimum,
|
||||
ExclusiveMinimum: exclusiveMinimum,
|
||||
Enum: enumsValueFromOptions(jsonTag.Options),
|
||||
AdditionalProperties: mapFromGoType(member.Type),
|
||||
AdditionalProperties: mapFromGoType(ctx, member.Type),
|
||||
},
|
||||
}
|
||||
switch sampleTypeFromGoType(member.Type) {
|
||||
switch sampleTypeFromGoType(ctx, member.Type) {
|
||||
case swaggerTypeArray:
|
||||
schema.Items = itemsFromGoType(member.Type)
|
||||
schema.Items = itemsFromGoType(ctx, member.Type)
|
||||
case swaggerTypeObject:
|
||||
p, r := propertiesFromType(member.Type)
|
||||
p, r := propertiesFromType(ctx, member.Type)
|
||||
schema.Properties = p
|
||||
schema.Required = r
|
||||
}
|
||||
@@ -165,20 +180,38 @@ func parametersFromType(method string, tp apiSpec.Type) []spec.Parameter {
|
||||
}
|
||||
})
|
||||
if len(properties) > 0 {
|
||||
resp = append(resp, spec.Parameter{
|
||||
ParamProps: spec.ParamProps{
|
||||
In: paramsInBody,
|
||||
Name: paramsInBody,
|
||||
Required: true,
|
||||
Schema: &spec.Schema{
|
||||
SchemaProps: spec.SchemaProps{
|
||||
Type: typeFromGoType(structType),
|
||||
Properties: properties,
|
||||
Required: requiredFields,
|
||||
if ctx.UseDefinitions {
|
||||
structName, ok := isPostJson(ctx, method, tp)
|
||||
if ok {
|
||||
resp = append(resp, spec.Parameter{
|
||||
ParamProps: spec.ParamProps{
|
||||
In: paramsInBody,
|
||||
Name: paramsInBody,
|
||||
Required: true,
|
||||
Schema: &spec.Schema{
|
||||
SchemaProps: spec.SchemaProps{
|
||||
Ref: spec.MustCreateRef(getRefName(structName)),
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
} else {
|
||||
resp = append(resp, spec.Parameter{
|
||||
ParamProps: spec.ParamProps{
|
||||
In: paramsInBody,
|
||||
Name: paramsInBody,
|
||||
Required: true,
|
||||
Schema: &spec.Schema{
|
||||
SchemaProps: spec.SchemaProps{
|
||||
Type: typeFromGoType(ctx, structType),
|
||||
Properties: properties,
|
||||
Required: requiredFields,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
@@ -7,20 +7,21 @@ import (
|
||||
|
||||
"github.com/go-openapi/spec"
|
||||
apiSpec "github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/util/stringx"
|
||||
)
|
||||
|
||||
func spec2Paths(info apiSpec.Info, srv apiSpec.Service) *spec.Paths {
|
||||
func spec2Paths(ctx Context, srv apiSpec.Service) *spec.Paths {
|
||||
paths := &spec.Paths{
|
||||
Paths: make(map[string]spec.PathItem),
|
||||
}
|
||||
for _, group := range srv.Groups {
|
||||
prefix := path.Clean(strings.TrimPrefix(group.GetAnnotation("prefix"), "/"))
|
||||
prefix := path.Clean(strings.TrimPrefix(group.GetAnnotation(propertyKeyPrefix), "/"))
|
||||
for _, route := range group.Routes {
|
||||
routPath := pathVariable2SwaggerVariable(route.Path)
|
||||
routPath := pathVariable2SwaggerVariable(ctx, route.Path)
|
||||
if len(prefix) > 0 && prefix != "." {
|
||||
routPath = "/" + path.Clean(prefix) + routPath
|
||||
}
|
||||
pathItem := spec2Path(info, group, route)
|
||||
pathItem := spec2Path(ctx, group, route)
|
||||
existPathItem, ok := paths.Paths[routPath]
|
||||
if !ok {
|
||||
paths.Paths[routPath] = pathItem
|
||||
@@ -60,8 +61,8 @@ func mergePathItem(old, new spec.PathItem) spec.PathItem {
|
||||
return old
|
||||
}
|
||||
|
||||
func spec2Path(info apiSpec.Info, group apiSpec.Group, route apiSpec.Route) spec.PathItem {
|
||||
authType := getStringFromKVOrDefault(group.Annotation.Properties, "authType", "")
|
||||
func spec2Path(ctx Context, group apiSpec.Group, route apiSpec.Route) spec.PathItem {
|
||||
authType := getStringFromKVOrDefault(group.Annotation.Properties, propertyKeyAuthType, "")
|
||||
var security []map[string][]string
|
||||
if len(authType) > 0 {
|
||||
security = []map[string][]string{
|
||||
@@ -70,22 +71,29 @@ func spec2Path(info apiSpec.Info, group apiSpec.Group, route apiSpec.Route) spec
|
||||
},
|
||||
}
|
||||
}
|
||||
groupName := getStringFromKVOrDefault(group.Annotation.Properties, propertyKeyGroup, "")
|
||||
operationId := route.Handler
|
||||
if len(groupName) > 0 {
|
||||
operationId = stringx.From(groupName + "_" + route.Handler).ToCamel()
|
||||
}
|
||||
operationId = stringx.From(operationId).Untitle()
|
||||
op := &spec.Operation{
|
||||
OperationProps: spec.OperationProps{
|
||||
Description: getStringFromKVOrDefault(route.AtDoc.Properties, "description", ""),
|
||||
Consumes: consumesFromTypeOrDef(route.Method, route.RequestType),
|
||||
Produces: getListFromInfoOrDefault(route.AtDoc.Properties, "produces", []string{applicationJson}),
|
||||
Schemes: getListFromInfoOrDefault(route.AtDoc.Properties, "schemes", []string{schemeHttps}),
|
||||
Tags: getListFromInfoOrDefault(group.Annotation.Properties, "tags", []string{""}),
|
||||
Summary: getStringFromKVOrDefault(route.AtDoc.Properties, "summary", ""),
|
||||
Deprecated: getBoolFromKVOrDefault(route.AtDoc.Properties, "deprecated", false),
|
||||
Parameters: parametersFromType(route.Method, route.RequestType),
|
||||
Responses: jsonResponseFromType(info, route.ResponseType),
|
||||
Description: getStringFromKVOrDefault(route.AtDoc.Properties, propertyKeyDescription, ""),
|
||||
Consumes: consumesFromTypeOrDef(ctx, route.Method, route.RequestType),
|
||||
Produces: getListFromInfoOrDefault(route.AtDoc.Properties, propertyKeyProduces, []string{applicationJson}),
|
||||
Schemes: getListFromInfoOrDefault(route.AtDoc.Properties, propertyKeySchemes, []string{schemeHttps}),
|
||||
Tags: getListFromInfoOrDefault(group.Annotation.Properties, propertyKeyTags, getListFromInfoOrDefault(group.Annotation.Properties, propertyKeySummary, []string{})),
|
||||
Summary: getStringFromKVOrDefault(route.AtDoc.Properties, propertyKeySummary, getFirstUsableString(route.AtDoc.Text, route.Handler)),
|
||||
ID: operationId,
|
||||
Deprecated: getBoolFromKVOrDefault(route.AtDoc.Properties, propertyKeyDeprecated, false),
|
||||
Parameters: parametersFromType(ctx, route.Method, route.RequestType),
|
||||
Security: security,
|
||||
Responses: jsonResponseFromType(ctx, route.AtDoc, route.ResponseType),
|
||||
},
|
||||
}
|
||||
externalDocsDescription := getStringFromKVOrDefault(route.AtDoc.Properties, "externalDocsDescription", "")
|
||||
externalDocsURL := getStringFromKVOrDefault(route.AtDoc.Properties, "externalDocsURL", "")
|
||||
externalDocsDescription := getStringFromKVOrDefault(route.AtDoc.Properties, propertyKeyExternalDocsDescription, "")
|
||||
externalDocsURL := getStringFromKVOrDefault(route.AtDoc.Properties, propertyKeyExternalDocsURL, "")
|
||||
if len(externalDocsDescription) > 0 || len(externalDocsURL) > 0 {
|
||||
op.ExternalDocs = &spec.ExternalDocumentation{
|
||||
Description: externalDocsDescription,
|
||||
|
||||
@@ -5,18 +5,18 @@ import (
|
||||
apiSpec "github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||
)
|
||||
|
||||
func propertiesFromType(tp apiSpec.Type) (spec.SchemaProperties, []string) {
|
||||
func propertiesFromType(ctx Context, tp apiSpec.Type) (spec.SchemaProperties, []string) {
|
||||
var (
|
||||
properties = map[string]spec.Schema{}
|
||||
requiredFields []string
|
||||
)
|
||||
switch val := tp.(type) {
|
||||
case apiSpec.PointerType:
|
||||
return propertiesFromType(val.Type)
|
||||
return propertiesFromType(ctx, val.Type)
|
||||
case apiSpec.ArrayType:
|
||||
return propertiesFromType(val.Value)
|
||||
return propertiesFromType(ctx, val.Value)
|
||||
case apiSpec.DefineStruct, apiSpec.NestedStruct:
|
||||
rangeMemberAndDo(val, func(tag *apiSpec.Tags, required bool, member apiSpec.Member) {
|
||||
rangeMemberAndDo(ctx, val, func(tag *apiSpec.Tags, required bool, member apiSpec.Member) {
|
||||
var (
|
||||
jsonTagString = member.Name
|
||||
minimum, maximum *float64
|
||||
@@ -24,42 +24,63 @@ func propertiesFromType(tp apiSpec.Type) (spec.SchemaProperties, []string) {
|
||||
example, defaultValue any
|
||||
enum []any
|
||||
)
|
||||
pathTag, _ := tag.Get(tagPath)
|
||||
if pathTag != nil {
|
||||
return
|
||||
}
|
||||
formTag, _ := tag.Get(tagForm)
|
||||
if formTag != nil {
|
||||
return
|
||||
}
|
||||
headerTag, _ := tag.Get(tagHeader)
|
||||
if headerTag != nil {
|
||||
return
|
||||
}
|
||||
|
||||
jsonTag, _ := tag.Get(tagJson)
|
||||
if jsonTag != nil {
|
||||
jsonTagString = jsonTag.Name
|
||||
minimum, maximum, exclusiveMinimum, exclusiveMaximum = rangeValueFromOptions(jsonTag.Options)
|
||||
example = exampleValueFromOptions(jsonTag.Options, member.Type)
|
||||
defaultValue = defValueFromOptions(jsonTag.Options, member.Type)
|
||||
example = exampleValueFromOptions(ctx, jsonTag.Options, member.Type)
|
||||
defaultValue = defValueFromOptions(ctx, jsonTag.Options, member.Type)
|
||||
enum = enumsValueFromOptions(jsonTag.Options)
|
||||
}
|
||||
|
||||
if required {
|
||||
requiredFields = append(requiredFields, jsonTagString)
|
||||
}
|
||||
var schema = spec.Schema{
|
||||
|
||||
schema := spec.Schema{
|
||||
SwaggerSchemaProps: spec.SwaggerSchemaProps{
|
||||
Example: example,
|
||||
},
|
||||
SchemaProps: spec.SchemaProps{
|
||||
Description: formatComment(member.Comment),
|
||||
Type: typeFromGoType(member.Type),
|
||||
Type: typeFromGoType(ctx, member.Type),
|
||||
Default: defaultValue,
|
||||
Maximum: maximum,
|
||||
ExclusiveMaximum: exclusiveMaximum,
|
||||
Minimum: minimum,
|
||||
ExclusiveMinimum: exclusiveMinimum,
|
||||
Enum: enum,
|
||||
AdditionalProperties: mapFromGoType(member.Type),
|
||||
AdditionalProperties: mapFromGoType(ctx, member.Type),
|
||||
},
|
||||
}
|
||||
switch sampleTypeFromGoType(member.Type) {
|
||||
|
||||
switch sampleTypeFromGoType(ctx, member.Type) {
|
||||
case swaggerTypeArray:
|
||||
schema.Items = itemsFromGoType(member.Type)
|
||||
schema.Items = itemsFromGoType(ctx, member.Type)
|
||||
case swaggerTypeObject:
|
||||
p, r := propertiesFromType(member.Type)
|
||||
p, r := propertiesFromType(ctx, member.Type)
|
||||
schema.Properties = p
|
||||
schema.Required = r
|
||||
}
|
||||
if ctx.UseDefinitions {
|
||||
structName, containsStruct := containsStruct(member.Type)
|
||||
if containsStruct {
|
||||
schema.SchemaProps.Ref = spec.MustCreateRef(getRefName(structName))
|
||||
}
|
||||
}
|
||||
|
||||
properties[jsonTagString] = schema
|
||||
})
|
||||
@@ -67,3 +88,22 @@ func propertiesFromType(tp apiSpec.Type) (spec.SchemaProperties, []string) {
|
||||
|
||||
return properties, requiredFields
|
||||
}
|
||||
|
||||
func containsStruct(tp apiSpec.Type) (string, bool) {
|
||||
switch val := tp.(type) {
|
||||
case apiSpec.PointerType:
|
||||
return containsStruct(val.Type)
|
||||
case apiSpec.ArrayType:
|
||||
return containsStruct(val.Value)
|
||||
case apiSpec.DefineStruct:
|
||||
return val.RawName, true
|
||||
case apiSpec.MapType:
|
||||
return containsStruct(val.Value)
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func getRefName(typeName string) string {
|
||||
return "#/definitions/" + typeName
|
||||
}
|
||||
|
||||
@@ -1,25 +1,62 @@
|
||||
package swagger
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/go-openapi/spec"
|
||||
apiSpec "github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||
)
|
||||
|
||||
func jsonResponseFromType(info apiSpec.Info, tp apiSpec.Type) *spec.Responses {
|
||||
p, _ := propertiesFromType(tp)
|
||||
func jsonResponseFromType(ctx Context, atDoc apiSpec.AtDoc, tp apiSpec.Type) *spec.Responses {
|
||||
if tp == nil {
|
||||
return &spec.Responses{
|
||||
ResponsesProps: spec.ResponsesProps{
|
||||
StatusCodeResponses: map[int]spec.Response{
|
||||
http.StatusOK: {
|
||||
ResponseProps: spec.ResponseProps{
|
||||
Description: "",
|
||||
Schema: &spec.Schema{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
props := spec.SchemaProps{
|
||||
Type: typeFromGoType(tp),
|
||||
Properties: p,
|
||||
AdditionalProperties: mapFromGoType(tp),
|
||||
Items: itemsFromGoType(tp),
|
||||
AdditionalProperties: mapFromGoType(ctx, tp),
|
||||
Items: itemsFromGoType(ctx, tp),
|
||||
}
|
||||
if ctx.UseDefinitions {
|
||||
structName, ok := containsStruct(tp)
|
||||
if ok {
|
||||
props.Ref = spec.MustCreateRef(getRefName(structName))
|
||||
return &spec.Responses{
|
||||
ResponsesProps: spec.ResponsesProps{
|
||||
StatusCodeResponses: map[int]spec.Response{
|
||||
http.StatusOK: {
|
||||
ResponseProps: spec.ResponseProps{
|
||||
Schema: &spec.Schema{
|
||||
SchemaProps: wrapCodeMsgProps(ctx, props, atDoc),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
p, _ := propertiesFromType(ctx, tp)
|
||||
props.Type = typeFromGoType(ctx, tp)
|
||||
props.Properties = p
|
||||
return &spec.Responses{
|
||||
ResponsesProps: spec.ResponsesProps{
|
||||
Default: &spec.Response{
|
||||
ResponseProps: spec.ResponseProps{
|
||||
Schema: &spec.Schema{
|
||||
SchemaProps: wrapCodeMsgProps(props, info),
|
||||
StatusCodeResponses: map[int]spec.Response{
|
||||
http.StatusOK: {
|
||||
ResponseProps: spec.ResponseProps{
|
||||
Schema: &spec.Schema{
|
||||
SchemaProps: wrapCodeMsgProps(ctx, props, atDoc),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -8,12 +8,11 @@ import (
|
||||
"github.com/go-openapi/spec"
|
||||
apiSpec "github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/internal/version"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/util"
|
||||
)
|
||||
|
||||
func spec2Swagger(api *apiSpec.ApiSpec) (*spec.Swagger, error) {
|
||||
ctx := contextFromApi(api.Info)
|
||||
extensions, info := specExtensions(api.Info)
|
||||
|
||||
var securityDefinitions spec.SecurityDefinitions
|
||||
securityDefinitionsFromJson := getStringFromKVOrDefault(api.Info.Properties, "securityDefinitionsFromJson", `{}`)
|
||||
_ = json.Unmarshal([]byte(securityDefinitionsFromJson), &securityDefinitions)
|
||||
@@ -22,14 +21,15 @@ func spec2Swagger(api *apiSpec.ApiSpec) (*spec.Swagger, error) {
|
||||
Extensions: extensions,
|
||||
},
|
||||
SwaggerProps: spec.SwaggerProps{
|
||||
Consumes: getListFromInfoOrDefault(api.Info.Properties, "consumes", []string{applicationJson}),
|
||||
Produces: getListFromInfoOrDefault(api.Info.Properties, "produces", []string{applicationJson}),
|
||||
Schemes: getListFromInfoOrDefault(api.Info.Properties, "schemes", []string{schemeHttps}),
|
||||
Definitions: definitionsFromTypes(ctx, api.Types),
|
||||
Consumes: getListFromInfoOrDefault(api.Info.Properties, propertyKeyConsumes, []string{applicationJson}),
|
||||
Produces: getListFromInfoOrDefault(api.Info.Properties, propertyKeyProduces, []string{applicationJson}),
|
||||
Schemes: getListFromInfoOrDefault(api.Info.Properties, propertyKeySchemes, []string{schemeHttps}),
|
||||
Swagger: swaggerVersion,
|
||||
Info: info,
|
||||
Host: getStringFromKVOrDefault(api.Info.Properties, "host", defaultHost),
|
||||
BasePath: getStringFromKVOrDefault(api.Info.Properties, "basePath", defaultBasePath),
|
||||
Paths: spec2Paths(api.Info, api.Service),
|
||||
Host: getStringFromKVOrDefault(api.Info.Properties, propertyKeyHost, ""),
|
||||
BasePath: getStringFromKVOrDefault(api.Info.Properties, propertyKeyBasePath, defaultBasePath),
|
||||
Paths: spec2Paths(ctx, api.Service),
|
||||
SecurityDefinitions: securityDefinitions,
|
||||
},
|
||||
}
|
||||
@@ -42,7 +42,7 @@ func formatComment(comment string) string {
|
||||
return strings.TrimSpace(s)
|
||||
}
|
||||
|
||||
func sampleItemsFromGoType(tp apiSpec.Type) *spec.Items {
|
||||
func sampleItemsFromGoType(ctx Context, tp apiSpec.Type) *spec.Items {
|
||||
val, ok := tp.(apiSpec.ArrayType)
|
||||
if !ok {
|
||||
return nil
|
||||
@@ -52,14 +52,14 @@ func sampleItemsFromGoType(tp apiSpec.Type) *spec.Items {
|
||||
case apiSpec.PrimitiveType:
|
||||
return &spec.Items{
|
||||
SimpleSchema: spec.SimpleSchema{
|
||||
Type: sampleTypeFromGoType(item),
|
||||
Type: sampleTypeFromGoType(ctx, item),
|
||||
},
|
||||
}
|
||||
case apiSpec.ArrayType:
|
||||
return &spec.Items{
|
||||
SimpleSchema: spec.SimpleSchema{
|
||||
Type: sampleTypeFromGoType(item),
|
||||
Items: sampleItemsFromGoType(item),
|
||||
Type: sampleTypeFromGoType(ctx, item),
|
||||
Items: sampleItemsFromGoType(ctx, item),
|
||||
},
|
||||
}
|
||||
default: // unsupported type
|
||||
@@ -68,30 +68,30 @@ func sampleItemsFromGoType(tp apiSpec.Type) *spec.Items {
|
||||
}
|
||||
|
||||
// itemsFromGoType returns the schema or array of the type, just for non json body parameters.
|
||||
func itemsFromGoType(tp apiSpec.Type) *spec.SchemaOrArray {
|
||||
func itemsFromGoType(ctx Context, tp apiSpec.Type) *spec.SchemaOrArray {
|
||||
array, ok := tp.(apiSpec.ArrayType)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return itemFromGoType(array.Value)
|
||||
return itemFromGoType(ctx, array.Value)
|
||||
}
|
||||
|
||||
func mapFromGoType(tp apiSpec.Type) *spec.SchemaOrBool {
|
||||
func mapFromGoType(ctx Context, tp apiSpec.Type) *spec.SchemaOrBool {
|
||||
mapType, ok := tp.(apiSpec.MapType)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
var schema = &spec.Schema{
|
||||
SchemaProps: spec.SchemaProps{
|
||||
Type: typeFromGoType(mapType.Value),
|
||||
AdditionalProperties: mapFromGoType(mapType.Value),
|
||||
Type: typeFromGoType(ctx, mapType.Value),
|
||||
AdditionalProperties: mapFromGoType(ctx, mapType.Value),
|
||||
},
|
||||
}
|
||||
switch sampleTypeFromGoType(mapType.Value) {
|
||||
switch sampleTypeFromGoType(ctx, mapType.Value) {
|
||||
case swaggerTypeArray:
|
||||
schema.Items = itemsFromGoType(mapType.Value)
|
||||
schema.Items = itemsFromGoType(ctx, mapType.Value)
|
||||
case swaggerTypeObject:
|
||||
p, r := propertiesFromType(mapType.Value)
|
||||
p, r := propertiesFromType(ctx, mapType.Value)
|
||||
schema.Properties = p
|
||||
schema.Required = r
|
||||
}
|
||||
@@ -102,37 +102,37 @@ func mapFromGoType(tp apiSpec.Type) *spec.SchemaOrBool {
|
||||
}
|
||||
|
||||
// itemFromGoType returns the schema or array of the type, just for non json body parameters.
|
||||
func itemFromGoType(tp apiSpec.Type) *spec.SchemaOrArray {
|
||||
func itemFromGoType(ctx Context, tp apiSpec.Type) *spec.SchemaOrArray {
|
||||
switch itemType := tp.(type) {
|
||||
case apiSpec.PrimitiveType:
|
||||
return &spec.SchemaOrArray{
|
||||
Schema: &spec.Schema{
|
||||
SchemaProps: spec.SchemaProps{
|
||||
Type: typeFromGoType(tp),
|
||||
Type: typeFromGoType(ctx, tp),
|
||||
},
|
||||
},
|
||||
}
|
||||
case apiSpec.DefineStruct, apiSpec.NestedStruct:
|
||||
properties, requiredFields := propertiesFromType(itemType)
|
||||
case apiSpec.DefineStruct, apiSpec.NestedStruct, apiSpec.MapType:
|
||||
properties, requiredFields := propertiesFromType(ctx, itemType)
|
||||
return &spec.SchemaOrArray{
|
||||
Schema: &spec.Schema{
|
||||
SchemaProps: spec.SchemaProps{
|
||||
Type: typeFromGoType(itemType),
|
||||
Items: itemsFromGoType(itemType),
|
||||
Type: typeFromGoType(ctx, itemType),
|
||||
Items: itemsFromGoType(ctx, itemType),
|
||||
Properties: properties,
|
||||
Required: requiredFields,
|
||||
AdditionalProperties: mapFromGoType(itemType),
|
||||
AdditionalProperties: mapFromGoType(ctx, itemType),
|
||||
},
|
||||
},
|
||||
}
|
||||
case apiSpec.PointerType:
|
||||
return itemFromGoType(itemType.Type)
|
||||
return itemFromGoType(ctx, itemType.Type)
|
||||
case apiSpec.ArrayType:
|
||||
return &spec.SchemaOrArray{
|
||||
Schema: &spec.Schema{
|
||||
SchemaProps: spec.SchemaProps{
|
||||
Type: typeFromGoType(itemType),
|
||||
Items: itemsFromGoType(itemType),
|
||||
Type: typeFromGoType(ctx, itemType),
|
||||
Items: itemsFromGoType(ctx, itemType),
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -140,7 +140,7 @@ func itemFromGoType(tp apiSpec.Type) *spec.SchemaOrArray {
|
||||
return nil
|
||||
}
|
||||
|
||||
func typeFromGoType(tp apiSpec.Type) []string {
|
||||
func typeFromGoType(ctx Context, tp apiSpec.Type) []string {
|
||||
switch val := tp.(type) {
|
||||
case apiSpec.PrimitiveType:
|
||||
res, ok := tpMapper[val.RawName]
|
||||
@@ -152,12 +152,12 @@ func typeFromGoType(tp apiSpec.Type) []string {
|
||||
case apiSpec.DefineStruct, apiSpec.MapType:
|
||||
return []string{swaggerTypeObject}
|
||||
case apiSpec.PointerType:
|
||||
return typeFromGoType(val.Type)
|
||||
return typeFromGoType(ctx, val.Type)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func sampleTypeFromGoType(tp apiSpec.Type) string {
|
||||
func sampleTypeFromGoType(ctx Context, tp apiSpec.Type) string {
|
||||
switch val := tp.(type) {
|
||||
case apiSpec.PrimitiveType:
|
||||
return tpMapper[val.RawName]
|
||||
@@ -166,31 +166,30 @@ func sampleTypeFromGoType(tp apiSpec.Type) string {
|
||||
case apiSpec.DefineStruct, apiSpec.MapType, apiSpec.NestedStruct:
|
||||
return swaggerTypeObject
|
||||
case apiSpec.PointerType:
|
||||
return sampleTypeFromGoType(val.Type)
|
||||
return sampleTypeFromGoType(ctx, val.Type)
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func typeContainsTag(structType apiSpec.DefineStruct, tag string) bool {
|
||||
for _, field := range structType.Members {
|
||||
tags, _ := apiSpec.Parse(field.Tag)
|
||||
for _, t := range tags.Tags() {
|
||||
if t.Key == tag {
|
||||
return true
|
||||
}
|
||||
func typeContainsTag(ctx Context, structType apiSpec.DefineStruct, tag string) bool {
|
||||
members := expandMembers(ctx, structType)
|
||||
for _, member := range members {
|
||||
tags, _ := apiSpec.Parse(member.Tag)
|
||||
if _, err := tags.Get(tag); err == nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func expandMembers(tp apiSpec.Type) []apiSpec.Member {
|
||||
func expandMembers(ctx Context, tp apiSpec.Type) []apiSpec.Member {
|
||||
var members []apiSpec.Member
|
||||
switch val := tp.(type) {
|
||||
case apiSpec.DefineStruct:
|
||||
for _, v := range val.Members {
|
||||
if v.IsInline {
|
||||
members = append(members, expandMembers(v.Type)...)
|
||||
members = append(members, expandMembers(ctx, v.Type)...)
|
||||
continue
|
||||
}
|
||||
members = append(members, v)
|
||||
@@ -198,7 +197,7 @@ func expandMembers(tp apiSpec.Type) []apiSpec.Member {
|
||||
case apiSpec.NestedStruct:
|
||||
for _, v := range val.Members {
|
||||
if v.IsInline {
|
||||
members = append(members, expandMembers(v.Type)...)
|
||||
members = append(members, expandMembers(ctx, v.Type)...)
|
||||
continue
|
||||
}
|
||||
members = append(members, v)
|
||||
@@ -208,42 +207,42 @@ func expandMembers(tp apiSpec.Type) []apiSpec.Member {
|
||||
return members
|
||||
}
|
||||
|
||||
func rangeMemberAndDo(structType apiSpec.Type, do func(tag *apiSpec.Tags, required bool, member apiSpec.Member)) {
|
||||
var members = expandMembers(structType)
|
||||
func rangeMemberAndDo(ctx Context, structType apiSpec.Type, do func(tag *apiSpec.Tags, required bool, member apiSpec.Member)) {
|
||||
var members = expandMembers(ctx, structType)
|
||||
|
||||
for _, field := range members {
|
||||
tags, _ := apiSpec.Parse(field.Tag)
|
||||
required := isRequired(tags)
|
||||
required := isRequired(ctx, tags)
|
||||
do(tags, required, field)
|
||||
}
|
||||
}
|
||||
|
||||
func isRequired(tags *apiSpec.Tags) bool {
|
||||
func isRequired(ctx Context, tags *apiSpec.Tags) bool {
|
||||
tag, err := tags.Get(tagJson)
|
||||
if err == nil {
|
||||
return !isOptional(tag.Options)
|
||||
return !isOptional(ctx, tag.Options)
|
||||
}
|
||||
tag, err = tags.Get(tagForm)
|
||||
if err == nil {
|
||||
return !isOptional(tag.Options)
|
||||
return !isOptional(ctx, tag.Options)
|
||||
}
|
||||
tag, err = tags.Get(tagPath)
|
||||
if err == nil {
|
||||
return !isOptional(tag.Options)
|
||||
return !isOptional(ctx, tag.Options)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isOptional(options []string) bool {
|
||||
func isOptional(_ Context, options []string) bool {
|
||||
for _, option := range options {
|
||||
if option == "optional" {
|
||||
if option == optionalFlag {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func pathVariable2SwaggerVariable(path string) string {
|
||||
func pathVariable2SwaggerVariable(_ Context, path string) string {
|
||||
pathItems := strings.FieldsFunc(path, slashRune)
|
||||
var resp []string
|
||||
for _, v := range pathItems {
|
||||
@@ -256,11 +255,12 @@ func pathVariable2SwaggerVariable(path string) string {
|
||||
return "/" + strings.Join(resp, "/")
|
||||
}
|
||||
|
||||
func wrapCodeMsgProps(properties spec.SchemaProps, api apiSpec.Info) spec.SchemaProps {
|
||||
wrapCodeMsg := getBoolFromKVOrDefault(api.Properties, "wrapCodeMsg", false)
|
||||
if !wrapCodeMsg {
|
||||
func wrapCodeMsgProps(ctx Context, properties spec.SchemaProps, atDoc apiSpec.AtDoc) spec.SchemaProps {
|
||||
if !ctx.WrapCodeMsg {
|
||||
return properties
|
||||
}
|
||||
globalCodeDesc := ctx.BizCodeEnumDescription
|
||||
methodCodeDesc := getStringFromKVOrDefault(atDoc.Properties, propertyKeyBizCodeEnumDescription, globalCodeDesc)
|
||||
return spec.SchemaProps{
|
||||
Type: []string{swaggerTypeObject},
|
||||
Properties: spec.SchemaProperties{
|
||||
@@ -270,7 +270,7 @@ func wrapCodeMsgProps(properties spec.SchemaProps, api apiSpec.Info) spec.Schema
|
||||
},
|
||||
SchemaProps: spec.SchemaProps{
|
||||
Type: []string{swaggerTypeInteger},
|
||||
Description: getStringFromKVOrDefault(api.Properties, "bizCodeEnumDescription", "business code"),
|
||||
Description: methodCodeDesc,
|
||||
},
|
||||
},
|
||||
"msg": {
|
||||
@@ -293,27 +293,27 @@ func specExtensions(api apiSpec.Info) (spec.Extensions, *spec.Info) {
|
||||
ext := spec.Extensions{}
|
||||
ext.Add("x-goctl-version", version.BuildVersion)
|
||||
ext.Add("x-description", "This is a goctl generated swagger file.")
|
||||
ext.Add("x-date", time.Now().Format("2006-01-02 15:04:05"))
|
||||
ext.Add("x-date", time.Now().Format(time.DateTime))
|
||||
ext.Add("x-github", "https://github.com/zeromicro/go-zero")
|
||||
ext.Add("x-go-zero-doc", "https://go-zero.dev/")
|
||||
|
||||
info := &spec.Info{}
|
||||
info.Description = util.Unquote(api.Properties["description"])
|
||||
info.Title = util.Unquote(api.Properties["title"])
|
||||
info.TermsOfService = util.Unquote(api.Properties["termsOfService"])
|
||||
info.Version = util.Unquote(api.Properties["version"])
|
||||
info.Title = getStringFromKVOrDefault(api.Properties, propertyKeyTitle, "")
|
||||
info.Description = getStringFromKVOrDefault(api.Properties, propertyKeyDescription, "")
|
||||
info.TermsOfService = getStringFromKVOrDefault(api.Properties, propertyKeyTermsOfService, "")
|
||||
info.Version = getStringFromKVOrDefault(api.Properties, propertyKeyVersion, "1.0")
|
||||
|
||||
contactInfo := spec.ContactInfo{}
|
||||
contactInfo.Name = util.Unquote(api.Properties["contactName"])
|
||||
contactInfo.URL = util.Unquote(api.Properties["contactURL"])
|
||||
contactInfo.Email = util.Unquote(api.Properties["contactEmail"])
|
||||
contactInfo.Name = getStringFromKVOrDefault(api.Properties, propertyKeyContactName, "")
|
||||
contactInfo.URL = getStringFromKVOrDefault(api.Properties, propertyKeyContactURL, "")
|
||||
contactInfo.Email = getStringFromKVOrDefault(api.Properties, propertyKeyContactEmail, "")
|
||||
if len(contactInfo.Name) > 0 || len(contactInfo.URL) > 0 || len(contactInfo.Email) > 0 {
|
||||
info.Contact = &contactInfo
|
||||
}
|
||||
|
||||
license := &spec.License{}
|
||||
license.Name = util.Unquote(api.Properties["licenseName"])
|
||||
license.URL = util.Unquote(api.Properties["licenseURL"])
|
||||
license.Name = getStringFromKVOrDefault(api.Properties, propertyKeyLicenseName, "")
|
||||
license.URL = getStringFromKVOrDefault(api.Properties, propertyKeyLicenseURL, "")
|
||||
if len(license.Name) > 0 || len(license.URL) > 0 {
|
||||
info.License = license
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ func Test_pathVariable2SwaggerVariable(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
result := pathVariable2SwaggerVariable(tc.input)
|
||||
result := pathVariable2SwaggerVariable(testingContext(t), tc.input)
|
||||
assert.Equal(t, tc.expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -34,6 +34,10 @@ func writeProperty(writer io.Writer, member spec.Member, indent int) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if strings.Contains(name, "-") {
|
||||
name = fmt.Sprintf("'%s'", name)
|
||||
}
|
||||
|
||||
comment := member.GetComment()
|
||||
if len(comment) > 0 {
|
||||
comment = strings.TrimPrefix(comment, "//")
|
||||
|
||||
2
tools/goctl/build.env
Normal file
2
tools/goctl/build.env
Normal file
@@ -0,0 +1,2 @@
|
||||
APP_NAME=goctl
|
||||
APP_VERSION=1.8.4-beta
|
||||
50
tools/goctl/build.sh
Normal file
50
tools/goctl/build.sh
Normal file
@@ -0,0 +1,50 @@
|
||||
#!/bin/bash
|
||||
|
||||
source build.env
|
||||
APP_NAME=$APP_NAME
|
||||
VERSION=$APP_VERSION
|
||||
BUILD_DIR="dist"
|
||||
ZIP_DIR="${BUILD_DIR}/zips"
|
||||
|
||||
PLATFORMS=(
|
||||
"linux/amd64"
|
||||
"linux/arm64"
|
||||
"darwin/amd64"
|
||||
"darwin/arm64"
|
||||
"windows/amd64"
|
||||
"windows/arm64"
|
||||
)
|
||||
|
||||
rm -rf "${BUILD_DIR}"
|
||||
mkdir -p "${ZIP_DIR}"
|
||||
|
||||
for PLATFORM in "${PLATFORMS[@]}"; do
|
||||
GOOS=${PLATFORM%/*}
|
||||
GOARCH=${PLATFORM#*/}
|
||||
|
||||
OUTPUT="${BUILD_DIR}/${APP_NAME}-${VERSION}-${GOOS}-${GOARCH}"
|
||||
|
||||
if [ "${GOOS}" = "windows" ]; then
|
||||
OUTPUT="${OUTPUT}.exe"
|
||||
fi
|
||||
|
||||
echo "Building for ${GOOS}/${GOARCH}..."
|
||||
|
||||
env GOOS="${GOOS}" GOARCH="${GOARCH}" go build -o "${OUTPUT}" goctl.go
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Error building for ${GOOS}/${GOARCH}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ZIP_OUTPUT="${ZIP_DIR}/$(basename "${OUTPUT}")"
|
||||
if [ "${GOOS}" = "windows" ]; then
|
||||
zip -j "${ZIP_OUTPUT%.exe}.zip" "${OUTPUT}"
|
||||
else
|
||||
zip -j "${ZIP_OUTPUT}.zip" "${OUTPUT}"
|
||||
fi
|
||||
|
||||
echo "Created zip: ${ZIP_OUTPUT}.zip"
|
||||
done
|
||||
|
||||
echo "All builds completed successfully. Zip files are in ${ZIP_DIR}/"
|
||||
103
tools/goctl/change.md
Normal file
103
tools/goctl/change.md
Normal file
@@ -0,0 +1,103 @@
|
||||
# 1.8.4-beta
|
||||
|
||||
## swagger
|
||||
- [features] Supported operation id for swagger
|
||||
## Other
|
||||
- Updated version to 1.8.4-beta
|
||||
|
||||
|
||||
# 1.8.4-alpha
|
||||
|
||||
## swagger
|
||||
1. [bug fix] remove example generation when request body are `query`, `path` and `header`
|
||||
- it not supported in api spec 2.0
|
||||
- it's will generate example when request body is json format.
|
||||
2. [features] swagger generation supported definitions
|
||||
- supported response definitions
|
||||
- supported json request body definitions
|
||||
- do not support query and form definitions, use parameters instead.
|
||||
|
||||
**How to use?**
|
||||
Use the `useDefinitions` keyword in the info code block of the API file to declare the enable. This value is a boolean value. When set to `true`, it will enable the generation of definitions. Otherwise, it will be generated according to properties, and the default is `false`, for example:
|
||||
|
||||
```go
|
||||
syntax = "v1"
|
||||
|
||||
info (
|
||||
...
|
||||
wrapCodeMsg: true
|
||||
useDefinitions: true
|
||||
)
|
||||
...
|
||||
```
|
||||
|
||||
the demo result of swagger.json
|
||||
|
||||
```json
|
||||
{
|
||||
...
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "",
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {
|
||||
"description": "1001-User not login\u003cbr\u003e1002-User permission denied",
|
||||
"type": "integer",
|
||||
"example": 0
|
||||
},
|
||||
"data": {
|
||||
"$ref": "#/definitions/FormResp"
|
||||
},
|
||||
"msg": {
|
||||
"description": "business message",
|
||||
"type": "string",
|
||||
"example": "ok"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
For a complete API example, please refer to the `api/swagger/example/example.api` file in pr. For a complete swagger result example, please refer to the `api/swagger/example/example.swagger.json` file in pr.
|
||||
|
||||
## 2. `goctl api go` code generation
|
||||
- [bug-fix] Add flag `--type-group` to control the output of types(deprecated: experimental switch control type grouping is no longer used), if true, the types in only one group will separate by file.
|
||||
- example `goctl api go --api demo.api --dir demo --type-group`
|
||||
- use `group` keyword in @server block to define the group name in api file, for example
|
||||
```go
|
||||
@server(
|
||||
group: user
|
||||
)
|
||||
service demo{
|
||||
...
|
||||
}
|
||||
```
|
||||
the example of separated types by file
|
||||
```
|
||||
.
|
||||
└── types
|
||||
├── common.go
|
||||
├── gotoolexport.go
|
||||
├── importfile.go
|
||||
├── process.go
|
||||
└── types.go
|
||||
```
|
||||
|
||||
## 3 API Parser
|
||||
- supported identifier value for info key-value in api parser
|
||||
for example
|
||||
|
||||
```
|
||||
syntax = "v1"
|
||||
|
||||
info(
|
||||
enable: true
|
||||
disable: false
|
||||
)
|
||||
...
|
||||
```
|
||||
@@ -4,7 +4,7 @@ go 1.21
|
||||
|
||||
require (
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2
|
||||
github.com/emicklei/proto v1.14.0
|
||||
github.com/emicklei/proto v1.14.1
|
||||
github.com/fatih/structtag v1.2.0
|
||||
github.com/go-openapi/spec v0.21.1-0.20250328170532-a3928469592e
|
||||
github.com/go-sql-driver/mysql v1.9.0
|
||||
@@ -16,7 +16,7 @@ require (
|
||||
github.com/withfig/autocomplete-tools/integrations/cobra v1.2.1
|
||||
github.com/zeromicro/antlr v0.0.1
|
||||
github.com/zeromicro/ddl-parser v1.0.5
|
||||
github.com/zeromicro/go-zero v1.8.2
|
||||
github.com/zeromicro/go-zero v1.8.3
|
||||
golang.org/x/text v0.22.0
|
||||
google.golang.org/grpc v1.65.0
|
||||
google.golang.org/protobuf v1.36.5
|
||||
@@ -72,7 +72,7 @@ require (
|
||||
github.com/prometheus/client_model v0.6.1 // indirect
|
||||
github.com/prometheus/common v0.62.0 // indirect
|
||||
github.com/prometheus/procfs v0.15.1 // indirect
|
||||
github.com/redis/go-redis/v9 v9.7.3 // indirect
|
||||
github.com/redis/go-redis/v9 v9.8.0 // indirect
|
||||
github.com/spaolacci/murmur3 v1.1.0 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778 // indirect
|
||||
github.com/yuin/gopher-lua v1.1.1 // indirect
|
||||
|
||||
@@ -32,8 +32,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/emicklei/go-restful/v3 v3.11.0 h1:rAQeMHw1c7zTmncogyy8VvRZwtkmkZ4FxERmMY4rD+g=
|
||||
github.com/emicklei/go-restful/v3 v3.11.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc=
|
||||
github.com/emicklei/proto v1.14.0 h1:WYxC0OrBuuC+FUCTZvb8+fzEHdZMwLEF+OnVfZA3LXU=
|
||||
github.com/emicklei/proto v1.14.0/go.mod h1:rn1FgRS/FANiZdD2djyH7TMA9jdRDcYQ9IEN9yvjX0A=
|
||||
github.com/emicklei/proto v1.14.1 h1:fFq+Bj70XXZWXWikcVRvYZxrMS4KIIiPAqdJ8vPrenY=
|
||||
github.com/emicklei/proto v1.14.1/go.mod h1:rn1FgRS/FANiZdD2djyH7TMA9jdRDcYQ9IEN9yvjX0A=
|
||||
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
|
||||
github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU=
|
||||
github.com/fatih/structtag v1.2.0 h1:/OdNE99OxoI/PqaW/SuSK9uxxT3f/tcSZgon/ssNSx4=
|
||||
@@ -146,8 +146,8 @@ github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ
|
||||
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
|
||||
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
|
||||
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
|
||||
github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM=
|
||||
github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA=
|
||||
github.com/redis/go-redis/v9 v9.8.0 h1:q3nRvjrlge/6UD7eTu/DSg2uYiU2mCL0G/uzBWqhicI=
|
||||
github.com/redis/go-redis/v9 v9.8.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
|
||||
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
|
||||
github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
@@ -183,8 +183,8 @@ github.com/zeromicro/antlr v0.0.1 h1:CQpIn/dc0pUjgGQ81y98s/NGOm2Hfru2NNio2I9mQgk
|
||||
github.com/zeromicro/antlr v0.0.1/go.mod h1:nfpjEwFR6Q4xGDJMcZnCL9tEfQRgszMwu3rDz2Z+p5M=
|
||||
github.com/zeromicro/ddl-parser v1.0.5 h1:LaVqHdzMTjasua1yYpIYaksxKqRzFrEukj2Wi2EbWaQ=
|
||||
github.com/zeromicro/ddl-parser v1.0.5/go.mod h1:ISU/8NuPyEpl9pa17Py9TBPetMjtsiHrb9f5XGiYbo8=
|
||||
github.com/zeromicro/go-zero v1.8.2 h1:AbJckBoojbr1lqCN1dkvURTIHOau7yvKReEd7ZmjuCk=
|
||||
github.com/zeromicro/go-zero v1.8.2/go.mod h1:G5dF+jzCEuq0t1j8qdrtVAy30QMgctGcKSfqFIGsvSg=
|
||||
github.com/zeromicro/go-zero v1.8.3 h1:AwpBJQLAsZAt4OOnK0eR8UU1Ja2RFBIXfKkHdnXQKfc=
|
||||
github.com/zeromicro/go-zero v1.8.3/go.mod h1:EnuEA3XdIQvAvc4WWTskRTO0jM2/aQi7OXv1gKWRNJ0=
|
||||
go.etcd.io/etcd/api/v3 v3.5.15 h1:3KpLJir1ZEBrYuV2v+Twaa/e2MdDCEZ/70H+lzEiwsk=
|
||||
go.etcd.io/etcd/api/v3 v3.5.15/go.mod h1:N9EhGzXq58WuMllgH9ZvnEr7SI9pS0k0+DHZezGp7jM=
|
||||
go.etcd.io/etcd/client/pkg/v3 v3.5.15 h1:fo0HpWz/KlHGMCC+YejpiCmyWDEuIpnTDzpJLB5fWlA=
|
||||
|
||||
@@ -28,7 +28,7 @@
|
||||
"dir": "{{.goctl.api.dir}}",
|
||||
"iu": "Ignore update",
|
||||
"stdin": "Use stdin to input api doc content, press \"ctrl + d\" to send EOF",
|
||||
"declare": "Use to skip check api types already declare"
|
||||
"declare": "Use to skip check api types already declare, deprecated if goctl version >= 1.8.3"
|
||||
},
|
||||
"go": {
|
||||
"short": "Generate go files for provided api in api file",
|
||||
@@ -38,7 +38,8 @@
|
||||
"remote": "{{.global.remote}}",
|
||||
"branch": "{{.global.branch}}",
|
||||
"style": "{{.global.style}}",
|
||||
"test": "Generate test files"
|
||||
"test": "Generate test files",
|
||||
"type-group": "Generate type group files"
|
||||
},
|
||||
"new": {
|
||||
"short": "Fast create api service",
|
||||
@@ -185,6 +186,7 @@
|
||||
},
|
||||
"pg": {
|
||||
"short": "Generate postgresql model",
|
||||
"prefix": "The cache prefix, effective when --cache is true",
|
||||
"datasource": {
|
||||
"short": "Generate model from datasource",
|
||||
"url": "The data source of database,like \"postgres://root:password@127.0.0.1:5432/database?sslmode=disable\"",
|
||||
@@ -204,6 +206,7 @@
|
||||
"short": "Generate mongo model",
|
||||
"type": "Specified model type name",
|
||||
"cache": "Generate code with cache [optional]",
|
||||
"prefix": "Generate code with cache prefix [optional]",
|
||||
"easy": "Generate code with auto generated CollectionName for easy declare [optional]",
|
||||
"dir": "{{.goctl.model.dir}}",
|
||||
"style": "{{.global.style}}",
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
)
|
||||
|
||||
// BuildVersion is the version of goctl.
|
||||
const BuildVersion = "1.8.3-beta"
|
||||
const BuildVersion = "1.8.4"
|
||||
|
||||
var tag = map[string]int{"pre-alpha": 0, "alpha": 1, "pre-bata": 2, "beta": 3, "released": 4, "": 5}
|
||||
|
||||
|
||||
@@ -4,9 +4,9 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/stringx"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/rpc/execx"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/util/console"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/util/ctx"
|
||||
@@ -37,7 +37,7 @@ func editMod(version string, verbose bool) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if !stringx.Contains(latest, version) {
|
||||
if !slices.Contains(latest, version) {
|
||||
return fmt.Errorf("release version %q is not found", version)
|
||||
}
|
||||
|
||||
|
||||
@@ -58,9 +58,11 @@ func init() {
|
||||
pgDatasourceCmdFlags.StringVar(&command.VarStringBranch, "branch")
|
||||
pgCmd.PersistentFlags().StringSliceVarPWithDefaultValue(&command.VarStringSliceIgnoreColumns,
|
||||
"ignore-columns", "i", []string{"create_at", "created_at", "create_time", "update_at", "updated_at", "update_time"})
|
||||
pgCmd.PersistentFlags().StringVarPWithDefaultValue(&command.VarStringCachePrefix, "prefix", "p", "cache")
|
||||
|
||||
mongoCmdFlags.StringSliceVarP(&mongo.VarStringSliceType, "type", "t")
|
||||
mongoCmdFlags.BoolVarP(&mongo.VarBoolCache, "cache", "c")
|
||||
mongoCmdFlags.StringVarP(&mongo.VarStringPrefix, "prefix", "p")
|
||||
mongoCmdFlags.BoolVarP(&mongo.VarBoolEasy, "easy", "e")
|
||||
mongoCmdFlags.StringVarP(&mongo.VarStringDir, "dir", "d")
|
||||
mongoCmdFlags.StringVar(&mongo.VarStringStyle, "style")
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
type Context struct {
|
||||
Types []string
|
||||
Cache bool
|
||||
Prefix string
|
||||
Easy bool
|
||||
Output string
|
||||
Cfg *config.Config
|
||||
@@ -60,6 +61,7 @@ func generateModel(ctx *Context) error {
|
||||
"Type": stringx.From(t).Title(),
|
||||
"lowerType": stringx.From(t).Untitle(),
|
||||
"Cache": ctx.Cache,
|
||||
"Prefix": ctx.Prefix,
|
||||
"version": version.BuildVersion,
|
||||
}, output, true); err != nil {
|
||||
return err
|
||||
|
||||
@@ -19,6 +19,8 @@ var (
|
||||
VarStringDir string
|
||||
// VarBoolCache describes whether cache is enabled.
|
||||
VarBoolCache bool
|
||||
// VarStringPrefix string describes the prefix for the cache key.
|
||||
VarStringPrefix string
|
||||
// VarBoolEasy describes whether to generate Collection Name in the code for easy declare.
|
||||
VarBoolEasy bool
|
||||
// VarStringStyle describes the style.
|
||||
@@ -35,6 +37,7 @@ var (
|
||||
func Action(_ *cobra.Command, _ []string) error {
|
||||
tp := VarStringSliceType
|
||||
c := VarBoolCache
|
||||
p := VarStringPrefix
|
||||
easy := VarBoolEasy
|
||||
o := strings.TrimSpace(VarStringDir)
|
||||
s := VarStringStyle
|
||||
@@ -74,6 +77,7 @@ func Action(_ *cobra.Command, _ []string) error {
|
||||
return generate.Do(&generate.Context{
|
||||
Types: tp,
|
||||
Cache: c,
|
||||
Prefix: p,
|
||||
Easy: easy,
|
||||
Output: a,
|
||||
Cfg: cfg,
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
)
|
||||
|
||||
{{if .Cache}}var prefix{{.Type}}CacheKey = "cache:{{.lowerType}}:"{{end}}
|
||||
{{if .Cache}}var prefix{{.Type}}CacheKey = "{{if .Prefix}}{{.Prefix}}:{{end}}cache:{{.lowerType}}:"{{end}}
|
||||
|
||||
type {{.lowerType}}Model interface{
|
||||
Insert(ctx context.Context,data *{{.Type}}) error
|
||||
|
||||
@@ -225,7 +225,20 @@ func PostgreSqlDataSource(_ *cobra.Command, _ []string) error {
|
||||
}
|
||||
ignoreColumns := mergeColumns(VarStringSliceIgnoreColumns)
|
||||
|
||||
return fromPostgreSqlDataSource(url, patterns, dir, schema, cfg, cache, idea, VarBoolStrict, ignoreColumns)
|
||||
arg := pgDataSourceArg{
|
||||
url: url,
|
||||
dir: dir,
|
||||
tablePat: patterns,
|
||||
schema: schema,
|
||||
cfg: cfg,
|
||||
cache: cache,
|
||||
idea: idea,
|
||||
strict: VarBoolStrict,
|
||||
ignoreColumns: ignoreColumns,
|
||||
prefix: VarStringCachePrefix,
|
||||
}
|
||||
|
||||
return fromPostgreSqlDataSource(arg)
|
||||
}
|
||||
|
||||
type ddlArg struct {
|
||||
@@ -339,32 +352,43 @@ func fromMysqlDataSource(arg dataSourceArg) error {
|
||||
return generator.StartFromInformationSchema(matchTables, arg.cache, arg.strict)
|
||||
}
|
||||
|
||||
func fromPostgreSqlDataSource(url string, pattern pattern, dir, schema string, cfg *config.Config, cache, idea, strict bool, ignoreColumns []string) error {
|
||||
log := console.NewConsole(idea)
|
||||
if len(url) == 0 {
|
||||
type pgDataSourceArg struct {
|
||||
url, dir string
|
||||
tablePat pattern
|
||||
schema string
|
||||
cfg *config.Config
|
||||
cache, idea bool
|
||||
strict bool
|
||||
ignoreColumns []string
|
||||
prefix string
|
||||
}
|
||||
|
||||
func fromPostgreSqlDataSource(arg pgDataSourceArg) error {
|
||||
log := console.NewConsole(arg.idea)
|
||||
if len(arg.url) == 0 {
|
||||
log.Error("%v", "expected data source of postgresql, but nothing found")
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(pattern) == 0 {
|
||||
if len(arg.tablePat) == 0 {
|
||||
log.Error("%v", "expected table or table globbing patterns, but nothing found")
|
||||
return nil
|
||||
}
|
||||
db := postgres.New(url)
|
||||
db := postgres.New(arg.url)
|
||||
im := model.NewPostgreSqlModel(db)
|
||||
|
||||
tables, err := im.GetAllTables(schema)
|
||||
tables, err := im.GetAllTables(arg.schema)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
matchTables := make(map[string]*model.Table)
|
||||
for _, item := range tables {
|
||||
if !pattern.Match(item) {
|
||||
if !arg.tablePat.Match(item) {
|
||||
continue
|
||||
}
|
||||
|
||||
columnData, err := im.FindColumns(schema, item)
|
||||
columnData, err := im.FindColumns(arg.schema, item)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -381,10 +405,11 @@ func fromPostgreSqlDataSource(url string, pattern pattern, dir, schema string, c
|
||||
return errors.New("no tables matched")
|
||||
}
|
||||
|
||||
generator, err := gen.NewDefaultGenerator("", dir, cfg, gen.WithConsoleOption(log), gen.WithPostgreSql(), gen.WithIgnoreColumns(ignoreColumns))
|
||||
generator, err := gen.NewDefaultGenerator(arg.prefix, arg.dir, arg.cfg, gen.WithConsoleOption(log),
|
||||
gen.WithPostgreSql(), gen.WithIgnoreColumns(arg.ignoreColumns))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return generator.StartFromInformationSchema(matchTables, cache, strict)
|
||||
return generator.StartFromInformationSchema(matchTables, arg.cache, arg.strict)
|
||||
}
|
||||
|
||||
@@ -465,6 +465,20 @@ func (p *Parser) parsePathItem() []token.Token {
|
||||
if !p.advanceIfPeekTokenIs(token.IDENT) {
|
||||
return nil
|
||||
}
|
||||
list = append(list, p.curTok)
|
||||
} else if p.peekTokenIs(token.DOT) {
|
||||
// Allow dot (.) in path segments for file extensions like .php, .html, etc.
|
||||
if !p.nextToken() {
|
||||
return nil
|
||||
}
|
||||
|
||||
list = append(list, p.curTok)
|
||||
|
||||
// After a dot, we expect an identifier (e.g., .php, .html)
|
||||
if !p.advanceIfPeekTokenIs(token.IDENT) {
|
||||
return nil
|
||||
}
|
||||
|
||||
list = append(list, p.curTok)
|
||||
} else {
|
||||
if p.peekTokenIs(token.LPAREN, token.Returns, token.AT_DOC, token.AT_HANDLER, token.SEMICOLON, token.RBRACE) {
|
||||
@@ -1342,7 +1356,7 @@ func (p *Parser) parseKVExpression() *ast.KVExpr {
|
||||
expr.Colon = p.curTokenNode()
|
||||
|
||||
// token STRING
|
||||
if !p.advanceIfPeekTokenIs(token.STRING, token.RAW_STRING) {
|
||||
if !p.advanceIfPeekTokenIs(token.STRING, token.RAW_STRING, token.IDENT) {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -130,6 +130,8 @@ func TestParser_Parse_infoStmt(t *testing.T) {
|
||||
"author": `"type author here"`,
|
||||
"email": `"type email here"`,
|
||||
"version": `"type version here"`,
|
||||
"enable": `true`,
|
||||
"disable": `false`,
|
||||
}
|
||||
p := New("foo.api", infoTestAPI)
|
||||
result := p.Parse()
|
||||
@@ -760,11 +762,6 @@ func TestParser_Parse_service(t *testing.T) {
|
||||
}),
|
||||
},
|
||||
Request: &ast.BodyStmt{
|
||||
LParen: ast.NewTokenNode(token.Token{Type: token.LPAREN, Text: "("}),
|
||||
RParen: ast.NewTokenNode(token.Token{Type: token.RPAREN, Text: ")"}),
|
||||
},
|
||||
Returns: ast.NewTokenNode(token.Token{Type: token.IDENT, Text: "returns"}),
|
||||
Response: &ast.BodyStmt{
|
||||
LParen: ast.NewTokenNode(token.Token{Type: token.LPAREN, Text: "("}),
|
||||
Body: &ast.BodyExpr{
|
||||
Value: ast.NewTokenNode(token.Token{Type: token.IDENT, Text: "Foo"}),
|
||||
@@ -1031,6 +1028,7 @@ func TestParser_Parse_pathItem(t *testing.T) {
|
||||
{input: "1", expected: "1"},
|
||||
{input: "11", expected: "11"},
|
||||
}
|
||||
|
||||
for _, v := range testData {
|
||||
p := New("foo.api", v.input)
|
||||
ok := p.nextToken()
|
||||
@@ -1066,6 +1064,38 @@ func TestParser_Parse_pathItem(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestParser_Parse_pathItem_WithDot(t *testing.T) {
|
||||
t.Run("valid with dots", func(t *testing.T) {
|
||||
var testData = []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{input: "file.php", expected: "file.php"},
|
||||
{input: "api_jsonrpc.php", expected: "api_jsonrpc.php"},
|
||||
{input: "index.html", expected: "index.html"},
|
||||
{input: "data.json", expected: "data.json"},
|
||||
{input: "style.css", expected: "style.css"},
|
||||
{input: "script.js", expected: "script.js"},
|
||||
{input: "document.pdf", expected: "document.pdf"},
|
||||
{input: "image.png", expected: "image.png"},
|
||||
{input: "api.v1", expected: "api.v1"},
|
||||
{input: "resource.with.multiple.dots", expected: "resource.with.multiple.dots"},
|
||||
}
|
||||
|
||||
for _, v := range testData {
|
||||
p := New("foo.api", v.input)
|
||||
ok := p.nextToken()
|
||||
assert.True(t, ok)
|
||||
tokens := p.parsePathItem()
|
||||
var expected []string
|
||||
for _, tok := range tokens {
|
||||
expected = append(expected, tok.Text)
|
||||
}
|
||||
assert.Equal(t, strings.Join(expected, ""), v.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestParser_Parse_parseTypeStmt(t *testing.T) {
|
||||
assertEqual := func(t *testing.T, expected, actual ast.Stmt) {
|
||||
if expected == nil {
|
||||
@@ -1399,6 +1429,7 @@ func TestParser_Parse_parseTypeStmt(t *testing.T) {
|
||||
assertEqual(t, val.expected, one)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("parseTypeGroupStmt", func(t *testing.T) {
|
||||
var testData = []struct {
|
||||
input string
|
||||
@@ -1472,6 +1503,7 @@ func TestParser_Parse_parseTypeStmt(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, val := range testData {
|
||||
p := New("test.api", val.input)
|
||||
result := p.Parse()
|
||||
|
||||
@@ -4,4 +4,6 @@ info(
|
||||
author: "type author here"
|
||||
email: "type email here"
|
||||
version: "type version here"
|
||||
enable: true
|
||||
disable: false
|
||||
)
|
||||
@@ -10,6 +10,8 @@ info ( // info stmt
|
||||
author: "type author here"
|
||||
email: "type email here"
|
||||
version: "type version here"
|
||||
enable: true
|
||||
disable: false
|
||||
)
|
||||
|
||||
type AliasInt int
|
||||
|
||||
@@ -192,3 +192,131 @@ func TestUntitle(t *testing.T) {
|
||||
assert.Equal(t, c.want, ret)
|
||||
}
|
||||
}
|
||||
|
||||
func TestContainsAny(t *testing.T) {
|
||||
type args struct {
|
||||
s string
|
||||
runes []rune
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "runes is empty",
|
||||
args: args{
|
||||
s: "test",
|
||||
runes: []rune{},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "s is empty and runes is not empty",
|
||||
args: args{
|
||||
s: "",
|
||||
runes: []rune{'a', 'b', 'c'},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "s contains runes",
|
||||
args: args{
|
||||
s: "hello",
|
||||
runes: []rune{'e', 'f'},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "s does not contain runes",
|
||||
args: args{
|
||||
s: "hello",
|
||||
runes: []rune{'x', 'y'},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "s and runes both have one matching character",
|
||||
args: args{
|
||||
s: "a",
|
||||
runes: []rune{'a'},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "s and runes both have one non-matching character",
|
||||
args: args{
|
||||
s: "a",
|
||||
runes: []rune{'b'},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equalf(t, tt.want, ContainsAny(tt.args.s, tt.args.runes...), "ContainsAny(%v, %v)", tt.args.s, tt.args.runes)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestContainsWhiteSpace(t *testing.T) {
|
||||
type args struct {
|
||||
s string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "contains space",
|
||||
args: args{s: "hello world"},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "contains newline",
|
||||
args: args{s: "hello\nworld"},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "contains tab",
|
||||
args: args{s: "hello\tworld"},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "contains form feed",
|
||||
args: args{s: "hello\fworld"},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "contains vertical tab",
|
||||
args: args{s: "hello\vworld"},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "no whitespace",
|
||||
args: args{s: "helloworld"},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
args: args{s: ""},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "only whitespace",
|
||||
args: args{s: " \t\n\f\v"},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "contains non-standard whitespace",
|
||||
args: args{s: "hello\u00A0world"},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equalf(t, tt.want, ContainsWhiteSpace(tt.args.s), "ContainsWhiteSpace(%v)", tt.args.s)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/lang"
|
||||
"github.com/zeromicro/go-zero/core/mathx"
|
||||
"google.golang.org/grpc/resolver"
|
||||
)
|
||||
|
||||
@@ -47,7 +46,7 @@ func TestDirectBuilder_Build(t *testing.T) {
|
||||
}, cc, resolver.BuildOptions{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
size := mathx.MinInt(test, subsetSize)
|
||||
size := min(test, subsetSize)
|
||||
assert.Equal(t, size, len(cc.state.Addresses))
|
||||
m := make(map[string]lang.PlaceholderType)
|
||||
for _, each := range cc.state.Addresses {
|
||||
|
||||
Reference in New Issue
Block a user