Compare commits

...

57 Commits

Author SHA1 Message Date
Kevin Wan
c9ff6a10d3 feat: support serverless in rest (#5001)
Signed-off-by: kevin <wanjunfeng@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-07-13 00:00:52 +08:00
Kevin Wan
a71e56de52 fix: context key error in sql read write mode (#5000) 2025-07-12 06:58:08 +08:00
Kevin Wan
bae8d4f4c8 chore: refactoring sql read write mode (#4990)
Signed-off-by: kevin <wanjunfeng@gmail.com>
2025-07-11 01:05:55 +08:00
zhoushuguang
8c6266f338 sql read write support (#4976)
Co-authored-by: light.zhou <light.zhou@bkyo.io>
2025-07-09 16:04:56 +00:00
Kevin Wan
95d5b81f44 chore: optimize pr 4979 (#4988)
Signed-off-by: kevin <wanjunfeng@gmail.com>
2025-07-09 23:55:24 +08:00
geekeryy
bca7bbc142 fix: correct duration type comparison in environment variable processing (#4979) 2025-07-09 15:22:27 +00:00
Kevin Wan
df9a52664b fix issue #4986 2025-07-08 13:58:48 +00:00
Kevin Wan
937cf0db96 Update readme-cn.md (#4983) 2025-07-04 11:02:49 +08:00
Kevin Wan
75cebb65f8 fix: timeout 0s not working (#4932)
Signed-off-by: kevin <wanjunfeng@gmail.com>
2025-07-01 17:01:24 +08:00
dependabot[bot]
410f56e73a chore(deps): bump github.com/redis/go-redis/v9 from 9.10.0 to 9.11.0 (#4969)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-06-25 18:35:01 +08:00
dependabot[bot]
017909a3ab chore(deps): bump github.com/emicklei/proto from 1.14.1 to 1.14.2 in /tools/goctl (#4961)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-06-19 15:40:18 +08:00
kesonan
0d31e6c375 (goctl): fix #4943 (#4953) 2025-06-14 15:36:30 +00:00
Kevin Wan
0ba86b1849 chore: add more tests (#4949) 2025-06-13 22:10:08 +08:00
wanwu
4cacc4d9d3 fix: the time.Duration type panics due to numerical values (#4944)
Co-authored-by: sam.yang <sam.yang@yijinin.com>
2025-06-12 15:11:07 +00:00
Eric
a99c14da4a fix: typo of the logic of CpuThreshold in comments (#4942)
Co-authored-by: zhouyy <zhouyy@ickey.cn>
2025-06-12 08:28:44 +00:00
Kevin Wan
985582264a chore: fix warnings (#4940) 2025-06-12 00:04:29 +08:00
Kevin Wan
8364e341e1 chore: update go-zero dep (#4933)
Signed-off-by: kevin <wanjunfeng@gmail.com>
2025-06-09 18:08:20 +08:00
Kevin Wan
0f2b589d4d Revert "fix: api group set timeout: 0s not working." (#4931) 2025-06-08 23:14:38 +08:00
spectatorMrZ
19fec36d24 fix: api group set timeout: 0s not working. (#4785) 2025-06-08 14:50:21 +00:00
Kevin Wan
f037bf344d chore: add more tests (#4930) 2025-06-08 22:08:04 +08:00
MarkJoyMa
d99cf35b07 Feat/continue profiling (#4867)
Co-authored-by: aiden.ma <Aiden.ma@yijinin.com>
Co-authored-by: aiden.ma <aiden.ma@bkyo.io>
2025-06-07 21:12:31 +08:00
Kevin Wan
f459f1b5ff chore: update goctl version (#4929) 2025-06-07 21:01:35 +08:00
Haiwei Zhang
0140fd417b feat(goctl): generate mongo model with cache prefix (#4907) 2025-06-07 12:54:33 +00:00
jaron
7969e0ca38 fix(goctl): Fix getting swagger consume types (#4903) 2025-06-07 12:46:34 +00:00
Kevin Wan
91c885b5b0 chore: add more unit tests for mcp (#4928) 2025-06-07 20:41:57 +08:00
MarkJoyMa
d4cccca387 Fix the problem that mcp request id is not of int type (#4914) 2025-06-07 10:37:18 +08:00
dependabot[bot]
4b2095ed03 chore(deps): bump github.com/redis/go-redis/v9 from 9.9.0 to 9.10.0 (#4926) 2025-06-07 10:07:26 +08:00
dependabot[bot]
1229eeb2d2 chore(deps): bump go.mongodb.org/mongo-driver from 1.17.3 to 1.17.4 (#4924) 2025-06-06 19:45:26 +08:00
dependabot[bot]
9142b146c5 chore(deps): bump github.com/alicebob/miniredis/v2 from 2.34.0 to 2.35.0 (#4919)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-06-06 10:09:15 +08:00
Kevin Wan
8a1b2d5aed chore: fix typo (#4920) 2025-06-05 22:51:22 +08:00
Leon cap
da5d39e6ca fix: correct spelling of 'cancellation' in timeout handler comment (#4916)
Co-authored-by: Kevin Wan <wanjunfeng@gmail.com>
2025-06-05 22:42:53 +08:00
Leon cap
68c5a17c67 fix: correct spelling of 'underlying' in Header method comment (#4918) 2025-06-05 10:36:21 +00:00
Leon cap
b53f9f5f2d fix: correct spelling of 'TimeoutHandler' in timeout handler comment (#4917) 2025-06-04 15:48:37 +00:00
dependabot[bot]
36d57626b6 chore(deps): bump github.com/redis/go-redis/v9 from 9.8.0 to 9.9.0 (#4905)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-05-28 11:32:57 +08:00
Kevin Wan
4e36ba832f Update readme.md (#4897) 2025-05-25 22:25:56 +08:00
Kevin Wan
a44954a771 fix: don't set read/write timeout if timeout middleware disabled (#4895)
Signed-off-by: kevin <wanjunfeng@gmail.com>
2025-05-25 15:07:58 +08:00
kesonan
f3edd4b880 goctl: v1.8.4-beta (#4890) 2025-05-25 05:36:56 +00:00
Kevin Wan
2de3e397ff chore: revert go version to 1.21 (#4893) 2025-05-24 18:17:49 +08:00
Qiu shao
a435eb56f2 perf(hash): optimize Md5Hex encoding performance (#4891) 2025-05-24 08:41:21 +00:00
Kevin Wan
d80761c147 chore: refactor coding style (#4887) 2025-05-22 23:29:40 +08:00
me-cs
e7bd0d8b60 update:To standardize the time format, use the go standard library's own (#4875) 2025-05-22 15:26:53 +00:00
me-cs
b109b3ef4c update:use builtin cmp func (#4879) 2025-05-22 15:19:13 +00:00
Kevin Wan
e3c371ac89 chore: refactor 2025-05-20 12:59:35 +00:00
燕归来
15eb6f4f6d test(hash): modify TestConsistentHashTransferOnFailure to more reasonable test transfer ratio (#4874) 2025-05-20 12:51:50 +00:00
me-cs
4d3681b71c Optimize slicing operations (#4877) 2025-05-20 11:36:02 +00:00
dependabot[bot]
a682bda0bb chore(deps): bump github.com/jackc/pgx/v5 from 5.7.4 to 5.7.5 (#4871)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-05-20 10:19:06 +08:00
kesonan
45b27ad93a goctl: 1.8.4-beta (#4869) 2025-05-19 15:30:50 +00:00
Kevin Wan
292a8302a1 chore: optimize mcp (#4866) 2025-05-17 12:28:06 +08:00
kesonan
91ab1f6d2b goctl features of 1.8.4-alpha (#4849) 2025-05-15 13:59:48 +00:00
Kevin Wan
5048c350ae chore: fix test failure in profilecenter_test.go 2025-05-15 13:31:53 +00:00
Kevin Wan
94edc32f3e chore: optimize profile center and remove tablewriter dependency 2025-05-15 13:22:27 +00:00
Kevin Wan
ec989b2e2a chore: for backward compatibility (#4852)
Signed-off-by: kevin <wanjunfeng@gmail.com>
2025-05-11 20:19:00 +08:00
me-cs
82fe802e81 update:Use the official sync.OnceFunc (#4840) 2025-05-11 12:08:43 +00:00
me-cs
072d68f897 update:Use the official slice operate func (#4841) 2025-05-11 11:48:54 +00:00
Kevin Wan
2e91ba5811 chore: refactor rest file server (#4851)
Signed-off-by: kevin <wanjunfeng@gmail.com>
2025-05-11 12:44:43 +08:00
shaouai
5564c43197 feat: serve files using embed.FS (#4847) 2025-05-10 15:43:13 +00:00
Kevin Wan
e55158b0f7 chore: update deps in goctl (#4830) 2025-05-04 16:18:02 +08:00
99 changed files with 13402 additions and 578 deletions

View File

@@ -8,16 +8,12 @@ import (
"sync" "sync"
"time" "time"
"github.com/zeromicro/go-zero/core/mathx"
"github.com/zeromicro/go-zero/core/proc" "github.com/zeromicro/go-zero/core/proc"
"github.com/zeromicro/go-zero/core/stat" "github.com/zeromicro/go-zero/core/stat"
"github.com/zeromicro/go-zero/core/stringx" "github.com/zeromicro/go-zero/core/stringx"
) )
const ( const numHistoryReasons = 5
numHistoryReasons = 5
timeFormat = "15:04:05"
)
// ErrServiceUnavailable is returned when the Breaker state is open. // ErrServiceUnavailable is returned when the Breaker state is open.
var ErrServiceUnavailable = errors.New("circuit breaker is open") var ErrServiceUnavailable = errors.New("circuit breaker is open")
@@ -262,9 +258,9 @@ type errorWindow struct {
func (ew *errorWindow) add(reason string) { func (ew *errorWindow) add(reason string) {
ew.lock.Lock() ew.lock.Lock()
ew.reasons[ew.index] = fmt.Sprintf("%s %s", time.Now().Format(timeFormat), reason) ew.reasons[ew.index] = fmt.Sprintf("%s %s", time.Now().Format(time.TimeOnly), reason)
ew.index = (ew.index + 1) % numHistoryReasons ew.index = (ew.index + 1) % numHistoryReasons
ew.count = mathx.MinInt(ew.count+1, numHistoryReasons) ew.count = min(ew.count+1, numHistoryReasons)
ew.lock.Unlock() ew.lock.Unlock()
} }

View File

@@ -423,7 +423,7 @@ func TestRegistry_Monitor(t *testing.T) {
GetRegistry().clusters = map[string]*cluster{ GetRegistry().clusters = map[string]*cluster{
getClusterKey(endpoints): { getClusterKey(endpoints): {
watchers: map[watchKey]*watchValue{ watchers: map[watchKey]*watchValue{
watchKey{ {
key: "foo", key: "foo",
exactMatch: true, exactMatch: true,
}: { }: {
@@ -449,7 +449,7 @@ func TestRegistry_Unmonitor(t *testing.T) {
GetRegistry().clusters = map[string]*cluster{ GetRegistry().clusters = map[string]*cluster{
getClusterKey(endpoints): { getClusterKey(endpoints): {
watchers: map[watchKey]*watchValue{ watchers: map[watchKey]*watchValue{
watchKey{ {
key: "foo", key: "foo",
exactMatch: true, exactMatch: true,
}: { }: {

View File

@@ -86,21 +86,16 @@ func TestConsistentHashIncrementalTransfer(t *testing.T) {
func TestConsistentHashTransferOnFailure(t *testing.T) { func TestConsistentHashTransferOnFailure(t *testing.T) {
index := 41 index := 41
keys, newKeys := getKeysBeforeAndAfterFailure(t, "localhost:", index) ratioNotExists := getTransferRatioOnFailure(t, index)
var transferred int assert.True(t, ratioNotExists == 0, fmt.Sprintf("%d: %f", index, ratioNotExists))
for k, v := range newKeys { index = 13
if v != keys[k] { ratio := getTransferRatioOnFailure(t, index)
transferred++ assert.True(t, ratio < 2.5/keySize, fmt.Sprintf("%d: %f", index, ratio))
}
}
ratio := float32(transferred) / float32(requestSize)
assert.True(t, ratio < 2.5/float32(keySize), fmt.Sprintf("%d: %f", index, ratio))
} }
func TestConsistentHashLeastTransferOnFailure(t *testing.T) { func TestConsistentHashLeastTransferOnFailure(t *testing.T) {
prefix := "localhost:" prefix := "localhost:"
index := 41 index := 13
keys, newKeys := getKeysBeforeAndAfterFailure(t, prefix, index) keys, newKeys := getKeysBeforeAndAfterFailure(t, prefix, index)
for k, v := range keys { for k, v := range keys {
newV := newKeys[k] newV := newKeys[k]
@@ -164,6 +159,17 @@ func getKeysBeforeAndAfterFailure(t *testing.T, prefix string, index int) (map[i
return keys, newKeys return keys, newKeys
} }
func getTransferRatioOnFailure(t *testing.T, index int) float32 {
keys, newKeys := getKeysBeforeAndAfterFailure(t, "localhost:", index)
var transferred int
for k, v := range newKeys {
if v != keys[k] {
transferred++
}
}
return float32(transferred) / float32(requestSize)
}
type mockNode struct { type mockNode struct {
addr string addr string
id int id int

View File

@@ -2,7 +2,7 @@ package hash
import ( import (
"crypto/md5" "crypto/md5"
"fmt" "encoding/hex"
"github.com/spaolacci/murmur3" "github.com/spaolacci/murmur3"
) )
@@ -20,6 +20,7 @@ func Md5(data []byte) []byte {
} }
// Md5Hex returns the md5 hex string of data. // Md5Hex returns the md5 hex string of data.
// This function is optimized for better performance than fmt.Sprintf.
func Md5Hex(data []byte) string { func Md5Hex(data []byte) string {
return fmt.Sprintf("%x", Md5(data)) return hex.EncodeToString(Md5(data))
} }

View File

@@ -7,12 +7,11 @@ import (
) )
var ( var (
fieldsContextKey contextKey
globalFields atomic.Value globalFields atomic.Value
globalFieldsLock sync.Mutex globalFieldsLock sync.Mutex
) )
type contextKey struct{} type fieldsKey struct{}
// AddGlobalFields adds global fields. // AddGlobalFields adds global fields.
func AddGlobalFields(fields ...LogField) { func AddGlobalFields(fields ...LogField) {
@@ -29,16 +28,16 @@ func AddGlobalFields(fields ...LogField) {
// ContextWithFields returns a new context with the given fields. // ContextWithFields returns a new context with the given fields.
func ContextWithFields(ctx context.Context, fields ...LogField) context.Context { func ContextWithFields(ctx context.Context, fields ...LogField) context.Context {
if val := ctx.Value(fieldsContextKey); val != nil { if val := ctx.Value(fieldsKey{}); val != nil {
if arr, ok := val.([]LogField); ok { if arr, ok := val.([]LogField); ok {
allFields := make([]LogField, 0, len(arr)+len(fields)) allFields := make([]LogField, 0, len(arr)+len(fields))
allFields = append(allFields, arr...) allFields = append(allFields, arr...)
allFields = append(allFields, fields...) allFields = append(allFields, fields...)
return context.WithValue(ctx, fieldsContextKey, allFields) return context.WithValue(ctx, fieldsKey{}, allFields)
} }
} }
return context.WithValue(ctx, fieldsContextKey, fields) return context.WithValue(ctx, fieldsKey{}, fields)
} }
// WithFields returns a new logger with the given fields. // WithFields returns a new logger with the given fields.

View File

@@ -34,7 +34,7 @@ func TestAddGlobalFields(t *testing.T) {
func TestContextWithFields(t *testing.T) { func TestContextWithFields(t *testing.T) {
ctx := ContextWithFields(context.Background(), Field("a", 1), Field("b", 2)) ctx := ContextWithFields(context.Background(), Field("a", 1), Field("b", 2))
vals := ctx.Value(fieldsContextKey) vals := ctx.Value(fieldsKey{})
assert.NotNil(t, vals) assert.NotNil(t, vals)
fields, ok := vals.([]LogField) fields, ok := vals.([]LogField)
assert.True(t, ok) assert.True(t, ok)
@@ -43,7 +43,7 @@ func TestContextWithFields(t *testing.T) {
func TestWithFields(t *testing.T) { func TestWithFields(t *testing.T) {
ctx := WithFields(context.Background(), Field("a", 1), Field("b", 2)) ctx := WithFields(context.Background(), Field("a", 1), Field("b", 2))
vals := ctx.Value(fieldsContextKey) vals := ctx.Value(fieldsKey{})
assert.NotNil(t, vals) assert.NotNil(t, vals)
fields, ok := vals.([]LogField) fields, ok := vals.([]LogField)
assert.True(t, ok) assert.True(t, ok)
@@ -55,7 +55,7 @@ func TestWithFieldsAppend(t *testing.T) {
ctx := context.WithValue(context.Background(), dummyKey, "dummy") ctx := context.WithValue(context.Background(), dummyKey, "dummy")
ctx = ContextWithFields(ctx, Field("a", 1), Field("b", 2)) ctx = ContextWithFields(ctx, Field("a", 1), Field("b", 2))
ctx = ContextWithFields(ctx, Field("c", 3), Field("d", 4)) ctx = ContextWithFields(ctx, Field("c", 3), Field("d", 4))
vals := ctx.Value(fieldsContextKey) vals := ctx.Value(fieldsKey{})
assert.NotNil(t, vals) assert.NotNil(t, vals)
fields, ok := vals.([]LogField) fields, ok := vals.([]LogField)
assert.True(t, ok) assert.True(t, ok)
@@ -80,8 +80,8 @@ func TestWithFieldsAppendCopy(t *testing.T) {
ctxa := ContextWithFields(ctx, af) ctxa := ContextWithFields(ctx, af)
ctxb := ContextWithFields(ctx, bf) ctxb := ContextWithFields(ctx, bf)
assert.EqualValues(t, af, ctxa.Value(fieldsContextKey).([]LogField)[count]) assert.EqualValues(t, af, ctxa.Value(fieldsKey{}).([]LogField)[count])
assert.EqualValues(t, bf, ctxb.Value(fieldsContextKey).([]LogField)[count]) assert.EqualValues(t, bf, ctxb.Value(fieldsKey{}).([]LogField)[count])
} }
func BenchmarkAtomicValue(b *testing.B) { func BenchmarkAtomicValue(b *testing.B) {

View File

@@ -224,7 +224,7 @@ func (l *richLogger) buildFields(fields ...LogField) []LogField {
fields = append(fields, Field(spanKey, spanID)) fields = append(fields, Field(spanKey, spanID))
} }
val := l.ctx.Value(fieldsContextKey) val := l.ctx.Value(fieldsKey{})
if val != nil { if val != nil {
if arr, ok := val.([]LogField); ok { if arr, ok := val.([]LogField); ok {
fields = append(fields, arr...) fields = append(fields, arr...)

View File

@@ -18,7 +18,6 @@ import (
) )
const ( const (
dateFormat = "2006-01-02"
hoursPerDay = 24 hoursPerDay = 24
bufferSize = 100 bufferSize = 100
defaultDirMode = 0o755 defaultDirMode = 0o755
@@ -116,7 +115,7 @@ func (r *DailyRotateRule) OutdatedFiles() []string {
} }
var buf strings.Builder var buf strings.Builder
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay*r.days)).Format(dateFormat) boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay*r.days)).Format(time.DateOnly)
buf.WriteString(r.filename) buf.WriteString(r.filename)
buf.WriteString(r.delimiter) buf.WriteString(r.delimiter)
buf.WriteString(boundary) buf.WriteString(boundary)
@@ -425,7 +424,7 @@ func compressLogFile(file string) {
} }
func getNowDate() string { func getNowDate() string {
return time.Now().Format(dateFormat) return time.Now().Format(time.DateOnly)
} }
func getNowDateInRFC3339Format() string { func getNowDateInRFC3339Format() string {

View File

@@ -52,7 +52,7 @@ func TestDailyRotateRuleOutdatedFiles(t *testing.T) {
}) })
t.Run("temp files", func(t *testing.T) { t.Run("temp files", func(t *testing.T) {
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat) boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(time.DateOnly)
f1, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary) f1, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
assert.NoError(t, err) assert.NoError(t, err)
_ = f1.Close() _ = f1.Close()
@@ -73,7 +73,7 @@ func TestDailyRotateRuleOutdatedFiles(t *testing.T) {
func TestDailyRotateRuleShallRotate(t *testing.T) { func TestDailyRotateRuleShallRotate(t *testing.T) {
var rule DailyRotateRule var rule DailyRotateRule
rule.rotatedTime = time.Now().Add(time.Hour * 24).Format(dateFormat) rule.rotatedTime = time.Now().Add(time.Hour * 24).Format(time.DateOnly)
assert.True(t, rule.ShallRotate(0)) assert.True(t, rule.ShallRotate(0))
} }
@@ -117,12 +117,12 @@ func TestSizeLimitRotateRuleOutdatedFiles(t *testing.T) {
}) })
t.Run("temp files", func(t *testing.T) { t.Run("temp files", func(t *testing.T) {
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat) boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(time.DateOnly)
f1, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary) f1, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
assert.NoError(t, err) assert.NoError(t, err)
f2, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary) f2, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
assert.NoError(t, err) assert.NoError(t, err)
boundary1 := time.Now().Add(time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat) boundary1 := time.Now().Add(time.Hour * time.Duration(hoursPerDay) * 2).Format(time.DateOnly)
f3, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary1) f3, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary1)
assert.NoError(t, err) assert.NoError(t, err)
t.Cleanup(func() { t.Cleanup(func() {
@@ -144,12 +144,12 @@ func TestSizeLimitRotateRuleOutdatedFiles(t *testing.T) {
}) })
t.Run("no backups", func(t *testing.T) { t.Run("no backups", func(t *testing.T) {
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat) boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(time.DateOnly)
f1, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary) f1, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
assert.NoError(t, err) assert.NoError(t, err)
f2, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary) f2, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
assert.NoError(t, err) assert.NoError(t, err)
boundary1 := time.Now().Add(time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat) boundary1 := time.Now().Add(time.Hour * time.Duration(hoursPerDay) * 2).Format(time.DateOnly)
f3, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary1) f3, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary1)
assert.NoError(t, err) assert.NoError(t, err)
t.Cleanup(func() { t.Cleanup(func() {
@@ -319,7 +319,7 @@ func TestRotateLoggerWrite(t *testing.T) {
} }
// the following write calls cannot be changed to Write, because of DATA RACE. // the following write calls cannot be changed to Write, because of DATA RACE.
logger.write([]byte(`foo`)) logger.write([]byte(`foo`))
rule.rotatedTime = time.Now().Add(-time.Hour * 24).Format(dateFormat) rule.rotatedTime = time.Now().Add(-time.Hour * 24).Format(time.DateOnly)
logger.write([]byte(`bar`)) logger.write([]byte(`bar`))
logger.Close() logger.Close()
logger.write([]byte(`baz`)) logger.write([]byte(`baz`))
@@ -447,7 +447,7 @@ func TestRotateLoggerWithSizeLimitRotateRuleWrite(t *testing.T) {
} }
// the following write calls cannot be changed to Write, because of DATA RACE. // the following write calls cannot be changed to Write, because of DATA RACE.
logger.write([]byte(`foo`)) logger.write([]byte(`foo`))
rule.rotatedTime = time.Now().Add(-time.Hour * 24).Format(dateFormat) rule.rotatedTime = time.Now().Add(-time.Hour * 24).Format(time.DateOnly)
logger.write([]byte(`bar`)) logger.write([]byte(`bar`))
logger.Close() logger.Close()
logger.write([]byte(`baz`)) logger.write([]byte(`baz`))

View File

@@ -7,6 +7,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"reflect" "reflect"
"slices"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@@ -15,7 +16,6 @@ import (
"github.com/zeromicro/go-zero/core/jsonx" "github.com/zeromicro/go-zero/core/jsonx"
"github.com/zeromicro/go-zero/core/lang" "github.com/zeromicro/go-zero/core/lang"
"github.com/zeromicro/go-zero/core/proc" "github.com/zeromicro/go-zero/core/proc"
"github.com/zeromicro/go-zero/core/stringx"
) )
const ( const (
@@ -30,7 +30,9 @@ var (
errValueNotSettable = errors.New("value is not settable") errValueNotSettable = errors.New("value is not settable")
errValueNotStruct = errors.New("value type is not struct") errValueNotStruct = errors.New("value type is not struct")
keyUnmarshaler = NewUnmarshaler(defaultKeyName) keyUnmarshaler = NewUnmarshaler(defaultKeyName)
boolType = reflect.TypeOf(false)
durationType = reflect.TypeOf(time.Duration(0)) durationType = reflect.TypeOf(time.Duration(0))
stringType = reflect.TypeOf("")
cacheKeys = make(map[string][]string) cacheKeys = make(map[string][]string)
cacheKeysLock sync.Mutex cacheKeysLock sync.Mutex
defaultCache = make(map[string]any) defaultCache = make(map[string]any)
@@ -622,9 +624,19 @@ func (u *Unmarshaler) processFieldNotFromString(fieldType reflect.Type, value re
return u.fillSliceFromString(fieldType, value, mapValue, fullName) return u.fillSliceFromString(fieldType, value, mapValue, fullName)
case valueKind == reflect.String && derefedFieldType == durationType: case valueKind == reflect.String && derefedFieldType == durationType:
return fillDurationValue(fieldType, value, mapValue.(string)) v, err := convertToString(mapValue, fullName)
if err != nil {
return err
}
return fillDurationValue(fieldType, value, v)
case valueKind == reflect.String && typeKind == reflect.Struct && u.implementsUnmarshaler(fieldType): case valueKind == reflect.String && typeKind == reflect.Struct && u.implementsUnmarshaler(fieldType):
return u.fillUnmarshalerStruct(fieldType, value, mapValue.(string)) v, err := convertToString(mapValue, fullName)
if err != nil {
return err
}
return u.fillUnmarshalerStruct(fieldType, value, v)
default: default:
return u.processFieldPrimitive(fieldType, value, mapValue, opts, fullName) return u.processFieldPrimitive(fieldType, value, mapValue, opts, fullName)
} }
@@ -755,24 +767,24 @@ func (u *Unmarshaler) processFieldWithEnvValue(fieldType reflect.Type, value ref
return err return err
} }
fieldKind := fieldType.Kind() derefType := Deref(fieldType)
switch fieldKind { switch derefType {
case reflect.Bool: case boolType:
val, err := strconv.ParseBool(envVal) val, err := strconv.ParseBool(envVal)
if err != nil { if err != nil {
return fmt.Errorf("unmarshal field %q with environment variable, %w", fullName, err) return fmt.Errorf("unmarshal field %q with environment variable, %w", fullName, err)
} }
value.SetBool(val) SetValue(fieldType, value, reflect.ValueOf(val))
return nil return nil
case durationType.Kind(): case durationType:
if err := fillDurationValue(fieldType, value, envVal); err != nil { if err := fillDurationValue(fieldType, value, envVal); err != nil {
return fmt.Errorf("unmarshal field %q with environment variable, %w", fullName, err) return fmt.Errorf("unmarshal field %q with environment variable, %w", fullName, err)
} }
return nil return nil
case reflect.String: case stringType:
value.SetString(envVal) SetValue(fieldType, value, reflect.ValueOf(envVal))
return nil return nil
default: default:
return u.processFieldPrimitiveWithJSONNumber(fieldType, value, json.Number(envVal), opts, fullName) return u.processFieldPrimitiveWithJSONNumber(fieldType, value, json.Number(envVal), opts, fullName)
@@ -894,7 +906,7 @@ func (u *Unmarshaler) processNamedFieldWithValueFromString(fieldType reflect.Typ
valueKind.String()) valueKind.String())
} }
if !stringx.Contains(options, checkValue) { if !slices.Contains(options, checkValue) {
return fmt.Errorf(`value "%s" for field %q is not defined in options "%v"`, return fmt.Errorf(`value "%s" for field %q is not defined in options "%v"`,
mapValue, key, options) mapValue, key, options)
} }

View File

@@ -203,6 +203,20 @@ func TestUnmarshalDuration(t *testing.T) {
} }
} }
func TestUnmarshalDurationUnexpectedError(t *testing.T) {
type inner struct {
Duration time.Duration `key:"duration"`
}
content := "{\"duration\": 1}"
var m = map[string]any{}
err := jsonx.Unmarshal([]byte(content), &m)
assert.NoError(t, err)
var in inner
err = UnmarshalKey(m, &in)
assert.Error(t, err)
assert.Contains(t, err.Error(), "expect string")
}
func TestUnmarshalDurationDefault(t *testing.T) { func TestUnmarshalDurationDefault(t *testing.T) {
type inner struct { type inner struct {
Int int `key:"int"` Int int `key:"int"`
@@ -4665,6 +4679,23 @@ func TestUnmarshal_EnvInt(t *testing.T) {
} }
} }
func TestUnmarshal_EnvInt64(t *testing.T) {
type Value struct {
Age int64 `key:"age,env=TEST_NAME_INT64"`
}
const (
envName = "TEST_NAME_INT64"
envVal = "88"
)
t.Setenv(envName, envVal)
var v Value
if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
assert.Equal(t, int64(88), v.Age)
}
}
func TestUnmarshal_EnvIntOverwrite(t *testing.T) { func TestUnmarshal_EnvIntOverwrite(t *testing.T) {
type Value struct { type Value struct {
Age int `key:"age,env=TEST_NAME_INT"` Age int `key:"age,env=TEST_NAME_INT"`
@@ -4770,20 +4801,33 @@ func TestUnmarshal_EnvBoolBad(t *testing.T) {
} }
func TestUnmarshal_EnvDuration(t *testing.T) { func TestUnmarshal_EnvDuration(t *testing.T) {
type Value struct {
Duration time.Duration `key:"duration,env=TEST_NAME_DURATION"`
}
const ( const (
envName = "TEST_NAME_DURATION" envName = "TEST_NAME_DURATION"
envVal = "1s" envVal = "1s"
) )
t.Setenv(envName, envVal) t.Setenv(envName, envVal)
var v Value t.Run("valid duration", func(t *testing.T) {
if assert.NoError(t, UnmarshalKey(emptyMap, &v)) { type Value struct {
assert.Equal(t, time.Second, v.Duration) Duration time.Duration `key:"duration,env=TEST_NAME_DURATION"`
} }
var v Value
if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
assert.Equal(t, time.Second, v.Duration)
}
})
t.Run("ptr of duration", func(t *testing.T) {
type Value struct {
Duration *time.Duration `key:"duration,env=TEST_NAME_DURATION"`
}
var v Value
if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
assert.Equal(t, time.Second, *v.Duration)
}
})
} }
func TestUnmarshal_EnvDurationBadValue(t *testing.T) { func TestUnmarshal_EnvDurationBadValue(t *testing.T) {
@@ -5995,6 +6039,16 @@ func TestUnmarshal_Unmarshaler(t *testing.T) {
}, &v)) }, &v))
assert.Nil(t, v.Foo) assert.Nil(t, v.Foo)
}) })
t.Run("json.Number", func(t *testing.T) {
v := struct {
Foo *mockUnmarshaler `json:"name"`
}{}
m := map[string]any{
"name": json.Number("123"),
}
assert.Error(t, UnmarshalJsonMap(m, &v))
})
} }
func TestParseJsonStringValue(t *testing.T) { func TestParseJsonStringValue(t *testing.T) {

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"math" "math"
"reflect" "reflect"
"slices"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@@ -91,6 +92,15 @@ func ValidatePtr(v reflect.Value) error {
return nil return nil
} }
func convertToString(val any, fullName string) (string, error) {
v, ok := val.(string)
if !ok {
return "", fmt.Errorf("expect string for field %s, but got type %T", fullName, val)
}
return v, nil
}
func convertTypeFromString(kind reflect.Kind, str string) (any, error) { func convertTypeFromString(kind reflect.Kind, str string) (any, error) {
switch kind { switch kind {
case reflect.Bool: case reflect.Bool:
@@ -634,11 +644,11 @@ func validateValueInOptions(val any, options []string) error {
if len(options) > 0 { if len(options) > 0 {
switch v := val.(type) { switch v := val.(type) {
case string: case string:
if !stringx.Contains(options, v) { if !slices.Contains(options, v) {
return fmt.Errorf(`error: value %q is not defined in options "%v"`, v, options) return fmt.Errorf(`error: value %q is not defined in options "%v"`, v, options)
} }
default: default:
if !stringx.Contains(options, Repr(v)) { if !slices.Contains(options, Repr(v)) {
return fmt.Errorf(`error: value "%v" is not defined in options "%v"`, val, options) return fmt.Errorf(`error: value "%v" is not defined in options "%v"`, val, options)
} }
} }

View File

@@ -1,19 +1,13 @@
package mathx package mathx
// MaxInt returns the larger one of a and b. // MaxInt returns the larger one of a and b.
// Deprecated: use builtin max instead.
func MaxInt(a, b int) int { func MaxInt(a, b int) int {
if a > b { return max(a, b)
return a
}
return b
} }
// MinInt returns the smaller one of a and b. // MinInt returns the smaller one of a and b.
// Deprecated: use builtin min instead.
func MinInt(a, b int) int { func MinInt(a, b int) int {
if a < b { return min(a, b)
return a
}
return b
} }

View File

@@ -1,13 +1,12 @@
package prof package prof
import ( import (
"bytes" "fmt"
"strconv" "strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/olekukonko/tablewriter"
"github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/threading" "github.com/zeromicro/go-zero/core/threading"
) )
@@ -28,46 +27,15 @@ type (
const flushInterval = 5 * time.Minute const flushInterval = 5 * time.Minute
var ( var pc = &profileCenter{
pc = &profileCenter{ slots: make(map[string]*profileSlot),
slots: make(map[string]*profileSlot),
}
once sync.Once
)
func report(name string, duration time.Duration) {
updated := func() bool {
pc.lock.RLock()
defer pc.lock.RUnlock()
slot, ok := pc.slots[name]
if ok {
atomic.AddInt64(&slot.lifecount, 1)
atomic.AddInt64(&slot.lastcount, 1)
atomic.AddInt64(&slot.lifecycle, int64(duration))
atomic.AddInt64(&slot.lastcycle, int64(duration))
}
return ok
}()
if !updated {
func() {
pc.lock.Lock()
defer pc.lock.Unlock()
pc.slots[name] = &profileSlot{
lifecount: 1,
lastcount: 1,
lifecycle: int64(duration),
lastcycle: int64(duration),
}
}()
}
once.Do(flushRepeatly)
} }
func flushRepeatly() { func init() {
flushRepeatedly()
}
func flushRepeatedly() {
threading.GoSafe(func() { threading.GoSafe(func() {
for { for {
time.Sleep(flushInterval) time.Sleep(flushInterval)
@@ -76,42 +44,64 @@ func flushRepeatly() {
}) })
} }
func report(name string, duration time.Duration) {
slot := loadOrStoreSlot(name, duration)
atomic.AddInt64(&slot.lifecount, 1)
atomic.AddInt64(&slot.lastcount, 1)
atomic.AddInt64(&slot.lifecycle, int64(duration))
atomic.AddInt64(&slot.lastcycle, int64(duration))
}
func loadOrStoreSlot(name string, duration time.Duration) *profileSlot {
pc.lock.RLock()
slot, ok := pc.slots[name]
pc.lock.RUnlock()
if ok {
return slot
}
pc.lock.Lock()
defer pc.lock.Unlock()
// double-check
if slot, ok = pc.slots[name]; ok {
return slot
}
slot = &profileSlot{}
pc.slots[name] = slot
return slot
}
func generateReport() string { func generateReport() string {
var buffer bytes.Buffer var builder strings.Builder
buffer.WriteString("Profiling report\n") builder.WriteString("Profiling report\n")
var data [][]string builder.WriteString("QUEUE,LIFECOUNT,LIFECYCLE,LASTCOUNT,LASTCYCLE\n")
calcFn := func(total, count int64) string { calcFn := func(total, count int64) string {
if count == 0 { if count == 0 {
return "-" return "-"
} }
return (time.Duration(total) / time.Duration(count)).String() return (time.Duration(total) / time.Duration(count)).String()
} }
func() { pc.lock.Lock()
pc.lock.Lock() for key, slot := range pc.slots {
defer pc.lock.Unlock() builder.WriteString(fmt.Sprintf("%s,%d,%s,%d,%s\n",
key,
slot.lifecount,
calcFn(slot.lifecycle, slot.lifecount),
slot.lastcount,
calcFn(slot.lastcycle, slot.lastcount),
))
for key, slot := range pc.slots { // reset last cycle stats
data = append(data, []string{ atomic.StoreInt64(&slot.lastcount, 0)
key, atomic.StoreInt64(&slot.lastcycle, 0)
strconv.FormatInt(slot.lifecount, 10), }
calcFn(slot.lifecycle, slot.lifecount), pc.lock.Unlock()
strconv.FormatInt(slot.lastcount, 10),
calcFn(slot.lastcycle, slot.lastcount),
})
// reset the data for last cycle return builder.String()
slot.lastcount = 0
slot.lastcycle = 0
}
}()
table := tablewriter.NewWriter(&buffer)
table.SetHeader([]string{"QUEUE", "LIFECOUNT", "LIFECYCLE", "LASTCOUNT", "LASTCYCLE"})
table.SetBorder(false)
table.AppendBulk(data)
table.Render()
return buffer.String()
} }

View File

@@ -8,7 +8,6 @@ import (
) )
func TestReport(t *testing.T) { func TestReport(t *testing.T) {
once.Do(func() {})
assert.NotContains(t, generateReport(), "foo") assert.NotContains(t, generateReport(), "foo")
report("foo", time.Second) report("foo", time.Second)
assert.Contains(t, generateReport(), "foo") assert.Contains(t, generateReport(), "foo")

View File

@@ -8,6 +8,7 @@ import (
"github.com/zeromicro/go-zero/core/stat" "github.com/zeromicro/go-zero/core/stat"
"github.com/zeromicro/go-zero/core/trace" "github.com/zeromicro/go-zero/core/trace"
"github.com/zeromicro/go-zero/internal/devserver" "github.com/zeromicro/go-zero/internal/devserver"
"github.com/zeromicro/go-zero/internal/profiling"
) )
const ( const (
@@ -38,6 +39,8 @@ type (
Telemetry trace.Config `json:",optional"` Telemetry trace.Config `json:",optional"`
DevServer DevServerConfig `json:",optional"` DevServer DevServerConfig `json:",optional"`
Shutdown proc.ShutdownConf `json:",optional"` Shutdown proc.ShutdownConf `json:",optional"`
// Profiling is the configuration for continuous profiling.
Profiling profiling.Config `json:",optional"`
} }
) )
@@ -70,7 +73,9 @@ func (sc ServiceConf) SetUp() error {
if len(sc.MetricsUrl) > 0 { if len(sc.MetricsUrl) > 0 {
stat.SetReportWriter(stat.NewRemoteWriter(sc.MetricsUrl)) stat.SetReportWriter(stat.NewRemoteWriter(sc.MetricsUrl))
} }
devserver.StartAgent(sc.DevServer) devserver.StartAgent(sc.DevServer)
profiling.Start(sc.Profiling)
return nil return nil
} }

View File

@@ -1,9 +1,10 @@
package service package service
import ( import (
"sync"
"github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/proc" "github.com/zeromicro/go-zero/core/proc"
"github.com/zeromicro/go-zero/core/syncx"
"github.com/zeromicro/go-zero/core/threading" "github.com/zeromicro/go-zero/core/threading"
) )
@@ -35,7 +36,7 @@ type (
// NewServiceGroup returns a ServiceGroup. // NewServiceGroup returns a ServiceGroup.
func NewServiceGroup() *ServiceGroup { func NewServiceGroup() *ServiceGroup {
sg := new(ServiceGroup) sg := new(ServiceGroup)
sg.stopOnce = syncx.Once(sg.doStop) sg.stopOnce = sync.OnceFunc(sg.doStop)
return sg return sg
} }

View File

@@ -19,7 +19,6 @@ import (
const ( const (
clusterNameKey = "CLUSTER_NAME" clusterNameKey = "CLUSTER_NAME"
testEnv = "test.v" testEnv = "test.v"
timeFormat = "2006-01-02 15:04:05"
) )
var ( var (
@@ -45,7 +44,7 @@ func Report(msg string) {
if fn != nil { if fn != nil {
reported := lessExecutor.DoOrDiscard(func() { reported := lessExecutor.DoOrDiscard(func() {
var builder strings.Builder var builder strings.Builder
builder.WriteString(fmt.Sprintln(time.Now().Format(timeFormat))) builder.WriteString(fmt.Sprintln(time.Now().Format(time.DateTime)))
if len(clusterName) > 0 { if len(clusterName) > 0 {
builder.WriteString(fmt.Sprintf("cluster: %s\n", clusterName)) builder.WriteString(fmt.Sprintf("cluster: %s\n", clusterName))
} }

View File

@@ -0,0 +1,29 @@
package sqlx
import "errors"
var (
errEmptyDatasource = errors.New("empty datasource")
errEmptyDriverName = errors.New("empty driver name")
)
// SqlConf defines the configuration for sqlx.
type SqlConf struct {
DataSource string
DriverName string `json:",default=mysql"`
Replicas []string `json:",optional"`
Policy string `json:",default=round-robin,options=round-robin|random"`
}
// Validate validates the SqlxConf.
func (sc SqlConf) Validate() error {
if len(sc.DataSource) == 0 {
return errEmptyDatasource
}
if len(sc.DriverName) == 0 {
return errEmptyDriverName
}
return nil
}

View File

@@ -0,0 +1,29 @@
package sqlx
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/conf"
)
func TestValidate(t *testing.T) {
text := []byte(`DataSource: primary:password@tcp(127.0.0.1:3306)/primary_db
`)
var sc SqlConf
err := conf.LoadFromYamlBytes(text, &sc)
assert.Nil(t, err)
assert.Equal(t, "mysql", sc.DriverName)
assert.Equal(t, policyRoundRobin, sc.Policy)
assert.Nil(t, sc.Validate())
sc = SqlConf{}
assert.Equal(t, errEmptyDatasource, sc.Validate())
sc.DataSource = "primary:password@tcp(127.0.0.1:3306)/primary_db"
assert.Equal(t, errEmptyDriverName, sc.Validate())
sc.DriverName = "mysql"
assert.Nil(t, sc.Validate())
}

View File

@@ -0,0 +1,65 @@
package sqlx
import "context"
const (
// policyRoundRobin round-robin policy for selecting replicas.
policyRoundRobin = "round-robin"
// policyRandom random policy for selecting replicas.
policyRandom = "random"
// readPrimaryMode indicates that the operation is a read,
// but should be performed on the primary database instance.
//
// This mode is used in scenarios where data freshness and consistency are critical,
// such as immediately after writes or where replication lag may cause stale reads.
readPrimaryMode readWriteMode = "read-primary"
// readReplicaMode indicates that the operation is a read from replicas.
// This is suitable for scenarios where eventual consistency is acceptable,
// and the goal is to offload traffic from the primary and improve read scalability.
readReplicaMode readWriteMode = "read-replica"
// writeMode indicates that the operation is a write operation (to primary).
writeMode readWriteMode = "write"
// notSpecifiedMode indicates that the read/write mode is not specified.
notSpecifiedMode readWriteMode = ""
)
type readWriteModeKey struct{}
// WithReadPrimary sets the context to read-primary mode.
func WithReadPrimary(ctx context.Context) context.Context {
return context.WithValue(ctx, readWriteModeKey{}, readPrimaryMode)
}
// WithReadReplica sets the context to read-replica mode.
func WithReadReplica(ctx context.Context) context.Context {
return context.WithValue(ctx, readWriteModeKey{}, readReplicaMode)
}
// WithWrite sets the context to write mode, indicating that the operation is a write operation.
func WithWrite(ctx context.Context) context.Context {
return context.WithValue(ctx, readWriteModeKey{}, writeMode)
}
type readWriteMode string
func (m readWriteMode) isValid() bool {
return m == readPrimaryMode || m == readReplicaMode || m == writeMode
}
func getReadWriteMode(ctx context.Context) readWriteMode {
if mode := ctx.Value(readWriteModeKey{}); mode != nil {
if v, ok := mode.(readWriteMode); ok && v.isValid() {
return v
}
}
return notSpecifiedMode
}
func usePrimary(ctx context.Context) bool {
return getReadWriteMode(ctx) != readReplicaMode
}

View File

@@ -0,0 +1,142 @@
package sqlx
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
)
func TestIsValid(t *testing.T) {
testCases := []struct {
name string
mode readWriteMode
expected bool
}{
{
name: "valid read-primary mode",
mode: readPrimaryMode,
expected: true,
},
{
name: "valid read-replica mode",
mode: readReplicaMode,
expected: true,
},
{
name: "valid write mode",
mode: writeMode,
expected: true,
},
{
name: "not specified mode (empty)",
mode: notSpecifiedMode,
expected: false,
},
{
name: "invalid custom string",
mode: readWriteMode("delete"),
expected: false,
},
{
name: "case sensitive check",
mode: readWriteMode("READ"),
expected: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
actual := tc.mode.isValid()
assert.Equal(t, tc.expected, actual)
})
}
}
func TestWithReadMode(t *testing.T) {
ctx := context.Background()
readPrimaryCtx := WithReadPrimary(ctx)
val := readPrimaryCtx.Value(readWriteModeKey{})
assert.Equal(t, readPrimaryMode, val)
readReplicaCtx := WithReadReplica(ctx)
val = readReplicaCtx.Value(readWriteModeKey{})
assert.Equal(t, readReplicaMode, val)
}
func TestWithWriteMode(t *testing.T) {
ctx := context.Background()
writeCtx := WithWrite(ctx)
val := writeCtx.Value(readWriteModeKey{})
assert.Equal(t, writeMode, val)
}
func TestGetReadWriteMode(t *testing.T) {
t.Run("valid read-primary mode", func(t *testing.T) {
ctx := context.WithValue(context.Background(), readWriteModeKey{}, readPrimaryMode)
assert.Equal(t, readPrimaryMode, getReadWriteMode(ctx))
})
t.Run("valid read-replica mode", func(t *testing.T) {
ctx := context.WithValue(context.Background(), readWriteModeKey{}, readReplicaMode)
assert.Equal(t, readReplicaMode, getReadWriteMode(ctx))
})
t.Run("valid write mode", func(t *testing.T) {
ctx := context.WithValue(context.Background(), readWriteModeKey{}, writeMode)
assert.Equal(t, writeMode, getReadWriteMode(ctx))
})
t.Run("invalid mode value (wrong type)", func(t *testing.T) {
ctx := context.WithValue(context.Background(), readWriteModeKey{}, "not-a-mode")
assert.Equal(t, notSpecifiedMode, getReadWriteMode(ctx))
})
t.Run("invalid mode value (wrong value)", func(t *testing.T) {
ctx := context.WithValue(context.Background(), readWriteModeKey{}, readWriteMode("delete"))
assert.Equal(t, notSpecifiedMode, getReadWriteMode(ctx))
})
t.Run("no mode set", func(t *testing.T) {
ctx := context.Background()
assert.Equal(t, notSpecifiedMode, getReadWriteMode(ctx))
})
}
func TestUsePrimary(t *testing.T) {
t.Run("context with read-replica mode", func(t *testing.T) {
ctx := context.WithValue(context.Background(), readWriteModeKey{}, readReplicaMode)
assert.False(t, usePrimary(ctx))
})
t.Run("context with read-primary mode", func(t *testing.T) {
ctx := context.WithValue(context.Background(), readWriteModeKey{}, readPrimaryMode)
assert.True(t, usePrimary(ctx))
})
t.Run("context with write mode", func(t *testing.T) {
ctx := context.WithValue(context.Background(), readWriteModeKey{}, writeMode)
assert.True(t, usePrimary(ctx))
})
t.Run("context with invalid mode", func(t *testing.T) {
ctx := context.WithValue(context.Background(), readWriteModeKey{}, readWriteMode("invalid"))
assert.True(t, usePrimary(ctx))
})
t.Run("context with no mode set", func(t *testing.T) {
ctx := context.Background()
assert.True(t, usePrimary(ctx))
})
}
func TestWithModeTwice(t *testing.T) {
ctx := context.Background()
ctx = WithReadPrimary(ctx)
writeCtx := WithWrite(ctx)
val := writeCtx.Value(readWriteModeKey{})
assert.Equal(t, writeMode, val)
}

View File

@@ -4,6 +4,9 @@ import (
"context" "context"
"database/sql" "database/sql"
"errors" "errors"
"fmt"
"math/rand"
"sync/atomic"
"github.com/zeromicro/go-zero/core/breaker" "github.com/zeromicro/go-zero/core/breaker"
"github.com/zeromicro/go-zero/core/errorx" "github.com/zeromicro/go-zero/core/errorx"
@@ -52,9 +55,10 @@ type (
beginTx beginnable beginTx beginnable
brk breaker.Breaker brk breaker.Breaker
accept breaker.Acceptable accept breaker.Acceptable
index uint32
} }
connProvider func() (*sql.DB, error) connProvider func(ctx context.Context) (*sql.DB, error)
sessionConn interface { sessionConn interface {
Exec(query string, args ...any) (sql.Result, error) Exec(query string, args ...any) (sql.Result, error)
@@ -64,10 +68,41 @@ type (
} }
) )
// MustNewConn returns a SqlConn with the given SqlConf.
func MustNewConn(c SqlConf, opts ...SqlOption) SqlConn {
conn, err := NewConn(c, opts...)
if err != nil {
logx.Must(err)
}
return conn
}
// NewConn returns a SqlConn with the given SqlConf.
func NewConn(c SqlConf, opts ...SqlOption) (SqlConn, error) {
if err := c.Validate(); err != nil {
return nil, err
}
conn := &commonSqlConn{
onError: func(ctx context.Context, err error) {
logInstanceError(ctx, c.DataSource, err)
},
beginTx: begin,
brk: breaker.NewBreaker(),
}
for _, opt := range opts {
opt(conn)
}
conn.connProv = getConnProvider(conn, c.DriverName, c.DataSource, c.Policy, c.Replicas)
return conn, nil
}
// NewSqlConn returns a SqlConn with given driver name and datasource. // NewSqlConn returns a SqlConn with given driver name and datasource.
func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn { func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
conn := &commonSqlConn{ conn := &commonSqlConn{
connProv: func() (*sql.DB, error) { connProv: func(context.Context) (*sql.DB, error) {
return getSqlConn(driverName, datasource) return getSqlConn(driverName, datasource)
}, },
onError: func(ctx context.Context, err error) { onError: func(ctx context.Context, err error) {
@@ -87,7 +122,7 @@ func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
// Use it with caution; it's provided for other ORM to interact with. // Use it with caution; it's provided for other ORM to interact with.
func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn { func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn {
conn := &commonSqlConn{ conn := &commonSqlConn{
connProv: func() (*sql.DB, error) { connProv: func(ctx context.Context) (*sql.DB, error) {
return db, nil return db, nil
}, },
onError: func(ctx context.Context, err error) { onError: func(ctx context.Context, err error) {
@@ -123,7 +158,7 @@ func (db *commonSqlConn) ExecCtx(ctx context.Context, q string, args ...any) (
err = db.brk.DoWithAcceptableCtx(ctx, func() error { err = db.brk.DoWithAcceptableCtx(ctx, func() error {
var conn *sql.DB var conn *sql.DB
conn, err = db.connProv() conn, err = db.connProv(ctx)
if err != nil { if err != nil {
db.onError(ctx, err) db.onError(ctx, err)
return err return err
@@ -151,7 +186,7 @@ func (db *commonSqlConn) PrepareCtx(ctx context.Context, query string) (stmt Stm
err = db.brk.DoWithAcceptableCtx(ctx, func() error { err = db.brk.DoWithAcceptableCtx(ctx, func() error {
var conn *sql.DB var conn *sql.DB
conn, err = db.connProv() conn, err = db.connProv(ctx)
if err != nil { if err != nil {
db.onError(ctx, err) db.onError(ctx, err)
return err return err
@@ -242,7 +277,7 @@ func (db *commonSqlConn) QueryRowsPartialCtx(ctx context.Context, v any,
} }
func (db *commonSqlConn) RawDB() (*sql.DB, error) { func (db *commonSqlConn) RawDB() (*sql.DB, error) {
return db.connProv() return db.connProv(context.Background())
} }
func (db *commonSqlConn) Transact(fn func(Session) error) error { func (db *commonSqlConn) Transact(fn func(Session) error) error {
@@ -288,7 +323,7 @@ func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows)
q string, args ...any) (err error) { q string, args ...any) (err error) {
var scanFailed bool var scanFailed bool
err = db.brk.DoWithAcceptableCtx(ctx, func() error { err = db.brk.DoWithAcceptableCtx(ctx, func() error {
conn, err := db.connProv() conn, err := db.connProv(ctx)
if err != nil { if err != nil {
db.onError(ctx, err) db.onError(ctx, err)
return err return err
@@ -311,6 +346,38 @@ func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows)
return return
} }
func getConnProvider(sc *commonSqlConn, driverName, datasource, policy string, replicas []string) connProvider {
return func(ctx context.Context) (*sql.DB, error) {
replicaCount := len(replicas)
if replicaCount == 0 || usePrimary(ctx) {
return getSqlConn(driverName, datasource)
}
var dsn string
if replicaCount == 1 {
dsn = replicas[0]
} else {
if len(policy) == 0 {
policy = policyRoundRobin
}
switch policy {
case policyRandom:
dsn = replicas[rand.Intn(replicaCount)]
case policyRoundRobin:
index := atomic.AddUint32(&sc.index, 1) - 1
dsn = replicas[index%uint32(replicaCount)]
default:
return nil, fmt.Errorf("unknown policy: %s", policy)
}
}
return getSqlConn(driverName, dsn)
}
}
// WithAcceptable returns a SqlOption that setting the acceptable function. // WithAcceptable returns a SqlOption that setting the acceptable function.
// acceptable is the func to check if the error can be accepted. // acceptable is the func to check if the error can be accepted.
func WithAcceptable(acceptable func(err error) bool) SqlOption { func WithAcceptable(acceptable func(err error) bool) SqlOption {

View File

@@ -1,6 +1,7 @@
package sqlx package sqlx
import ( import (
"context"
"database/sql" "database/sql"
"errors" "errors"
"io" "io"
@@ -98,7 +99,7 @@ func TestSqlConn_RawDB(t *testing.T) {
func TestSqlConn_Errors(t *testing.T) { func TestSqlConn_Errors(t *testing.T) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
conn := NewSqlConnFromDB(db) conn := NewSqlConnFromDB(db)
conn.(*commonSqlConn).connProv = func() (*sql.DB, error) { conn.(*commonSqlConn).connProv = func(ctx context.Context) (*sql.DB, error) {
return nil, errors.New("error") return nil, errors.New("error")
} }
_, err := conn.Prepare("any") _, err := conn.Prepare("any")
@@ -138,6 +139,148 @@ func TestSqlConn_Errors(t *testing.T) {
}) })
} }
func TestConfigSqlConn(t *testing.T) {
db, mock, err := sqlmock.New()
assert.NotNil(t, db)
assert.NotNil(t, mock)
assert.Nil(t, err)
connManager.Inject(mockedDatasource, db)
mock.ExpectExec("any")
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"foo"}))
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
conn := MustNewConn(conf, withMysqlAcceptable())
_, err = conn.Exec("any", "value")
assert.NotNil(t, err)
_, err = conn.Prepare("any")
assert.NotNil(t, err)
var val string
assert.NotNil(t, conn.QueryRow(&val, "any"))
assert.NotNil(t, conn.QueryRowPartial(&val, "any"))
assert.NotNil(t, conn.QueryRows(&val, "any"))
assert.NotNil(t, conn.QueryRowsPartial(&val, "any"))
}
func TestConfigSqlConnStatement(t *testing.T) {
db, mock, err := sqlmock.New()
assert.NotNil(t, db)
assert.NotNil(t, mock)
assert.Nil(t, err)
connManager.Inject(mockedDatasource, db)
mock.ExpectPrepare("any")
mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
mock.ExpectPrepare("any")
row := sqlmock.NewRows([]string{"foo"}).AddRow("bar")
mock.ExpectQuery("any").WillReturnRows(row)
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
conn := MustNewConn(conf, withMysqlAcceptable())
stmt, err := conn.Prepare("any")
assert.NoError(t, err)
res, err := stmt.Exec()
assert.NoError(t, err)
lastInsertID, err := res.LastInsertId()
assert.NoError(t, err)
assert.Equal(t, int64(2), lastInsertID)
rowsAffected, err := res.RowsAffected()
assert.NoError(t, err)
assert.Equal(t, int64(3), rowsAffected)
stmt, err = conn.Prepare("any")
assert.NoError(t, err)
var val string
err = stmt.QueryRow(&val)
assert.NoError(t, err)
assert.Equal(t, "bar", val)
mock.ExpectPrepare("any")
rows := sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar")
mock.ExpectQuery("any").WillReturnRows(rows)
stmt, err = conn.Prepare("any")
assert.NoError(t, err)
var vals []string
assert.NoError(t, stmt.QueryRowsPartial(&vals))
assert.ElementsMatch(t, []string{"foo", "bar"}, vals)
}
func TestConfigSqlConnQuery(t *testing.T) {
db, mock, err := sqlmock.New()
assert.NotNil(t, db)
assert.NotNil(t, mock)
assert.Nil(t, err)
connManager.Inject(mockedDatasource, db)
t.Run("QueryRow", func(t *testing.T) {
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"foo"}).AddRow("bar"))
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
conn := MustNewConn(conf)
var val string
assert.NoError(t, conn.QueryRow(&val, "any"))
assert.Equal(t, "bar", val)
})
t.Run("QueryRowPartial", func(t *testing.T) {
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"foo"}).AddRow("bar"))
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
conn := MustNewConn(conf)
var val string
assert.NoError(t, conn.QueryRowPartial(&val, "any"))
assert.Equal(t, "bar", val)
})
t.Run("QueryRows", func(t *testing.T) {
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar"))
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
conn := MustNewConn(conf)
var vals []string
assert.NoError(t, conn.QueryRows(&vals, "any"))
assert.ElementsMatch(t, []string{"foo", "bar"}, vals)
})
t.Run("QueryRowsPartial", func(t *testing.T) {
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar"))
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
conn := MustNewConn(conf)
var vals []string
assert.NoError(t, conn.QueryRowsPartial(&vals, "any"))
assert.ElementsMatch(t, []string{"foo", "bar"}, vals)
})
}
func TestConfigSqlConnErr(t *testing.T) {
t.Run("panic on empty config", func(t *testing.T) {
original := logx.ExitOnFatal.True()
logx.ExitOnFatal.Set(false)
defer logx.ExitOnFatal.Set(original)
assert.Panics(t, func() {
MustNewConn(SqlConf{})
})
})
t.Run("on error", func(t *testing.T) {
db, mock, err := sqlmock.New()
assert.NotNil(t, db)
assert.NotNil(t, mock)
assert.Nil(t, err)
connManager.Inject(mockedDatasource, db)
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
conn := MustNewConn(conf)
conn.(*commonSqlConn).connProv = func(ctx context.Context) (*sql.DB, error) {
return nil, errors.New("error")
}
_, err = conn.Prepare("any")
assert.Error(t, err)
})
}
func TestStatement(t *testing.T) { func TestStatement(t *testing.T) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectPrepare("any").WillBeClosed() mock.ExpectPrepare("any").WillBeClosed()
@@ -303,6 +446,93 @@ func TestWithAcceptable(t *testing.T) {
assert.True(t, conn.accept(acceptableErr3)) assert.True(t, conn.accept(acceptableErr3))
} }
func TestProvider(t *testing.T) {
defer func() {
_ = connManager.Close()
}()
primaryDSN := "primary:password@tcp(127.0.0.1:3306)/primary_db"
replicasDSN := []string{
"replica_one:pwd@tcp(localhost:3306)/replica_one",
"replica_two:pwd@tcp(localhost:3306)/replica_two",
"replica_three:pwd@tcp(localhost:3306)/replica_three",
}
primaryDB, err := connManager.GetResource(primaryDSN, func() (io.Closer, error) { return sql.Open(mysqlDriverName, primaryDSN) })
assert.Nil(t, err)
assert.NotNil(t, primaryDB)
replicaOneDB, err := connManager.GetResource(replicasDSN[0], func() (io.Closer, error) { return sql.Open(mysqlDriverName, replicasDSN[0]) })
assert.Nil(t, err)
assert.NotNil(t, replicaOneDB)
replicaTwoDB, err := connManager.GetResource(replicasDSN[1], func() (io.Closer, error) { return sql.Open(mysqlDriverName, replicasDSN[1]) })
assert.Nil(t, err)
assert.NotNil(t, replicaTwoDB)
replicaThreeDB, err := connManager.GetResource(replicasDSN[2], func() (io.Closer, error) { return sql.Open(mysqlDriverName, replicasDSN[2]) })
assert.Nil(t, err)
assert.NotNil(t, replicaThreeDB)
sc := &commonSqlConn{}
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRoundRobin, nil)
ctx := context.Background()
db, err := sc.connProv(ctx)
assert.Nil(t, err)
assert.Equal(t, primaryDB, db)
ctx = WithWrite(ctx)
db, err = sc.connProv(ctx)
assert.Nil(t, err)
assert.Equal(t, primaryDB, db)
ctx = WithReadPrimary(ctx)
db, err = sc.connProv(ctx)
assert.Nil(t, err)
assert.Equal(t, primaryDB, db)
// no mode set, should return primary
ctx = context.Background()
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRoundRobin, replicasDSN)
db, err = sc.connProv(ctx)
assert.Nil(t, err)
assert.Equal(t, primaryDB, db)
ctx = WithReadReplica(ctx)
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRoundRobin, []string{replicasDSN[0]})
db, err = sc.connProv(ctx)
assert.Nil(t, err)
assert.Equal(t, replicaOneDB, db)
// default policy is round-robin
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRoundRobin, replicasDSN)
replicas := []io.Closer{replicaOneDB, replicaTwoDB, replicaThreeDB}
for i := 0; i < len(replicasDSN); i++ {
db, err = sc.connProv(ctx)
assert.Nil(t, err)
assert.Equal(t, replicas[i], db)
}
// random policy
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRandom, replicasDSN)
for i := 0; i < len(replicasDSN); i++ {
db, err = sc.connProv(ctx)
assert.Nil(t, err)
assert.Contains(t, replicas, db)
}
// unknown policy
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, "unknown", replicasDSN)
_, err = sc.connProv(ctx)
assert.NotNil(t, err)
// empty policy transforms to round-robin
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, "", replicasDSN)
for i := 0; i < len(replicasDSN); i++ {
db, err = sc.connProv(ctx)
assert.Nil(t, err)
assert.Equal(t, replicas[i], db)
}
}
func buildConn() (mock sqlmock.Sqlmock, err error) { func buildConn() (mock sqlmock.Sqlmock, err error) {
_, err = connManager.GetResource(mockedDatasource, func() (io.Closer, error) { _, err = connManager.GetResource(mockedDatasource, func() (io.Closer, error) {
var db *sql.DB var db *sql.DB

View File

@@ -27,7 +27,7 @@ func getCachedSqlConn(driverName, server string) (*sql.DB, error) {
return nil, err return nil, err
} }
if driverName != mysqlDriverName { if driverName == mysqlDriverName {
if cfg, e := mysql.ParseDSN(server); e != nil { if cfg, e := mysql.ParseDSN(server); e != nil {
// if cannot parse, don't collect the metrics // if cannot parse, don't collect the metrics
logx.Error(e) logx.Error(e)

View File

@@ -156,7 +156,7 @@ func begin(db *sql.DB) (trans, error) {
func transact(ctx context.Context, db *commonSqlConn, b beginnable, func transact(ctx context.Context, db *commonSqlConn, b beginnable,
fn func(context.Context, Session) error) (err error) { fn func(context.Context, Session) error) (err error) {
conn, err := db.connProv() conn, err := db.connProv(ctx)
if err != nil { if err != nil {
db.onError(ctx, err) db.onError(ctx, err)
return err return err

View File

@@ -117,7 +117,7 @@ func TestTxExceptions(t *testing.T) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
conn := &commonSqlConn{ conn := &commonSqlConn{
connProv: func() (*sql.DB, error) { connProv: func(ctx context.Context) (*sql.DB, error) {
return nil, errors.New("foo") return nil, errors.New("foo")
}, },
beginTx: begin, beginTx: begin,

View File

@@ -2,6 +2,7 @@ package stringx
import ( import (
"errors" "errors"
"slices"
"unicode" "unicode"
"github.com/zeromicro/go-zero/core/lang" "github.com/zeromicro/go-zero/core/lang"
@@ -15,14 +16,9 @@ var (
) )
// Contains checks if str is in list. // Contains checks if str is in list.
// Deprecated: use slices.Contains instead.
func Contains(list []string, str string) bool { func Contains(list []string, str string) bool {
for _, each := range list { return slices.Contains(list, str)
if each == str {
return true
}
}
return false
} }
// Filter filters chars from s with given filter function. // Filter filters chars from s with given filter function.
@@ -123,11 +119,7 @@ func Remove(strings []string, strs ...string) []string {
// Reverse reverses s. // Reverse reverses s.
func Reverse(s string) string { func Reverse(s string) string {
runes := []rune(s) runes := []rune(s)
slices.Reverse(runes)
for from, to := 0, len(runes)-1; from < to; from, to = from+1, to-1 {
runes[from], runes[to] = runes[to], runes[from]
}
return string(runes) return string(runes)
} }

View File

@@ -7,6 +7,28 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestContainsString(t *testing.T) {
cases := []struct {
slice []string
value string
expect bool
}{
{[]string{"1"}, "1", true},
{[]string{"1"}, "2", false},
{[]string{"1", "2"}, "1", true},
{[]string{"1", "2"}, "3", false},
{nil, "3", false},
{nil, "", false},
}
for _, each := range cases {
t.Run(path.Join(each.slice...), func(t *testing.T) {
actual := Contains(each.slice, each.value)
assert.Equal(t, each.expect, actual)
})
}
}
func TestNotEmpty(t *testing.T) { func TestNotEmpty(t *testing.T) {
cases := []struct { cases := []struct {
args []string args []string
@@ -41,28 +63,6 @@ func TestNotEmpty(t *testing.T) {
} }
} }
func TestContainsString(t *testing.T) {
cases := []struct {
slice []string
value string
expect bool
}{
{[]string{"1"}, "1", true},
{[]string{"1"}, "2", false},
{[]string{"1", "2"}, "1", true},
{[]string{"1", "2"}, "3", false},
{nil, "3", false},
{nil, "", false},
}
for _, each := range cases {
t.Run(path.Join(each.slice...), func(t *testing.T) {
actual := Contains(each.slice, each.value)
assert.Equal(t, each.expect, actual)
})
}
}
func TestFilter(t *testing.T) { func TestFilter(t *testing.T) {
cases := []struct { cases := []struct {
input string input string

View File

@@ -3,9 +3,7 @@ package syncx
import "sync" import "sync"
// Once returns a func that guarantees fn can only called once. // Once returns a func that guarantees fn can only called once.
// Deprecated: use sync.OnceFunc instead.
func Once(fn func()) func() { func Once(fn func()) func() {
once := new(sync.Once) return sync.OnceFunc(fn)
return func() {
once.Do(fn)
}
} }

View File

@@ -1,10 +1,10 @@
package utils package utils
import ( import (
"cmp"
"strconv" "strconv"
"strings" "strings"
"github.com/zeromicro/go-zero/core/mathx"
"github.com/zeromicro/go-zero/core/stringx" "github.com/zeromicro/go-zero/core/stringx"
) )
@@ -39,7 +39,7 @@ func compare(v1, v2 string) int {
fields1, fields2 := strings.Split(v1, "."), strings.Split(v2, ".") fields1, fields2 := strings.Split(v1, "."), strings.Split(v2, ".")
ver1, ver2 := strsToInts(fields1), strsToInts(fields2) ver1, ver2 := strsToInts(fields1), strsToInts(fields2)
ver1len, ver2len := len(ver1), len(ver2) ver1len, ver2len := len(ver1), len(ver2)
shorter := mathx.MinInt(ver1len, ver2len) shorter := min(ver1len, ver2len)
for i := 0; i < shorter; i++ { for i := 0; i < shorter; i++ {
if ver1[i] == ver2[i] { if ver1[i] == ver2[i] {
@@ -50,14 +50,7 @@ func compare(v1, v2 string) int {
return 1 return 1
} }
} }
return cmp.Compare(ver1len, ver2len)
if ver1len < ver2len {
return -1
} else if ver1len == ver2len {
return 0
} else {
return 1
}
} }
func strsToInts(strs []string) []int64 { func strsToInts(strs []string) []int64 {

10
go.mod
View File

@@ -4,7 +4,7 @@ go 1.21
require ( require (
github.com/DATA-DOG/go-sqlmock v1.5.2 github.com/DATA-DOG/go-sqlmock v1.5.2
github.com/alicebob/miniredis/v2 v2.34.0 github.com/alicebob/miniredis/v2 v2.35.0
github.com/fatih/color v1.18.0 github.com/fatih/color v1.18.0
github.com/fullstorydev/grpcurl v1.9.3 github.com/fullstorydev/grpcurl v1.9.3
github.com/go-sql-driver/mysql v1.9.0 github.com/go-sql-driver/mysql v1.9.0
@@ -12,17 +12,17 @@ require (
github.com/golang/mock v1.6.0 github.com/golang/mock v1.6.0
github.com/golang/protobuf v1.5.4 github.com/golang/protobuf v1.5.4
github.com/google/uuid v1.6.0 github.com/google/uuid v1.6.0
github.com/grafana/pyroscope-go v1.2.2
github.com/jackc/pgx/v5 v5.7.4 github.com/jackc/pgx/v5 v5.7.4
github.com/jhump/protoreflect v1.17.0 github.com/jhump/protoreflect v1.17.0
github.com/olekukonko/tablewriter v0.0.5
github.com/pelletier/go-toml/v2 v2.2.2 github.com/pelletier/go-toml/v2 v2.2.2
github.com/prometheus/client_golang v1.21.1 github.com/prometheus/client_golang v1.21.1
github.com/redis/go-redis/v9 v9.8.0 github.com/redis/go-redis/v9 v9.11.0
github.com/spaolacci/murmur3 v1.1.0 github.com/spaolacci/murmur3 v1.1.0
github.com/stretchr/testify v1.10.0 github.com/stretchr/testify v1.10.0
go.etcd.io/etcd/api/v3 v3.5.15 go.etcd.io/etcd/api/v3 v3.5.15
go.etcd.io/etcd/client/v3 v3.5.15 go.etcd.io/etcd/client/v3 v3.5.15
go.mongodb.org/mongo-driver v1.17.3 go.mongodb.org/mongo-driver v1.17.4
go.opentelemetry.io/otel v1.24.0 go.opentelemetry.io/otel v1.24.0
go.opentelemetry.io/otel/exporters/jaeger v1.17.0 go.opentelemetry.io/otel/exporters/jaeger v1.17.0
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.24.0 go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.24.0
@@ -50,7 +50,6 @@ require (
require ( require (
filippo.io/edwards25519 v1.1.0 // indirect filippo.io/edwards25519 v1.1.0 // indirect
github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 // indirect
github.com/beorn7/perks v1.0.1 // indirect github.com/beorn7/perks v1.0.1 // indirect
github.com/bufbuild/protocompile v0.14.1 // indirect github.com/bufbuild/protocompile v0.14.1 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect
@@ -73,6 +72,7 @@ require (
github.com/google/gnostic-models v0.6.8 // indirect github.com/google/gnostic-models v0.6.8 // indirect
github.com/google/go-cmp v0.6.0 // indirect github.com/google/go-cmp v0.6.0 // indirect
github.com/google/gofuzz v1.2.0 // indirect github.com/google/gofuzz v1.2.0 // indirect
github.com/grafana/pyroscope-go/godeltaprof v0.1.8 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 // indirect
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 // indirect github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect

21
go.sum
View File

@@ -2,10 +2,8 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 h1:uvdUDbHQHO85qeSydJtItA4T55Pw6BtAejd0APRJOCE= github.com/alicebob/miniredis/v2 v2.35.0 h1:QwLphYqCEAo1eu1TqPRN2jgVMPBweeQcR21jeqDCONI=
github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc= github.com/alicebob/miniredis/v2 v2.35.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM=
github.com/alicebob/miniredis/v2 v2.34.0 h1:mBFWMaJSNL9RwdGRyEDoAAv8OQc5UlEhLDQggTglU/0=
github.com/alicebob/miniredis/v2 v2.34.0/go.mod h1:kWShP4b58T1CW0Y5dViCd5ztzrDqRWqM3nksiyXk5s8=
github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8=
github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
@@ -82,6 +80,10 @@ github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1 h1:K6RDEckDVWvDI9JAJY
github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/grafana/pyroscope-go v1.2.2 h1:uvKCyZMD724RkaCEMrSTC38Yn7AnFe8S2wiAIYdDPCE=
github.com/grafana/pyroscope-go v1.2.2/go.mod h1:zzT9QXQAp2Iz2ZdS216UiV8y9uXJYQiGE1q8v1FyhqU=
github.com/grafana/pyroscope-go/godeltaprof v0.1.8 h1:iwOtYXeeVSAeYefJNaxDytgjKtUuKQbJqgAIjlnicKg=
github.com/grafana/pyroscope-go/godeltaprof v0.1.8/go.mod h1:2+l7K7twW49Ct4wFluZD3tZ6e0SjanjcUUBPVD/UuGU=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 h1:bkypFPDjIYGfCYD5mRBvpqxfYX1YCS1PXdKYWi8FsN0= github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 h1:bkypFPDjIYGfCYD5mRBvpqxfYX1YCS1PXdKYWi8FsN0=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0/go.mod h1:P+Lt/0by1T8bfcF3z737NnSbmxQAppXMRziHUxPOC8k= github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0/go.mod h1:P+Lt/0by1T8bfcF3z737NnSbmxQAppXMRziHUxPOC8k=
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw= github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw=
@@ -121,7 +123,6 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@@ -135,8 +136,6 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32 h1:W6apQkHrMkS0Muv8G/TipAy/FJl/rCYT0+EuS8+Z0z4= github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32 h1:W6apQkHrMkS0Muv8G/TipAy/FJl/rCYT0+EuS8+Z0z4=
github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32/go.mod h1:9wM+0iRr9ahx58uYLpLIr5fm8diHn0JbqRycJi6w0Ms= github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32/go.mod h1:9wM+0iRr9ahx58uYLpLIr5fm8diHn0JbqRycJi6w0Ms=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/onsi/ginkgo/v2 v2.13.0 h1:0jY9lJquiL8fcf3M4LAXN5aMlS/b2BV86HFFPCPMgE4= github.com/onsi/ginkgo/v2 v2.13.0 h1:0jY9lJquiL8fcf3M4LAXN5aMlS/b2BV86HFFPCPMgE4=
github.com/onsi/ginkgo/v2 v2.13.0/go.mod h1:TE309ZR8s5FsKKpuB1YAQYBzCaAfUgatB/xlT/ETL/o= github.com/onsi/ginkgo/v2 v2.13.0/go.mod h1:TE309ZR8s5FsKKpuB1YAQYBzCaAfUgatB/xlT/ETL/o=
github.com/onsi/gomega v1.29.0 h1:KIA/t2t5UBzoirT4H9tsML45GEbo3ouUnBHsCfD2tVg= github.com/onsi/gomega v1.29.0 h1:KIA/t2t5UBzoirT4H9tsML45GEbo3ouUnBHsCfD2tVg=
@@ -159,8 +158,8 @@ github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I= github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
github.com/redis/go-redis/v9 v9.8.0 h1:q3nRvjrlge/6UD7eTu/DSg2uYiU2mCL0G/uzBWqhicI= github.com/redis/go-redis/v9 v9.11.0 h1:E3S08Gl/nJNn5vkxd2i78wZxWAPNZgUNTp8WIJUAiIs=
github.com/redis/go-redis/v9 v9.8.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw= github.com/redis/go-redis/v9 v9.11.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
@@ -203,8 +202,8 @@ go.etcd.io/etcd/client/pkg/v3 v3.5.15 h1:fo0HpWz/KlHGMCC+YejpiCmyWDEuIpnTDzpJLB5
go.etcd.io/etcd/client/pkg/v3 v3.5.15/go.mod h1:mXDI4NAOwEiszrHCb0aqfAYNCrZP4e9hRca3d1YK8EU= go.etcd.io/etcd/client/pkg/v3 v3.5.15/go.mod h1:mXDI4NAOwEiszrHCb0aqfAYNCrZP4e9hRca3d1YK8EU=
go.etcd.io/etcd/client/v3 v3.5.15 h1:23M0eY4Fd/inNv1ZfU3AxrbbOdW79r9V9Rl62Nm6ip4= go.etcd.io/etcd/client/v3 v3.5.15 h1:23M0eY4Fd/inNv1ZfU3AxrbbOdW79r9V9Rl62Nm6ip4=
go.etcd.io/etcd/client/v3 v3.5.15/go.mod h1:CLSJxrYjvLtHsrPKsy7LmZEE+DK2ktfd2bN4RhBMwlU= go.etcd.io/etcd/client/v3 v3.5.15/go.mod h1:CLSJxrYjvLtHsrPKsy7LmZEE+DK2ktfd2bN4RhBMwlU=
go.mongodb.org/mongo-driver v1.17.3 h1:TQyXhnsWfWtgAhMtOgtYHMTkZIfBTpMTsMnd9ZBeHxQ= go.mongodb.org/mongo-driver v1.17.4 h1:jUorfmVzljjr0FLzYQsGP8cgN/qzzxlY9Vh0C9KFXVw=
go.mongodb.org/mongo-driver v1.17.3/go.mod h1:Hy04i7O2kC4RS06ZrhPRqj/u4DTYkFDAAccj+rVKqgQ= go.mongodb.org/mongo-driver v1.17.4/go.mod h1:Hy04i7O2kC4RS06ZrhPRqj/u4DTYkFDAAccj+rVKqgQ=
go.opentelemetry.io/otel v1.24.0 h1:0LAOdjNmQeSTzGBzduGe/rU4tZhMwL5rWgtp9Ku5Jfo= go.opentelemetry.io/otel v1.24.0 h1:0LAOdjNmQeSTzGBzduGe/rU4tZhMwL5rWgtp9Ku5Jfo=
go.opentelemetry.io/otel v1.24.0/go.mod h1:W7b9Ozg4nkF5tWI5zsXkaKKDjdVjpD4oAt9Qi/MArHo= go.opentelemetry.io/otel v1.24.0/go.mod h1:W7b9Ozg4nkF5tWI5zsXkaKKDjdVjpD4oAt9Qi/MArHo=
go.opentelemetry.io/otel/exporters/jaeger v1.17.0 h1:D7UpUy2Xc2wsi1Ras6V40q806WM07rqoCWzXu7Sqy+4= go.opentelemetry.io/otel/exporters/jaeger v1.17.0 h1:D7UpUy2Xc2wsi1Ras6V40q806WM07rqoCWzXu7Sqy+4=

View File

@@ -0,0 +1,263 @@
package profiling
import (
"runtime"
"sync"
"time"
"github.com/grafana/pyroscope-go"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/proc"
"github.com/zeromicro/go-zero/core/stat"
"github.com/zeromicro/go-zero/core/threading"
)
const (
defaultCheckInterval = time.Second * 10
defaultProfilingDuration = time.Minute * 2
defaultUploadRate = time.Second * 15
)
type (
Config struct {
// Name is the name of the application.
Name string `json:",optional,inherit"`
// ServerAddr is the address of the profiling server.
ServerAddr string
// AuthUser is the username for basic authentication.
AuthUser string `json:",optional"`
// AuthPassword is the password for basic authentication.
AuthPassword string `json:",optional"`
// UploadRate is the duration for which profiling data is uploaded.
UploadRate time.Duration `json:",default=15s"`
// CheckInterval is the interval to check if profiling should start.
CheckInterval time.Duration `json:",default=10s"`
// ProfilingDuration is the duration for which profiling data is collected.
ProfilingDuration time.Duration `json:",default=2m"`
// CpuThreshold the collection is allowed only when the current service cpu > CpuThreshold
CpuThreshold int64 `json:",default=700,range=[0:1000)"`
// ProfileType is the type of profiling to be performed.
ProfileType ProfileType
}
ProfileType struct {
// Logger is a flag to enable or disable logging.
Logger bool `json:",default=false"`
// CPU is a flag to disable CPU profiling.
CPU bool `json:",default=true"`
// Goroutines is a flag to disable goroutine profiling.
Goroutines bool `json:",default=true"`
// Memory is a flag to disable memory profiling.
Memory bool `json:",default=true"`
// Mutex is a flag to disable mutex profiling.
Mutex bool `json:",default=false"`
// Block is a flag to disable block profiling.
Block bool `json:",default=false"`
}
profiler interface {
Start() error
Stop() error
}
pyroscopeProfiler struct {
c Config
profiler *pyroscope.Profiler
}
)
var (
once sync.Once
newProfiler = func(c Config) profiler {
return newPyroscopeProfiler(c)
}
)
// Start initializes the pyroscope profiler with the given configuration.
func Start(c Config) {
// check if the profiling is enabled
if len(c.ServerAddr) == 0 {
return
}
// set default values for the configuration
if c.ProfilingDuration <= 0 {
c.ProfilingDuration = defaultProfilingDuration
}
// set default values for the configuration
if c.CheckInterval <= 0 {
c.CheckInterval = defaultCheckInterval
}
if c.UploadRate <= 0 {
c.UploadRate = defaultUploadRate
}
once.Do(func() {
logx.Info("continuous profiling started")
threading.GoSafe(func() {
startPyroscope(c, proc.Done())
})
})
}
// startPyroscope starts the pyroscope profiler with the given configuration.
func startPyroscope(c Config, done <-chan struct{}) {
var (
pr profiler
err error
latestProfilingTime time.Time
intervalTicker = time.NewTicker(c.CheckInterval)
profilingTicker = time.NewTicker(c.ProfilingDuration)
)
defer profilingTicker.Stop()
defer intervalTicker.Stop()
for {
select {
case <-intervalTicker.C:
// Check if the machine is overloaded and if the profiler is not running
if pr == nil && isCpuOverloaded(c) {
pr = newProfiler(c)
if err := pr.Start(); err != nil {
logx.Errorf("failed to start profiler: %v", err)
continue
}
// record the latest profiling time
latestProfilingTime = time.Now()
logx.Infof("pyroscope profiler started.")
}
case <-profilingTicker.C:
// check if the profiling duration has passed
if !time.Now().After(latestProfilingTime.Add(c.ProfilingDuration)) {
continue
}
// check if the profiler is already running, if so, skip
if pr != nil {
if err = pr.Stop(); err != nil {
logx.Errorf("failed to stop profiler: %v", err)
}
logx.Infof("pyroscope profiler stopped.")
pr = nil
}
case <-done:
logx.Infof("continuous profiling stopped.")
return
}
}
}
// genPyroscopeConf generates the pyroscope configuration based on the given config.
func genPyroscopeConf(c Config) pyroscope.Config {
pConf := pyroscope.Config{
UploadRate: c.UploadRate,
ApplicationName: c.Name,
BasicAuthUser: c.AuthUser, // http basic auth user
BasicAuthPassword: c.AuthPassword, // http basic auth password
ServerAddress: c.ServerAddr,
Logger: nil,
HTTPHeaders: map[string]string{},
// you can provide static tags via a map:
Tags: map[string]string{
"name": c.Name,
},
}
if c.ProfileType.Logger {
pConf.Logger = logx.WithCallerSkip(0)
}
if c.ProfileType.CPU {
pConf.ProfileTypes = append(pConf.ProfileTypes, pyroscope.ProfileCPU)
}
if c.ProfileType.Goroutines {
pConf.ProfileTypes = append(pConf.ProfileTypes, pyroscope.ProfileGoroutines)
}
if c.ProfileType.Memory {
pConf.ProfileTypes = append(pConf.ProfileTypes, pyroscope.ProfileAllocObjects, pyroscope.ProfileAllocSpace,
pyroscope.ProfileInuseObjects, pyroscope.ProfileInuseSpace)
}
if c.ProfileType.Mutex {
pConf.ProfileTypes = append(pConf.ProfileTypes, pyroscope.ProfileMutexCount, pyroscope.ProfileMutexDuration)
}
if c.ProfileType.Block {
pConf.ProfileTypes = append(pConf.ProfileTypes, pyroscope.ProfileBlockCount, pyroscope.ProfileBlockDuration)
}
logx.Infof("applicationName: %s", pConf.ApplicationName)
return pConf
}
// isCpuOverloaded checks the machine performance based on the given configuration.
func isCpuOverloaded(c Config) bool {
currentValue := stat.CpuUsage()
if currentValue >= c.CpuThreshold {
logx.Infof("continuous profiling cpu overload, cpu: %d", currentValue)
return true
}
return false
}
func newPyroscopeProfiler(c Config) profiler {
return &pyroscopeProfiler{
c: c,
}
}
func (p *pyroscopeProfiler) Start() error {
pConf := genPyroscopeConf(p.c)
// set mutex and block profile rate
setFraction(p.c)
prof, err := pyroscope.Start(pConf)
if err != nil {
resetFraction(p.c)
return err
}
p.profiler = prof
return nil
}
func (p *pyroscopeProfiler) Stop() error {
if p.profiler == nil {
return nil
}
if err := p.profiler.Stop(); err != nil {
return err
}
resetFraction(p.c)
p.profiler = nil
return nil
}
func setFraction(c Config) {
// These 2 lines are only required if you're using mutex or block profiling
if c.ProfileType.Mutex {
runtime.SetMutexProfileFraction(10) // 10/seconds
}
if c.ProfileType.Block {
runtime.SetBlockProfileRate(1000 * 1000) // 1/millisecond
}
}
func resetFraction(c Config) {
// These 2 lines are only required if you're using mutex or block profiling
if c.ProfileType.Mutex {
runtime.SetMutexProfileFraction(0)
}
if c.ProfileType.Block {
runtime.SetBlockProfileRate(0)
}
}

View File

@@ -0,0 +1,177 @@
package profiling
import (
"sync"
"testing"
"time"
"github.com/grafana/pyroscope-go"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/conf"
"github.com/zeromicro/go-zero/core/syncx"
)
func TestStart(t *testing.T) {
t.Run("profiling", func(t *testing.T) {
var c Config
assert.NoError(t, conf.FillDefault(&c))
c.Name = "test"
p := newProfiler(c)
assert.NotNil(t, p)
assert.NoError(t, p.Start())
assert.NoError(t, p.Stop())
})
t.Run("invalid config", func(t *testing.T) {
mp := &mockProfiler{}
newProfiler = func(c Config) profiler {
return mp
}
Start(Config{})
Start(Config{
ServerAddr: "localhost:4040",
})
})
t.Run("test start profiler", func(t *testing.T) {
mp := &mockProfiler{}
newProfiler = func(c Config) profiler {
return mp
}
c := Config{
Name: "test",
ServerAddr: "localhost:4040",
CheckInterval: time.Millisecond,
ProfilingDuration: time.Millisecond * 10,
CpuThreshold: 0,
}
var done = make(chan struct{})
go startPyroscope(c, done)
time.Sleep(time.Millisecond * 50)
close(done)
assert.True(t, mp.started.True())
assert.True(t, mp.stopped.True())
})
t.Run("test start profiler with cpu overloaded", func(t *testing.T) {
mp := &mockProfiler{}
newProfiler = func(c Config) profiler {
return mp
}
c := Config{
Name: "test",
ServerAddr: "localhost:4040",
CheckInterval: time.Millisecond,
ProfilingDuration: time.Millisecond * 10,
CpuThreshold: 900,
}
var done = make(chan struct{})
go startPyroscope(c, done)
time.Sleep(time.Millisecond * 50)
close(done)
assert.False(t, mp.started.True())
})
t.Run("start/stop err", func(t *testing.T) {
mp := &mockProfiler{
err: assert.AnError,
}
newProfiler = func(c Config) profiler {
return mp
}
c := Config{
Name: "test",
ServerAddr: "localhost:4040",
CheckInterval: time.Millisecond,
ProfilingDuration: time.Millisecond * 10,
CpuThreshold: 0,
}
var done = make(chan struct{})
go startPyroscope(c, done)
time.Sleep(time.Millisecond * 50)
close(done)
assert.False(t, mp.started.True())
assert.False(t, mp.stopped.True())
})
}
func TestGenPyroscopeConf(t *testing.T) {
c := Config{
Name: "",
ServerAddr: "localhost:4040",
AuthUser: "user",
AuthPassword: "password",
ProfileType: ProfileType{
Logger: true,
CPU: true,
Goroutines: true,
Memory: true,
Mutex: true,
Block: true,
},
}
pyroscopeConf := genPyroscopeConf(c)
assert.Equal(t, c.ServerAddr, pyroscopeConf.ServerAddress)
assert.Equal(t, c.AuthUser, pyroscopeConf.BasicAuthUser)
assert.Equal(t, c.AuthPassword, pyroscopeConf.BasicAuthPassword)
assert.Equal(t, c.Name, pyroscopeConf.ApplicationName)
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileCPU)
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileGoroutines)
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileAllocObjects)
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileAllocSpace)
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileInuseObjects)
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileInuseSpace)
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileMutexCount)
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileMutexDuration)
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileBlockCount)
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileBlockDuration)
setFraction(c)
resetFraction(c)
newPyroscopeProfiler(c)
}
func TestNewPyroscopeProfiler(t *testing.T) {
p := newPyroscopeProfiler(Config{})
assert.Error(t, p.Start())
assert.NoError(t, p.Stop())
}
type mockProfiler struct {
mutex sync.Mutex
started syncx.AtomicBool
stopped syncx.AtomicBool
err error
}
func (m *mockProfiler) Start() error {
m.mutex.Lock()
if m.err == nil {
m.started.Set(true)
}
m.mutex.Unlock()
return m.err
}
func (m *mockProfiler) Stop() error {
m.mutex.Lock()
if m.err == nil {
m.stopped.Set(true)
}
m.mutex.Unlock()
return m.err
}

View File

@@ -173,17 +173,20 @@ func (s *sseMcpServer) handleRequest(w http.ResponseWriter, r *http.Request) {
return return
} }
w.WriteHeader(http.StatusAccepted)
// For notification methods (no ID), we don't send a response // For notification methods (no ID), we don't send a response
isNotification := req.ID == 0 isNotification, err := req.isNotification()
if err != nil {
http.Error(w, "Invalid request.ID", http.StatusBadRequest)
}
w.WriteHeader(http.StatusAccepted)
// Special handling for initialization sequence // Special handling for initialization sequence
// Always allow initialize and notifications/initialized regardless of client state // Always allow initialize and notifications/initialized regardless of client state
if req.Method == methodInitialize { if req.Method == methodInitialize {
logx.Infof("Processing initialize request with ID: %d", req.ID) logx.Infof("Processing initialize request with ID: %v", req.ID)
s.processInitialize(r.Context(), client, req) s.processInitialize(r.Context(), client, req)
logx.Infof("Sent initialize response for ID: %d, waiting for notifications/initialized", req.ID) logx.Infof("Sent initialize response for ID: %v, waiting for notifications/initialized", req.ID)
return return
} else if req.Method == methodNotificationsInitialized { } else if req.Method == methodNotificationsInitialized {
// Handle initialized notification // Handle initialized notification
@@ -206,41 +209,41 @@ func (s *sseMcpServer) handleRequest(w http.ResponseWriter, r *http.Request) {
// Process normal requests only after initialization // Process normal requests only after initialization
switch req.Method { switch req.Method {
case methodToolsCall: case methodToolsCall:
logx.Infof("Received tools call request with ID: %d", req.ID) logx.Infof("Received tools call request with ID: %v", req.ID)
s.processToolCall(r.Context(), client, req) s.processToolCall(r.Context(), client, req)
logx.Infof("Sent tools call response for ID: %d", req.ID) logx.Infof("Sent tools call response for ID: %v", req.ID)
case methodToolsList: case methodToolsList:
logx.Infof("Processing tools/list request with ID: %d", req.ID) logx.Infof("Processing tools/list request with ID: %v", req.ID)
s.processListTools(r.Context(), client, req) s.processListTools(r.Context(), client, req)
logx.Infof("Sent tools/list response for ID: %d", req.ID) logx.Infof("Sent tools/list response for ID: %v", req.ID)
case methodPromptsList: case methodPromptsList:
logx.Infof("Processing prompts/list request with ID: %d", req.ID) logx.Infof("Processing prompts/list request with ID: %v", req.ID)
s.processListPrompts(r.Context(), client, req) s.processListPrompts(r.Context(), client, req)
logx.Infof("Sent prompts/list response for ID: %d", req.ID) logx.Infof("Sent prompts/list response for ID: %v", req.ID)
case methodPromptsGet: case methodPromptsGet:
logx.Infof("Processing prompts/get request with ID: %d", req.ID) logx.Infof("Processing prompts/get request with ID: %v", req.ID)
s.processGetPrompt(r.Context(), client, req) s.processGetPrompt(r.Context(), client, req)
logx.Infof("Sent prompts/get response for ID: %d", req.ID) logx.Infof("Sent prompts/get response for ID: %v", req.ID)
case methodResourcesList: case methodResourcesList:
logx.Infof("Processing resources/list request with ID: %d", req.ID) logx.Infof("Processing resources/list request with ID: %v", req.ID)
s.processListResources(r.Context(), client, req) s.processListResources(r.Context(), client, req)
logx.Infof("Sent resources/list response for ID: %d", req.ID) logx.Infof("Sent resources/list response for ID: %v", req.ID)
case methodResourcesRead: case methodResourcesRead:
logx.Infof("Processing resources/read request with ID: %d", req.ID) logx.Infof("Processing resources/read request with ID: %v", req.ID)
s.processResourcesRead(r.Context(), client, req) s.processResourcesRead(r.Context(), client, req)
logx.Infof("Sent resources/read response for ID: %d", req.ID) logx.Infof("Sent resources/read response for ID: %v", req.ID)
case methodResourcesSubscribe: case methodResourcesSubscribe:
logx.Infof("Processing resources/subscribe request with ID: %d", req.ID) logx.Infof("Processing resources/subscribe request with ID: %v", req.ID)
s.processResourceSubscribe(r.Context(), client, req) s.processResourceSubscribe(r.Context(), client, req)
logx.Infof("Sent resources/subscribe response for ID: %d", req.ID) logx.Infof("Sent resources/subscribe response for ID: %v", req.ID)
case methodPing: case methodPing:
logx.Infof("Processing ping request with ID: %d", req.ID) logx.Infof("Processing ping request with ID: %v", req.ID)
s.processPing(r.Context(), client, req) s.processPing(r.Context(), client, req)
case methodNotificationsCancelled: case methodNotificationsCancelled:
logx.Infof("Received notifications/cancelled notification: %d", req.ID) logx.Infof("Received notifications/cancelled notification: %v", req.ID)
s.processNotificationCancelled(r.Context(), client, req) s.processNotificationCancelled(r.Context(), client, req)
default: default:
logx.Infof("Unknown method: %s from client: %d", req.Method, req.ID) logx.Infof("Unknown method: %s from client: %v", req.Method, req.ID)
s.sendErrorResponse(r.Context(), client, req.ID, "Method not found", errCodeMethodNotFound) s.sendErrorResponse(r.Context(), client, req.ID, "Method not found", errCodeMethodNotFound)
} }
} }
@@ -809,7 +812,7 @@ func (s *sseMcpServer) processResourcesRead(ctx context.Context, client *mcpClie
} }
// Ensure MimeType is set if available from the resource definition // Ensure MimeType is set if available from the resource definition
if len(content.MimeType) == 0 && resource.MimeType != "" { if len(content.MimeType) == 0 && len(resource.MimeType) > 0 {
content.MimeType = resource.MimeType content.MimeType = resource.MimeType
} }
@@ -880,10 +883,10 @@ func (s *sseMcpServer) processPing(ctx context.Context, client *mcpClient, req R
// sendErrorResponse sends an error response via the SSE channel // sendErrorResponse sends an error response via the SSE channel
func (s *sseMcpServer) sendErrorResponse(ctx context.Context, client *mcpClient, func (s *sseMcpServer) sendErrorResponse(ctx context.Context, client *mcpClient,
id int64, message string, code int) { id any, message string, code int) {
errorResponse := struct { errorResponse := struct {
JsonRpc string `json:"jsonrpc"` JsonRpc string `json:"jsonrpc"`
ID int64 `json:"id"` ID any `json:"id"`
Error errorMessage `json:"error"` Error errorMessage `json:"error"`
}{ }{
JsonRpc: jsonRpcVersion, JsonRpc: jsonRpcVersion,
@@ -898,7 +901,7 @@ func (s *sseMcpServer) sendErrorResponse(ctx context.Context, client *mcpClient,
jsonData, _ := json.Marshal(errorResponse) jsonData, _ := json.Marshal(errorResponse)
// Use CRLF line endings as requested // Use CRLF line endings as requested
sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData)) sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData))
logx.Infof("Sending error for ID %d: %s", id, sseMessage) logx.Infof("Sending error for ID %v: %s", id, sseMessage)
// cannot receive from ctx.Done() because we're sending to the channel for SSE messages // cannot receive from ctx.Done() because we're sending to the channel for SSE messages
select { select {
@@ -910,7 +913,7 @@ func (s *sseMcpServer) sendErrorResponse(ctx context.Context, client *mcpClient,
} }
// sendResponse sends a success response via the SSE channel // sendResponse sends a success response via the SSE channel
func (s *sseMcpServer) sendResponse(ctx context.Context, client *mcpClient, id int64, result any) { func (s *sseMcpServer) sendResponse(ctx context.Context, client *mcpClient, id any, result any) {
response := Response{ response := Response{
JsonRpc: jsonRpcVersion, JsonRpc: jsonRpcVersion,
ID: id, ID: id,
@@ -925,13 +928,13 @@ func (s *sseMcpServer) sendResponse(ctx context.Context, client *mcpClient, id i
// Use CRLF line endings as requested // Use CRLF line endings as requested
sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData)) sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData))
logx.Infof("Sending response for ID %d: %s", id, sseMessage) logx.Infof("Sending response for ID %v: %s", id, sseMessage)
// cannot receive from ctx.Done() because we're sending to the channel for SSE messages // cannot receive from ctx.Done() because we're sending to the channel for SSE messages
select { select {
case client.channel <- sseMessage: case client.channel <- sseMessage:
default: default:
// Channel buffer is full, log warning and continue // Channel buffer is full, log warning and continue
logx.Infof("Client %s channel is full while sending response with ID %d", client.id, id) logx.Infof("Client %s channel is full while sending response with ID %v", client.id, id)
} }
} }

View File

@@ -175,6 +175,20 @@ func TestHandleRequest_badRequest(t *testing.T) {
mock.server.handleRequest(w, r) mock.server.handleRequest(w, r)
assert.Equal(t, http.StatusBadRequest, w.Code) assert.Equal(t, http.StatusBadRequest, w.Code)
}) })
t.Run("bad id", func(t *testing.T) {
mock := newMockMcpServer(t)
defer mock.shutdown()
addTestClient(mock.server, "test-session", true)
body := `{"jsonrpc": "2.0", "id": {}, "method": "tools.call", "params": {}}`
r := httptest.NewRequest(http.MethodPost, "/?session_id=test-session", bytes.NewReader([]byte(body)))
w := httptest.NewRecorder()
mock.server.handleRequest(w, r)
assert.Equal(t, http.StatusBadRequest, w.Code)
assert.Contains(t, w.Body.String(), "Invalid request.ID")
})
} }
func TestRegisterTool(t *testing.T) { func TestRegisterTool(t *testing.T) {

View File

@@ -3,6 +3,7 @@ package mcp
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"sync" "sync"
"github.com/zeromicro/go-zero/rest" "github.com/zeromicro/go-zero/rest"
@@ -15,11 +16,28 @@ type Cursor string
type Request struct { type Request struct {
SessionId string `form:"session_id"` // Session identifier for client tracking SessionId string `form:"session_id"` // Session identifier for client tracking
JsonRpc string `json:"jsonrpc"` // Must be "2.0" per JSON-RPC spec JsonRpc string `json:"jsonrpc"` // Must be "2.0" per JSON-RPC spec
ID int64 `json:"id"` // Request identifier for matching responses ID any `json:"id"` // Request identifier for matching responses
Method string `json:"method"` // Method name to invoke Method string `json:"method"` // Method name to invoke
Params json.RawMessage `json:"params"` // Parameters for the method Params json.RawMessage `json:"params"` // Parameters for the method
} }
func (r Request) isNotification() (bool, error) {
switch val := r.ID.(type) {
case int:
return val == 0, nil
case int64:
return val == 0, nil
case float64:
return val == 0.0, nil
case string:
return len(val) == 0, nil
case nil:
return true, nil
default:
return false, fmt.Errorf("invalid type %T", val)
}
}
type PaginatedParams struct { type PaginatedParams struct {
Cursor string `json:"cursor"` Cursor string `json:"cursor"`
Meta struct { Meta struct {
@@ -116,13 +134,8 @@ type FileContent struct {
// EmbeddedResource represents a resource embedded in a message // EmbeddedResource represents a resource embedded in a message
type EmbeddedResource struct { type EmbeddedResource struct {
Type string `json:"type"` // Always "resource" Type string `json:"type"` // Always "resource"
Resource struct { Resource ResourceContent `json:"resource"` // The resource data
URI string `json:"uri"` // Resource URI
MimeType string `json:"mimeType"` // MIME type of the resource
Text string `json:"text,omitempty"` // Text content (if available)
Blob string `json:"blob,omitempty"` // Base64 encoded blob data (if available)
} `json:"resource"` // The resource data
} }
// Annotations provides additional metadata for content // Annotations provides additional metadata for content
@@ -249,7 +262,7 @@ type errorObj struct {
// Response represents a JSON-RPC response // Response represents a JSON-RPC response
type Response struct { type Response struct {
JsonRpc string `json:"jsonrpc"` // Always "2.0" JsonRpc string `json:"jsonrpc"` // Always "2.0"
ID int64 `json:"id"` // Same as request ID ID any `json:"id"` // Same as request ID
Result any `json:"result"` // Result object (null if error) Result any `json:"result"` // Result object (null if error)
Error *errorObj `json:"error,omitempty"` // Error object (null if success) Error *errorObj `json:"error,omitempty"` // Error object (null if success)
} }

View File

@@ -3,6 +3,7 @@ package mcp
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -55,7 +56,7 @@ func TestRequestUnmarshaling(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, "2.0", req.JsonRpc) assert.Equal(t, "2.0", req.JsonRpc)
assert.Equal(t, int64(789), req.ID) assert.Equal(t, float64(789), req.ID)
assert.Equal(t, "test_method", req.Method) assert.Equal(t, "test_method", req.Method)
// Check params unmarshaled correctly // Check params unmarshaled correctly
@@ -204,3 +205,67 @@ func TestCallToolResult(t *testing.T) {
assert.Contains(t, string(data), `"content":[{"text":"Sample result"}]`) assert.Contains(t, string(data), `"content":[{"text":"Sample result"}]`)
assert.NotContains(t, string(data), `"isError":`) assert.NotContains(t, string(data), `"isError":`)
} }
func TestRequest_isNotification(t *testing.T) {
tests := []struct {
name string
id any
want bool
wantErr error
}{
// integer test cases
{name: "int zero", id: 0, want: true, wantErr: nil},
{name: "int non-zero", id: 1, want: false, wantErr: nil},
{name: "int64 zero", id: int64(0), want: true, wantErr: nil},
{name: "int64 max", id: int64(9223372036854775807), want: false, wantErr: nil},
// floating point number test cases
{name: "float64 zero", id: float64(0.0), want: true, wantErr: nil},
{name: "float64 positive", id: float64(0.000001), want: false, wantErr: nil},
{name: "float64 negative", id: float64(-0.000001), want: false, wantErr: nil},
{name: "float64 epsilon", id: float64(1e-300), want: false, wantErr: nil},
// string test cases
{name: "empty string", id: "", want: true, wantErr: nil},
{name: "non-empty string", id: "abc", want: false, wantErr: nil},
{name: "space string", id: " ", want: false, wantErr: nil},
{name: "unicode string", id: "こんにちは", want: false, wantErr: nil},
// special cases
{name: "nil", id: nil, want: true, wantErr: nil},
// logical type test cases
{name: "bool true", id: true, want: false, wantErr: errors.New("invalid type bool")},
{name: "bool false", id: false, want: false, wantErr: errors.New("invalid type bool")},
{name: "struct type", id: struct{}{}, want: false, wantErr: errors.New("invalid type struct {}")},
{name: "slice type", id: []int{1, 2, 3}, want: false, wantErr: errors.New("invalid type []int")},
{name: "map type", id: map[string]int{"a": 1}, want: false, wantErr: errors.New("invalid type map[string]int")},
{name: "pointer type", id: new(int), want: false, wantErr: errors.New("invalid type *int")},
{name: "func type", id: func() {}, want: false, wantErr: errors.New("invalid type func()")},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := Request{
SessionId: "test-session",
JsonRpc: "2.0",
ID: tt.id,
Method: "testMethod",
Params: json.RawMessage(`{}`),
}
got, err := req.isNotification()
if (err != nil) != (tt.wantErr != nil) {
t.Fatalf("error presence mismatch: got error = %v, wantErr %v", err, tt.wantErr)
}
if err != nil && tt.wantErr != nil && err.Error() != tt.wantErr.Error() {
t.Fatalf("error message mismatch:\ngot %q\nwant %q", err.Error(), tt.wantErr.Error())
}
if got != tt.want {
t.Errorf("isNotification() = %v, want %v for ID %v (%T)", got, tt.want, tt.id, tt.id)
}
})
}
}

View File

@@ -6,7 +6,6 @@
[English](readme.md) | 简体中文 [English](readme.md) | 简体中文
[![Go](https://github.com/zeromicro/go-zero/workflows/Go/badge.svg?branch=master)](https://github.com/zeromicro/go-zero/actions)
[![Go Report Card](https://goreportcard.com/badge/github.com/zeromicro/go-zero)](https://goreportcard.com/report/github.com/zeromicro/go-zero) [![Go Report Card](https://goreportcard.com/badge/github.com/zeromicro/go-zero)](https://goreportcard.com/report/github.com/zeromicro/go-zero)
[![goproxy](https://goproxy.cn/stats/github.com/zeromicro/go-zero/badges/download-count.svg)](https://goproxy.cn/stats/github.com/zeromicro/go-zero/badges/download-count.svg) [![goproxy](https://goproxy.cn/stats/github.com/zeromicro/go-zero/badges/download-count.svg)](https://goproxy.cn/stats/github.com/zeromicro/go-zero/badges/download-count.svg)
[![codecov](https://codecov.io/gh/zeromicro/go-zero/branch/master/graph/badge.svg)](https://codecov.io/gh/zeromicro/go-zero) [![codecov](https://codecov.io/gh/zeromicro/go-zero/branch/master/graph/badge.svg)](https://codecov.io/gh/zeromicro/go-zero)
@@ -304,6 +303,7 @@ go-zero 已被许多公司用于生产部署,接入场景如在线教育、电
>105. 昆仑万维科技股份有限公司 >105. 昆仑万维科技股份有限公司
>106. 无锡盛算信息技术有限公司 >106. 无锡盛算信息技术有限公司
>107. 深圳市聚货通信息科技有限公司 >107. 深圳市聚货通信息科技有限公司
>108. 浙江银盾云科技有限公司
如果贵公司也已使用 go-zero欢迎在 [登记地址](https://github.com/zeromicro/go-zero/issues/602) 登记,仅仅为了推广,不做其它用途。 如果贵公司也已使用 go-zero欢迎在 [登记地址](https://github.com/zeromicro/go-zero/issues/602) 登记,仅仅为了推广,不做其它用途。

View File

@@ -7,7 +7,6 @@ go-zero is a web and rpc framework with lots of builtin engineering practices. I
<div align=center> <div align=center>
[![Go](https://github.com/zeromicro/go-zero/workflows/Go/badge.svg?branch=master)](https://github.com/zeromicro/go-zero/actions)
[![codecov](https://codecov.io/gh/zeromicro/go-zero/branch/master/graph/badge.svg)](https://codecov.io/gh/zeromicro/go-zero) [![codecov](https://codecov.io/gh/zeromicro/go-zero/branch/master/graph/badge.svg)](https://codecov.io/gh/zeromicro/go-zero)
[![Go Report Card](https://goreportcard.com/badge/github.com/zeromicro/go-zero)](https://goreportcard.com/report/github.com/zeromicro/go-zero) [![Go Report Card](https://goreportcard.com/badge/github.com/zeromicro/go-zero)](https://goreportcard.com/report/github.com/zeromicro/go-zero)
[![Release](https://img.shields.io/github/v/release/zeromicro/go-zero.svg?style=flat-square)](https://github.com/zeromicro/go-zero) [![Release](https://img.shields.io/github/v/release/zeromicro/go-zero.svg?style=flat-square)](https://github.com/zeromicro/go-zero)
@@ -251,7 +250,3 @@ go-zero enlisted in the [CNCF Cloud Native Landscape](https://landscape.cncf.io/
## Give a Star! ⭐ ## Give a Star! ⭐
If you like this project or are using it to learn or start your own solution, give it a star to get updates on new releases. Your support matters! If you like this project or are using it to learn or start your own solution, give it a star to get updates on new releases. Your support matters!
## Buy me a coffee
<a href="https://www.buymeacoffee.com/kevwan" target="_blank"><img src="https://cdn.buymeacoffee.com/buttons/v2/default-yellow.png" alt="Buy Me A Coffee" style="height: 60px !important;width: 217px !important;" ></a>

View File

@@ -28,7 +28,10 @@ var ErrSignatureConfig = errors.New("bad config for Signature")
type engine struct { type engine struct {
conf RestConf conf RestConf
routes []featuredRoutes routes []featuredRoutes
// timeout is the max timeout of all routes // timeout is the max timeout of all routes,
// and is used to set http.Server.ReadTimeout and http.Server.WriteTimeout.
// this network timeout is used to avoid DoS attacks by sending data slowly
// or receiving data slowly with many connections to exhaust server resources.
timeout time.Duration timeout time.Duration
unauthorizedCallback handler.UnauthorizedCallback unauthorizedCallback handler.UnauthorizedCallback
unsignedCallback handler.UnsignedCallback unsignedCallback handler.UnsignedCallback
@@ -60,11 +63,7 @@ func (ng *engine) addRoutes(r featuredRoutes) {
} }
ng.routes = append(ng.routes, r) ng.routes = append(ng.routes, r)
// need to guarantee the timeout is the max of all routes ng.mightUpdateTimeout(r)
// otherwise impossible to set http.Server.ReadTimeout & WriteTimeout
if r.timeout > ng.timeout {
ng.timeout = r.timeout
}
} }
func buildSSERoutes(routes []Route) []Route { func buildSSERoutes(routes []Route) []Route {
@@ -192,11 +191,12 @@ func (ng *engine) checkedMaxBytes(bytes int64) int64 {
return ng.conf.MaxBytes return ng.conf.MaxBytes
} }
func (ng *engine) checkedTimeout(timeout time.Duration) time.Duration { func (ng *engine) checkedTimeout(timeout *time.Duration) time.Duration {
if timeout > 0 { if timeout != nil {
return timeout return *timeout
} }
// if timeout not set in featured routes, use global timeout
return time.Duration(ng.conf.Timeout) * time.Millisecond return time.Duration(ng.conf.Timeout) * time.Millisecond
} }
@@ -228,6 +228,32 @@ func (ng *engine) getShedder(priority bool) load.Shedder {
return ng.shedder return ng.shedder
} }
func (ng *engine) hasTimeout() bool {
return ng.conf.Middlewares.Timeout && ng.timeout > 0
}
// mightUpdateTimeout checks if the route timeout is greater than the current,
// and updates the engine's timeout accordingly.
func (ng *engine) mightUpdateTimeout(r featuredRoutes) {
// if global timeout is set to 0, it means no need to set read/write timeout
// if route timeout is nil, no need to update ng.timeout
if ng.timeout == 0 || r.timeout == nil {
return
}
// if route timeout is 0 (means no timeout), cannot set read/write timeout
if *r.timeout == 0 {
ng.timeout = 0
return
}
// need to guarantee the timeout is the max of all routes
// otherwise impossible to set http.Server.ReadTimeout & WriteTimeout
if *r.timeout > ng.timeout {
ng.timeout = *r.timeout
}
}
// notFoundHandler returns a middleware that handles 404 not found requests. // notFoundHandler returns a middleware that handles 404 not found requests.
func (ng *engine) notFoundHandler(next http.Handler) http.Handler { func (ng *engine) notFoundHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -329,7 +355,7 @@ func (ng *engine) start(router httpx.Router, opts ...StartOption) error {
} }
// make sure user defined options overwrite default options // make sure user defined options overwrite default options
opts = append([]StartOption{ng.withTimeout()}, opts...) opts = append([]StartOption{ng.withNetworkTimeout()}, opts...)
if len(ng.conf.CertFile) == 0 && len(ng.conf.KeyFile) == 0 { if len(ng.conf.CertFile) == 0 && len(ng.conf.KeyFile) == 0 {
return internal.StartHttp(ng.conf.Host, ng.conf.Port, router, opts...) return internal.StartHttp(ng.conf.Host, ng.conf.Port, router, opts...)
@@ -352,18 +378,19 @@ func (ng *engine) use(middleware Middleware) {
ng.middlewares = append(ng.middlewares, middleware) ng.middlewares = append(ng.middlewares, middleware)
} }
func (ng *engine) withTimeout() internal.StartOption { func (ng *engine) withNetworkTimeout() internal.StartOption {
return func(svr *http.Server) { return func(svr *http.Server) {
timeout := ng.timeout if !ng.hasTimeout() {
if timeout > 0 { return
// factor 0.8, to avoid clients send longer content-length than the actual content,
// without this timeout setting, the server will time out and respond 503 Service Unavailable,
// which triggers the circuit breaker.
svr.ReadTimeout = 4 * timeout / 5
// factor 1.1, to avoid servers don't have enough time to write responses.
// setting the factor less than 1.0 may lead clients not receiving the responses.
svr.WriteTimeout = 11 * timeout / 10
} }
// factor 0.8, to avoid clients send longer content-length than the actual content,
// without this timeout setting, the server will time out and respond 503 Service Unavailable,
// which triggers the circuit breaker.
svr.ReadTimeout = 4 * ng.timeout / 5
// factor 1.1, to avoid servers don't have enough time to write responses.
// setting the factor less than 1.0 may lead clients not receiving the responses.
svr.WriteTimeout = 11 * ng.timeout / 10
} }
} }

View File

@@ -73,7 +73,17 @@ Verbose: true
Path: "/", Path: "/",
Handler: func(w http.ResponseWriter, r *http.Request) {}, Handler: func(w http.ResponseWriter, r *http.Request) {},
}}, }},
timeout: time.Minute, timeout: ptrOfDuration(time.Minute),
},
{
jwt: jwtSetting{},
signature: signatureSetting{},
routes: []Route{{
Method: http.MethodGet,
Path: "/",
Handler: func(w http.ResponseWriter, r *http.Request) {},
}},
timeout: ptrOfDuration(0),
}, },
{ {
priority: true, priority: true,
@@ -84,7 +94,7 @@ Verbose: true
Path: "/", Path: "/",
Handler: func(w http.ResponseWriter, r *http.Request) {}, Handler: func(w http.ResponseWriter, r *http.Request) {},
}}, }},
timeout: time.Second, timeout: ptrOfDuration(time.Second),
}, },
{ {
priority: true, priority: true,
@@ -227,8 +237,12 @@ Verbose: true
})) }))
timeout := time.Second * 3 timeout := time.Second * 3
if route.timeout > timeout { if route.timeout != nil {
timeout = route.timeout if *route.timeout == 0 {
timeout = 0
} else if *route.timeout > timeout {
timeout = *route.timeout
}
} }
assert.Equal(t, timeout, ng.timeout) assert.Equal(t, timeout, ng.timeout)
}) })
@@ -236,10 +250,69 @@ Verbose: true
} }
} }
func TestNewEngine_unsignedCallback(t *testing.T) {
priKeyfile, err := fs.TempFilenameWithText(priKey)
assert.Nil(t, err)
defer os.Remove(priKeyfile)
yaml := `Name: foo
Host: localhost
Port: 0
Middlewares:
Log: false
`
route := featuredRoutes{
priority: true,
jwt: jwtSetting{
enabled: true,
},
signature: signatureSetting{
enabled: true,
SignatureConf: SignatureConf{
Strict: true,
PrivateKeys: []PrivateKeyConf{
{
Fingerprint: "a",
KeyFile: priKeyfile,
},
},
},
},
routes: []Route{{
Method: http.MethodGet,
Path: "/",
Handler: func(w http.ResponseWriter, r *http.Request) {},
}},
}
var index int32
t.Run(fmt.Sprintf("%s-%v", yaml, route.routes), func(t *testing.T) {
var cnf RestConf
assert.Nil(t, conf.LoadFromYamlBytes([]byte(yaml), &cnf))
ng := newEngine(cnf)
if atomic.AddInt32(&index, 1)%2 == 0 {
ng.setUnsignedCallback(func(w http.ResponseWriter, r *http.Request,
next http.Handler, strict bool, code int) {
})
}
ng.addRoutes(route)
ng.use(func(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, r)
}
})
assert.NotNil(t, ng.start(mockedRouter{}, func(svr *http.Server) {
}))
assert.Equal(t, time.Duration(time.Second*3), ng.timeout)
})
}
func TestEngine_checkedTimeout(t *testing.T) { func TestEngine_checkedTimeout(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
timeout time.Duration timeout *time.Duration
expect time.Duration expect time.Duration
}{ }{
{ {
@@ -248,17 +321,17 @@ func TestEngine_checkedTimeout(t *testing.T) {
}, },
{ {
name: "less", name: "less",
timeout: time.Millisecond * 500, timeout: ptrOfDuration(time.Millisecond * 500),
expect: time.Millisecond * 500, expect: time.Millisecond * 500,
}, },
{ {
name: "equal", name: "equal",
timeout: time.Second, timeout: ptrOfDuration(time.Second),
expect: time.Second, expect: time.Second,
}, },
{ {
name: "more", name: "more",
timeout: time.Millisecond * 1500, timeout: ptrOfDuration(time.Millisecond * 1500),
expect: time.Millisecond * 1500, expect: time.Millisecond * 1500,
}, },
} }
@@ -394,9 +467,14 @@ func TestEngine_withTimeout(t *testing.T) {
for _, test := range tests { for _, test := range tests {
test := test test := test
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
ng := newEngine(RestConf{Timeout: test.timeout}) ng := newEngine(RestConf{
Timeout: test.timeout,
Middlewares: MiddlewaresConf{
Timeout: true,
},
})
svr := &http.Server{} svr := &http.Server{}
ng.withTimeout()(svr) ng.withNetworkTimeout()(svr)
assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*4/5, svr.ReadTimeout) assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*4/5, svr.ReadTimeout)
assert.Equal(t, time.Duration(0), svr.ReadHeaderTimeout) assert.Equal(t, time.Duration(0), svr.ReadHeaderTimeout)
@@ -406,6 +484,62 @@ func TestEngine_withTimeout(t *testing.T) {
} }
} }
func TestEngine_ReadWriteTimeout(t *testing.T) {
logx.Disable()
tests := []struct {
name string
timeout int64
middleware bool
}{
{
name: "0/false",
timeout: 0,
middleware: false,
},
{
name: "0/true",
timeout: 0,
middleware: true,
},
{
name: "set/false",
timeout: 1000,
middleware: false,
},
{
name: "both set",
timeout: 1000,
middleware: true,
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
ng := newEngine(RestConf{
Timeout: test.timeout,
Middlewares: MiddlewaresConf{
Timeout: test.middleware,
},
})
svr := &http.Server{}
ng.withNetworkTimeout()(svr)
assert.Equal(t, time.Duration(0), svr.ReadHeaderTimeout)
assert.Equal(t, time.Duration(0), svr.IdleTimeout)
if test.timeout > 0 && test.middleware {
assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*4/5, svr.ReadTimeout)
assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*11/10, svr.WriteTimeout)
} else {
assert.Equal(t, time.Duration(0), svr.ReadTimeout)
assert.Equal(t, time.Duration(0), svr.WriteTimeout)
}
})
}
}
func TestEngine_start(t *testing.T) { func TestEngine_start(t *testing.T) {
logx.Disable() logx.Disable()

View File

@@ -106,8 +106,8 @@ func (h *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
case <-ctx.Done(): case <-ctx.Done():
tw.mu.Lock() tw.mu.Lock()
defer tw.mu.Unlock() defer tw.mu.Unlock()
// there isn't any user-defined middleware before TimoutHandler, // there isn't any user-defined middleware before TimeoutHandler,
// so we can guarantee that cancelation in biz related code won't come here. // so we can guarantee that cancellation in biz related code won't come here.
httpx.ErrorCtx(r.Context(), w, ctx.Err(), func(w http.ResponseWriter, err error) { httpx.ErrorCtx(r.Context(), w, ctx.Err(), func(w http.ResponseWriter, err error) {
if errors.Is(err, context.Canceled) { if errors.Is(err, context.Canceled) {
w.WriteHeader(statusClientClosedRequest) w.WriteHeader(statusClientClosedRequest)
@@ -151,7 +151,7 @@ func (tw *timeoutWriter) Flush() {
flusher.Flush() flusher.Flush()
} }
// Header returns the underline temporary http.Header. // Header returns the underlying temporary http.Header.
func (tw *timeoutWriter) Header() http.Header { func (tw *timeoutWriter) Header() http.Header {
return tw.h return tw.h
} }

View File

@@ -2,15 +2,16 @@ package fileserver
import ( import (
"net/http" "net/http"
"path"
"strings" "strings"
"sync" "sync"
) )
// Middleware returns a middleware that serves files from the given file system. // Middleware returns a middleware that serves files from the given file system.
func Middleware(path string, fs http.FileSystem) func(http.HandlerFunc) http.HandlerFunc { func Middleware(upath string, fs http.FileSystem) func(http.HandlerFunc) http.HandlerFunc {
fileServer := http.FileServer(fs) fileServer := http.FileServer(fs)
pathWithoutTrailSlash := ensureNoTrailingSlash(path) pathWithoutTrailSlash := ensureNoTrailingSlash(upath)
canServe := createServeChecker(path, fs) canServe := createServeChecker(upath, fs)
return func(next http.HandlerFunc) http.HandlerFunc { return func(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
@@ -28,9 +29,22 @@ func createFileChecker(fs http.FileSystem) func(string) bool {
var lock sync.RWMutex var lock sync.RWMutex
fileChecker := make(map[string]bool) fileChecker := make(map[string]bool)
return func(path string) bool { return func(upath string) bool {
// Emulate http.Dir.Opens path normalization for embed.FS.Open.
// http.FileServer redirects any request ending in "/index.html"
// to the same path without the final "index.html".
// So the path here may be empty or end with a "/".
// http.Dir.Open uses this logic to clean the path,
// correctly handling those two cases.
// embed.FS doesnt perform this normalization, so we apply the same logic here.
upath = path.Clean("/" + upath)[1:]
if len(upath) == 0 {
// if the path is empty, we use "." to open the current directory
upath = "."
}
lock.RLock() lock.RLock()
exist, ok := fileChecker[path] exist, ok := fileChecker[upath]
lock.RUnlock() lock.RUnlock()
if ok { if ok {
return exist return exist
@@ -39,9 +53,9 @@ func createFileChecker(fs http.FileSystem) func(string) bool {
lock.Lock() lock.Lock()
defer lock.Unlock() defer lock.Unlock()
file, err := fs.Open(path) file, err := fs.Open(upath)
exist = err == nil exist = err == nil
fileChecker[path] = exist fileChecker[upath] = exist
if err != nil { if err != nil {
return false return false
} }
@@ -51,8 +65,8 @@ func createFileChecker(fs http.FileSystem) func(string) bool {
} }
} }
func createServeChecker(path string, fs http.FileSystem) func(r *http.Request) bool { func createServeChecker(upath string, fs http.FileSystem) func(r *http.Request) bool {
pathWithTrailSlash := ensureTrailingSlash(path) pathWithTrailSlash := ensureTrailingSlash(upath)
fileChecker := createFileChecker(fs) fileChecker := createFileChecker(fs)
return func(r *http.Request) bool { return func(r *http.Request) bool {
@@ -62,18 +76,18 @@ func createServeChecker(path string, fs http.FileSystem) func(r *http.Request) b
} }
} }
func ensureTrailingSlash(path string) string { func ensureTrailingSlash(upath string) string {
if strings.HasSuffix(path, "/") { if strings.HasSuffix(upath, "/") {
return path return upath
} }
return path + "/" return upath + "/"
} }
func ensureNoTrailingSlash(path string) string { func ensureNoTrailingSlash(upath string) string {
if strings.HasSuffix(path, "/") { if strings.HasSuffix(upath, "/") {
return path[:len(path)-1] return upath[:len(upath)-1]
} }
return path return upath
} }

View File

@@ -1,6 +1,8 @@
package fileserver package fileserver
import ( import (
"embed"
"io/fs"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
@@ -61,6 +63,46 @@ func TestMiddleware(t *testing.T) {
requestPath: "/ws", requestPath: "/ws",
expectedStatus: http.StatusAlreadyReported, expectedStatus: http.StatusAlreadyReported,
}, },
// http.FileServer redirects any request ending in "/index.html"
// to the same path, without the final "index.html".
{
name: "Serve index.html",
path: "/static",
dir: "testdata",
requestPath: "/static/index.html",
expectedStatus: http.StatusMovedPermanently,
},
{
name: "Serve index.html with path with trailing slash",
path: "/static/",
dir: "testdata",
requestPath: "/static/index.html",
expectedStatus: http.StatusMovedPermanently,
},
{
name: "Serve index.html in a nested directory",
path: "/static",
dir: "testdata",
requestPath: "/static/nested/index.html",
expectedStatus: http.StatusMovedPermanently,
},
{
name: "Request index.html indirectly",
path: "/static",
dir: "testdata",
requestPath: "/static/",
expectedStatus: http.StatusOK,
expectedContent: "hello",
},
{
name: "Request index.html in a nested directory indirectly",
path: "/static",
dir: "testdata",
requestPath: "/static/nested/",
expectedStatus: http.StatusOK,
expectedContent: "hello",
},
} }
for _, tt := range tests { for _, tt := range tests {
@@ -87,6 +129,128 @@ func TestMiddleware(t *testing.T) {
} }
} }
var (
//go:embed testdata
testdataFS embed.FS
)
func TestMiddleware_embedFS(t *testing.T) {
tests := []struct {
name string
path string
requestPath string
expectedStatus int
expectedContent string
}{
{
name: "Serve static file",
path: "/static",
requestPath: "/static/example.txt",
expectedStatus: http.StatusOK,
expectedContent: "1",
},
{
name: "Path with trailing slash",
path: "/static/",
requestPath: "/static/example.txt",
expectedStatus: http.StatusOK,
expectedContent: "1",
},
{
name: "Root path",
path: "/",
requestPath: "/example.txt",
expectedStatus: http.StatusOK,
expectedContent: "1",
},
{
name: "Pass through non-matching path",
path: "/static/",
requestPath: "/other/path",
expectedStatus: http.StatusAlreadyReported,
},
{
name: "Not exist file",
path: "/assets",
requestPath: "/assets/not-exist.txt",
expectedStatus: http.StatusAlreadyReported,
},
{
name: "Not exist file in root",
path: "/",
requestPath: "/not-exist.txt",
expectedStatus: http.StatusAlreadyReported,
},
{
name: "websocket request",
path: "/",
requestPath: "/ws",
expectedStatus: http.StatusAlreadyReported,
},
// http.FileServer redirects any request ending in "/index.html"
// to the same path, without the final "index.html".
{
name: "Serve index.html",
path: "/static",
requestPath: "/static/index.html",
expectedStatus: http.StatusMovedPermanently,
},
{
name: "Serve index.html with path with trailing slash",
path: "/static/",
requestPath: "/static/index.html",
expectedStatus: http.StatusMovedPermanently,
},
{
name: "Serve index.html in a nested directory",
path: "/static",
requestPath: "/static/nested/index.html",
expectedStatus: http.StatusMovedPermanently,
},
{
name: "Request index.html indirectly",
path: "/static",
requestPath: "/static/",
expectedStatus: http.StatusOK,
expectedContent: "hello",
},
{
name: "Request index.html in a nested directory indirectly",
path: "/static",
requestPath: "/static/nested/",
expectedStatus: http.StatusOK,
expectedContent: "hello",
},
}
subFS, err := fs.Sub(testdataFS, "testdata")
assert.Nil(t, err)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
middleware := Middleware(tt.path, http.FS(subFS))
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusAlreadyReported)
})
handlerToTest := middleware(nextHandler)
for i := 0; i < 2; i++ {
req := httptest.NewRequest(http.MethodGet, tt.requestPath, nil)
rr := httptest.NewRecorder()
handlerToTest.ServeHTTP(rr, req)
assert.Equal(t, tt.expectedStatus, rr.Code)
if len(tt.expectedContent) > 0 {
assert.Equal(t, tt.expectedContent, rr.Body.String())
}
}
})
}
}
func TestEnsureTrailingSlash(t *testing.T) { func TestEnsureTrailingSlash(t *testing.T) {
tests := []struct { tests := []struct {
input string input string

View File

@@ -0,0 +1 @@
hello

View File

@@ -0,0 +1 @@
hello

View File

@@ -119,6 +119,16 @@ func (s *Server) Use(middleware Middleware) {
s.ngin.use(middleware) s.ngin.use(middleware)
} }
// build builds the Server and binds the routes to the router.
func (s *Server) build() error {
return s.ngin.bindRoutes(s.router)
}
// serve serves the HTTP requests using the Server's router.
func (s *Server) serve(w http.ResponseWriter, r *http.Request) {
s.router.ServeHTTP(w, r)
}
// ToMiddleware converts the given handler to a Middleware. // ToMiddleware converts the given handler to a Middleware.
func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware { func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware {
return func(handle http.HandlerFunc) http.HandlerFunc { return func(handle http.HandlerFunc) http.HandlerFunc {
@@ -283,14 +293,14 @@ func WithSignature(signature SignatureConf) RouteOption {
func WithSSE() RouteOption { func WithSSE() RouteOption {
return func(r *featuredRoutes) { return func(r *featuredRoutes) {
r.sse = true r.sse = true
r.timeout = 0 r.timeout = ptrOfDuration(0)
} }
} }
// WithTimeout returns a RouteOption to set timeout with given value. // WithTimeout returns a RouteOption to set timeout with given value.
func WithTimeout(timeout time.Duration) RouteOption { func WithTimeout(timeout time.Duration) RouteOption {
return func(r *featuredRoutes) { return func(r *featuredRoutes) {
r.timeout = timeout r.timeout = &timeout
} }
} }
@@ -325,6 +335,10 @@ func handleError(err error) {
panic(err) panic(err)
} }
func ptrOfDuration(d time.Duration) *time.Duration {
return &d
}
func validateSecret(secret string) { func validateSecret(secret string) {
if len(secret) < 8 { if len(secret) < 8 {
panic("secret's length can't be less than 8") panic("secret's length can't be less than 8")

View File

@@ -345,7 +345,7 @@ func TestWithPriority(t *testing.T) {
func TestWithTimeout(t *testing.T) { func TestWithTimeout(t *testing.T) {
var fr featuredRoutes var fr featuredRoutes
WithTimeout(time.Hour)(&fr) WithTimeout(time.Hour)(&fr)
assert.Equal(t, time.Hour, fr.timeout) assert.Equal(t, time.Hour, *fr.timeout)
} }
func TestWithTLSConfig(t *testing.T) { func TestWithTLSConfig(t *testing.T) {
@@ -819,6 +819,6 @@ func TestServerEmbedFileSystem(t *testing.T) {
// serve(server, w, r) // serve(server, w, r)
// // verify the response // // verify the response
func serve(s *Server, w http.ResponseWriter, r *http.Request) { func serve(s *Server, w http.ResponseWriter, r *http.Request) {
s.ngin.bindRoutes(s.router) _ = s.build()
s.router.ServeHTTP(w, r) s.serve(w, r)
} }

27
rest/serverless.go Normal file
View File

@@ -0,0 +1,27 @@
package rest
import "net/http"
// Serverless is a wrapper around Server that allows it to be used in serverless environments.
type Serverless struct {
server *Server
}
// NewServerless creates a new Serverless instance from the provided Server.
func NewServerless(server *Server) (*Serverless, error) {
// Ensure the server is built before using it in a serverless context.
// Why not call server.build() when serving requests,
// is because we need to ensure fail fast behavior.
if err := server.build(); err != nil {
return nil, err
}
return &Serverless{
server: server,
}, nil
}
// Serve handles HTTP requests by delegating them to the underlying Server instance.
func (s *Serverless) Serve(w http.ResponseWriter, r *http.Request) {
s.server.serve(w, r)
}

67
rest/serverless_test.go Normal file
View File

@@ -0,0 +1,67 @@
package rest
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/conf"
"github.com/zeromicro/go-zero/core/logx/logtest"
)
func TestNewServerless(t *testing.T) {
logtest.Discard(t)
const configYaml = `
Name: foo
Host: localhost
Port: 0
`
var cnf RestConf
assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
svr, err := NewServer(cnf)
assert.NoError(t, err)
svr.AddRoute(Route{
Method: http.MethodGet,
Path: "/",
Handler: func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Hello World"))
},
})
serverless, err := NewServerless(svr)
assert.NoError(t, err)
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "/", nil)
serverless.Serve(w, r)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "Hello World", w.Body.String())
}
func TestNewServerlessWithError(t *testing.T) {
logtest.Discard(t)
const configYaml = `
Name: foo
Host: localhost
Port: 0
`
var cnf RestConf
assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
svr, err := NewServer(cnf)
assert.NoError(t, err)
svr.AddRoute(Route{
Method: http.MethodGet,
Path: "notstartwith/",
Handler: nil,
})
_, err = NewServerless(svr)
assert.Error(t, err)
}

View File

@@ -31,7 +31,7 @@ type (
} }
featuredRoutes struct { featuredRoutes struct {
timeout time.Duration timeout *time.Duration
priority bool priority bool
jwt jwtSetting jwt jwtSetting
signature signatureSetting signature signatureSetting

1
tools/goctl/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
dist

View File

@@ -77,6 +77,7 @@ func init() {
goCmdFlags.StringVar(&gogen.VarStringRemote, "remote") goCmdFlags.StringVar(&gogen.VarStringRemote, "remote")
goCmdFlags.StringVar(&gogen.VarStringBranch, "branch") goCmdFlags.StringVar(&gogen.VarStringBranch, "branch")
goCmdFlags.BoolVar(&gogen.VarBoolWithTest, "test") goCmdFlags.BoolVar(&gogen.VarBoolWithTest, "test")
goCmdFlags.BoolVar(&gogen.VarBoolTypeGroup, "type-group")
goCmdFlags.StringVarWithDefaultValue(&gogen.VarStringStyle, "style", config.DefaultFormat) goCmdFlags.StringVarWithDefaultValue(&gogen.VarStringStyle, "style", config.DefaultFormat)
javaCmdFlags.StringVar(&javagen.VarStringDir, "dir") javaCmdFlags.StringVar(&javagen.VarStringDir, "dir")

View File

@@ -40,6 +40,8 @@ var (
// VarStringStyle describes the style of output files. // VarStringStyle describes the style of output files.
VarStringStyle string VarStringStyle string
VarBoolWithTest bool VarBoolWithTest bool
// VarBoolTypeGroup describes whether to group types.
VarBoolTypeGroup bool
) )
// GoCommand gen go project files from command line // GoCommand gen go project files from command line

View File

@@ -9,11 +9,11 @@ import (
"sort" "sort"
"strings" "strings"
"github.com/zeromicro/go-zero/core/collection"
"github.com/zeromicro/go-zero/tools/goctl/api/spec" "github.com/zeromicro/go-zero/tools/goctl/api/spec"
apiutil "github.com/zeromicro/go-zero/tools/goctl/api/util" apiutil "github.com/zeromicro/go-zero/tools/goctl/api/util"
"github.com/zeromicro/go-zero/tools/goctl/config" "github.com/zeromicro/go-zero/tools/goctl/config"
"github.com/zeromicro/go-zero/tools/goctl/internal/version" "github.com/zeromicro/go-zero/tools/goctl/internal/version"
"github.com/zeromicro/go-zero/tools/goctl/pkg/env"
"github.com/zeromicro/go-zero/tools/goctl/util" "github.com/zeromicro/go-zero/tools/goctl/util"
"github.com/zeromicro/go-zero/tools/goctl/util/format" "github.com/zeromicro/go-zero/tools/goctl/util/format"
) )
@@ -41,53 +41,116 @@ func BuildTypes(types []spec.Type) (string, error) {
return builder.String(), nil return builder.String(), nil
} }
func removeTypeFromDefault(tp spec.Type, group string, groupTypes map[string]map[string]spec.Type) map[string]map[string]spec.Type { func getTypeName(tp spec.Type) string {
if tp == nil {
return ""
}
switch val := tp.(type) { switch val := tp.(type) {
case spec.DefineStruct: case spec.DefineStruct:
typeName := util.Title(tp.Name()) typeName := util.Title(tp.Name())
defaultGroups, ok := groupTypes[groupTypeDefault] return typeName
if ok {
delete(defaultGroups, typeName)
types, ok := groupTypes[group]
if !ok {
types = make(map[string]spec.Type)
}
types[typeName] = tp
groupTypes[group] = types
}
groupTypes[groupTypeDefault] = defaultGroups
case spec.PointerType: case spec.PointerType:
groupTypes = removeTypeFromDefault(val.Type, group, groupTypes) return getTypeName(val.Type)
case spec.ArrayType: case spec.ArrayType:
groupTypes = removeTypeFromDefault(val.Value, group, groupTypes) return getTypeName(val.Value)
} }
return groupTypes return ""
} }
func genTypesWithGroup(dir string, cfg *config.Config, api *spec.ApiSpec) error { func genTypesWithGroup(dir string, cfg *config.Config, api *spec.ApiSpec) error {
groupTypes := make(map[string]map[string]spec.Type) groupTypes := make(map[string]map[string]spec.Type)
for _, v := range api.Types { typesBelongToFiles := make(map[string]*collection.Set)
types, ok := groupTypes[groupTypeDefault]
if !ok {
types = make(map[string]spec.Type)
}
types[util.Title(v.Name())] = v
groupTypes[groupTypeDefault] = types
}
for _, v := range api.Service.Groups { for _, v := range api.Service.Groups {
group := v.GetAnnotation(groupProperty) group := v.GetAnnotation(groupProperty)
if len(group) == 0 { if len(group) == 0 {
group = groupTypeDefault
}
// convert filepath to Identifier name spec.
group = strings.TrimPrefix(group, "/")
group = strings.TrimSuffix(group, "/")
group = util.SafeString(group)
for _, v := range v.Routes {
requestTypeName := getTypeName(v.RequestType)
responseTypeName := getTypeName(v.ResponseType)
requestTypeFileSet, ok := typesBelongToFiles[requestTypeName]
if !ok {
requestTypeFileSet = collection.NewSet()
}
if len(requestTypeName) > 0 {
requestTypeFileSet.AddStr(group)
typesBelongToFiles[requestTypeName] = requestTypeFileSet
}
responseTypeFileSet, ok := typesBelongToFiles[responseTypeName]
if !ok {
responseTypeFileSet = collection.NewSet()
}
if len(responseTypeName) > 0 {
responseTypeFileSet.AddStr(group)
typesBelongToFiles[responseTypeName] = responseTypeFileSet
}
}
}
typesInOneFile := make(map[string]*collection.Set)
for typeName, fileSet := range typesBelongToFiles {
count := fileSet.Count()
switch {
case count == 0: // it means there has no structure type or no request/response body
continue
case count == 1: // it means a structure type used in only one group.
groupName := fileSet.KeysStr()[0]
typeSet, ok := typesInOneFile[groupName]
if !ok {
typeSet = collection.NewSet()
}
typeSet.AddStr(typeName)
typesInOneFile[groupName] = typeSet
default: // it means this type is used in multiple groups.
continue continue
} }
for _, v := range v.Routes { }
if v.RequestType != nil {
groupTypes = removeTypeFromDefault(v.RequestType, group, groupTypes) for _, v := range api.Types {
} typeName := util.Title(v.Name())
if v.ResponseType != nil { groupSet, ok := typesBelongToFiles[typeName]
groupTypes = removeTypeFromDefault(v.ResponseType, group, groupTypes) var typeCount int
} if !ok {
typeCount = 0
} else {
typeCount = groupSet.Count()
} }
if typeCount == 0 { // not belong to any group
types, ok := groupTypes[groupTypeDefault]
if !ok {
types = make(map[string]spec.Type)
}
types[typeName] = v
groupTypes[groupTypeDefault] = types
continue
}
if typeCount == 1 { // belong to one group
groupName := groupSet.KeysStr()[0]
types, ok := groupTypes[groupName]
if !ok {
types = make(map[string]spec.Type)
}
types[typeName] = v
groupTypes[groupName] = types
continue
}
// belong to multiple groups
types, ok := groupTypes[groupTypeDefault]
if !ok {
types = make(map[string]spec.Type)
}
types[typeName] = v
groupTypes[groupTypeDefault] = types
} }
for group, typeGroup := range groupTypes { for group, typeGroup := range groupTypes {
@@ -142,7 +205,7 @@ func writeTypes(dir, baseFilename string, cfg *config.Config, types []spec.Type)
} }
func genTypes(dir string, cfg *config.Config, api *spec.ApiSpec) error { func genTypes(dir string, cfg *config.Config, api *spec.ApiSpec) error {
if env.UseExperimental() { if VarBoolTypeGroup {
return genTypesWithGroup(dir, cfg, api) return genTypesWithGroup(dir, cfg, api)
} }
return writeTypes(dir, typesFile, cfg, api.Types) return writeTypes(dir, typesFile, cfg, api.Types)

View File

@@ -8,10 +8,10 @@ import (
"fmt" "fmt"
"io" "io"
"path" "path"
"slices"
"strings" "strings"
"text/template" "text/template"
"github.com/zeromicro/go-zero/core/stringx"
"github.com/zeromicro/go-zero/tools/goctl/api/spec" "github.com/zeromicro/go-zero/tools/goctl/api/spec"
apiutil "github.com/zeromicro/go-zero/tools/goctl/api/util" apiutil "github.com/zeromicro/go-zero/tools/goctl/api/util"
"github.com/zeromicro/go-zero/tools/goctl/internal/version" "github.com/zeromicro/go-zero/tools/goctl/internal/version"
@@ -96,13 +96,13 @@ func (c *componentsContext) createComponent(dir, packetName string, ty spec.Type
for _, item := range c.responseTypes { for _, item := range c.responseTypes {
if item.Name() == defineStruct.Name() { if item.Name() == defineStruct.Name() {
superClassName = "HttpResponseData" superClassName = "HttpResponseData"
if !stringx.Contains(c.imports, httpResponseData) { if !slices.Contains(c.imports, httpResponseData) {
c.imports = append(c.imports, httpResponseData) c.imports = append(c.imports, httpResponseData)
} }
break break
} }
} }
if superClassName == "HttpData" && !stringx.Contains(c.imports, httpData) { if superClassName == "HttpData" && !slices.Contains(c.imports, httpData) {
c.imports = append(c.imports, httpData) c.imports = append(c.imports, httpData)
} }
@@ -266,7 +266,7 @@ func (c *componentsContext) genGetSet(writer io.Writer, indent int) error {
tyString := javaType tyString := javaType
decorator := "" decorator := ""
javaPrimitiveType := []string{"int", "long", "boolean", "float", "double", "short"} javaPrimitiveType := []string{"int", "long", "boolean", "float", "double", "short"}
if !stringx.Contains(javaPrimitiveType, javaType) { if !slices.Contains(javaPrimitiveType, javaType) {
if member.IsOptional() || member.IsOmitEmpty() { if member.IsOptional() || member.IsOmitEmpty() {
decorator = "@Nullable " decorator = "@Nullable "
} else { } else {

View File

@@ -3,9 +3,9 @@ package spec
import ( import (
"errors" "errors"
"path" "path"
"slices"
"strings" "strings"
"github.com/zeromicro/go-zero/core/stringx"
"github.com/zeromicro/go-zero/tools/goctl/util" "github.com/zeromicro/go-zero/tools/goctl/util"
) )
@@ -64,7 +64,7 @@ func (m Member) IsOptional() bool {
tag := m.Tags() tag := m.Tags()
for _, item := range tag { for _, item := range tag {
if item.Key == bodyTagKey || item.Key == formTagKey { if item.Key == bodyTagKey || item.Key == formTagKey {
if stringx.Contains(item.Options, "optional") { if slices.Contains(item.Options, "optional") {
return true return true
} }
} }
@@ -81,7 +81,7 @@ func (m Member) IsOmitEmpty() bool {
tag := m.Tags() tag := m.Tags()
for _, item := range tag { for _, item := range tag {
if item.Key == bodyTagKey { if item.Key == bodyTagKey {
if stringx.Contains(item.Options, "omitempty") { if slices.Contains(item.Options, "omitempty") {
return true return true
} }
} }
@@ -93,7 +93,7 @@ func (m Member) IsOmitEmpty() bool {
func (m Member) GetPropertyName() (string, error) { func (m Member) GetPropertyName() (string, error) {
tags := m.Tags() tags := m.Tags()
for _, tag := range tags { for _, tag := range tags {
if stringx.Contains(definedKeys, tag.Key) { if slices.Contains(definedKeys, tag.Key) {
if tag.Name == "-" { if tag.Name == "-" {
return util.Untitle(m.Name), nil return util.Untitle(m.Name), nil
} }

View File

@@ -49,6 +49,9 @@ func Parse(tag string) (*Tags, error) {
// Get gets tag value by specified key // Get gets tag value by specified key
func (t *Tags) Get(key string) (*Tag, error) { func (t *Tags) Get(key string) (*Tag, error) {
if t == nil {
return nil, errTagNotExist
}
for _, tag := range t.tags { for _, tag := range t.tags {
if tag.Key == key { if tag.Key == key {
return tag, nil return tag, nil
@@ -60,6 +63,9 @@ func (t *Tags) Get(key string) (*Tag, error) {
// Keys returns all keys in Tags // Keys returns all keys in Tags
func (t *Tags) Keys() []string { func (t *Tags) Keys() []string {
if t == nil {
return []string{}
}
var keys []string var keys []string
for _, tag := range t.tags { for _, tag := range t.tags {
keys = append(keys, tag.Key) keys = append(keys, tag.Key)
@@ -69,5 +75,8 @@ func (t *Tags) Keys() []string {
// Tags returns all tags in Tags // Tags returns all tags in Tags
func (t *Tags) Tags() []*Tag { func (t *Tags) Tags() []*Tag {
if t == nil {
return []*Tag{}
}
return t.tags return t.tags
} }

View File

@@ -0,0 +1,68 @@
package spec
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestTags_Get(t *testing.T) {
tags := &Tags{
tags: []*Tag{
{Key: "json", Name: "foo", Options: []string{"omitempty"}},
{Key: "xml", Name: "bar", Options: nil},
},
}
tag, err := tags.Get("json")
assert.NoError(t, err)
assert.NotNil(t, tag)
assert.Equal(t, "json", tag.Key)
assert.Equal(t, "foo", tag.Name)
_, err = tags.Get("yaml")
assert.Error(t, err)
var nilTags *Tags
_, err = nilTags.Get("json")
assert.Error(t, err)
}
func TestTags_Keys(t *testing.T) {
tags := &Tags{
tags: []*Tag{
{Key: "json", Name: "foo", Options: []string{"omitempty"}},
{Key: "xml", Name: "bar", Options: nil},
},
}
keys := tags.Keys()
expected := []string{"json", "xml"}
assert.Equal(t, expected, keys)
var nilTags *Tags
nilKeys := nilTags.Keys()
assert.Empty(t, nilKeys)
}
func TestTags_Tags(t *testing.T) {
tags := &Tags{
tags: []*Tag{
{Key: "json", Name: "foo", Options: []string{"omitempty"}},
{Key: "xml", Name: "bar", Options: nil},
},
}
result := tags.Tags()
assert.Len(t, result, 2)
assert.Equal(t, "json", result[0].Key)
assert.Equal(t, "foo", result[0].Name)
assert.Equal(t, []string{"omitempty"}, result[0].Options)
assert.Equal(t, "xml", result[1].Key)
assert.Equal(t, "bar", result[1].Name)
assert.Nil(t, result[1].Options)
var nilTags *Tags
nilResult := nilTags.Tags()
assert.Empty(t, nilResult)
}

View File

@@ -7,15 +7,6 @@ import (
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
) )
func hasKey(properties map[string]string, key string) bool {
if len(properties) == 0 {
return false
}
md := metadata.New(properties)
_, ok := md[key]
return ok
}
func getBoolFromKVOrDefault(properties map[string]string, key string, def bool) bool { func getBoolFromKVOrDefault(properties map[string]string, key string, def bool) bool {
if len(properties) == 0 { if len(properties) == 0 {
return def return def

View File

@@ -42,7 +42,7 @@ func Test_getListFromInfoOrDefault(t *testing.T) {
"empty": `""`, "empty": `""`,
} }
assert.Equal(t, []string{"a", "b", "c"}, getListFromInfoOrDefault(properties, "list", []string{"default"})) assert.Equal(t, []string{"a", " b", " c"}, getListFromInfoOrDefault(properties, "list", []string{"default"}))
assert.Equal(t, []string{"default"}, getListFromInfoOrDefault(properties, "empty", []string{"default"})) assert.Equal(t, []string{"default"}, getListFromInfoOrDefault(properties, "empty", []string{"default"}))
assert.Equal(t, []string{"default"}, getListFromInfoOrDefault(properties, "missing", []string{"default"})) assert.Equal(t, []string{"default"}, getListFromInfoOrDefault(properties, "missing", []string{"default"}))
assert.Equal(t, []string{"default"}, getListFromInfoOrDefault(nil, "nil", []string{"default"})) assert.Equal(t, []string{"default"}, getListFromInfoOrDefault(nil, "nil", []string{"default"}))

View File

@@ -8,10 +8,9 @@ import (
"strings" "strings"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"gopkg.in/yaml.v2"
"github.com/zeromicro/go-zero/tools/goctl/pkg/parser/api/parser" "github.com/zeromicro/go-zero/tools/goctl/pkg/parser/api/parser"
"github.com/zeromicro/go-zero/tools/goctl/util/pathx" "github.com/zeromicro/go-zero/tools/goctl/util/pathx"
"gopkg.in/yaml.v2"
) )
var ( var (

View File

@@ -1,14 +1,15 @@
package swagger package swagger
const ( const (
tagHeader = "header" tagHeader = "header"
tagPath = "path" tagPath = "path"
tagForm = "form" tagForm = "form"
tagJson = "json" tagJson = "json"
defFlag = "default=" defFlag = "default="
enumFlag = "options=" enumFlag = "options="
rangeFlag = "range=" rangeFlag = "range="
exampleFlag = "example=" exampleFlag = "example="
optionalFlag = "optional"
paramsInHeader = "header" paramsInHeader = "header"
paramsInPath = "path" paramsInPath = "path"
@@ -27,6 +28,38 @@ const (
applicationJson = "application/json" applicationJson = "application/json"
applicationForm = "application/x-www-form-urlencoded" applicationForm = "application/x-www-form-urlencoded"
schemeHttps = "https" schemeHttps = "https"
defaultHost = "127.0.0.1"
defaultBasePath = "/" defaultBasePath = "/"
) )
const (
propertyKeyUseDefinitions = "useDefinitions"
propertyKeyExternalDocsDescription = "externalDocsDescription"
propertyKeyExternalDocsURL = "externalDocsURL"
propertyKeyTitle = "title"
propertyKeyTermsOfService = "termsOfService"
propertyKeyDescription = "description"
propertyKeyVersion = "version"
propertyKeyContactName = "contactName"
propertyKeyContactURL = "contactURL"
propertyKeyContactEmail = "contactEmail"
propertyKeyLicenseName = "licenseName"
propertyKeyLicenseURL = "licenseURL"
propertyKeyProduces = "produces"
propertyKeyConsumes = "consumes"
propertyKeySchemes = "schemes"
propertyKeyTags = "tags"
propertyKeySummary = "summary"
propertyKeyGroup = "group"
propertyKeyOperationId = "operationId"
propertyKeyDeprecated = "deprecated"
propertyKeyPrefix = "prefix"
propertyKeyAuthType = "authType"
propertyKeyHost = "host"
propertyKeyBasePath = "basePath"
propertyKeyWrapCodeMsg = "wrapCodeMsg"
propertyKeyBizCodeEnumDescription = "bizCodeEnumDescription"
)
const (
defaultValueOfPropertyUseDefinition = false
)

View File

@@ -7,7 +7,7 @@ import (
"github.com/zeromicro/go-zero/tools/goctl/api/spec" "github.com/zeromicro/go-zero/tools/goctl/api/spec"
) )
func consumesFromTypeOrDef(method string, tp spec.Type) []string { func consumesFromTypeOrDef(ctx Context, method string, tp spec.Type) []string {
if strings.EqualFold(method, http.MethodGet) { if strings.EqualFold(method, http.MethodGet) {
return []string{} return []string{}
} }
@@ -18,7 +18,7 @@ func consumesFromTypeOrDef(method string, tp spec.Type) []string {
if !ok { if !ok {
return []string{} return []string{}
} }
if typeContainsTag(structType, tagJson) { if typeContainsTag(ctx, structType, tagJson) {
return []string{applicationJson} return []string{applicationJson}
} }
return []string{applicationForm} return []string{applicationForm}

View File

@@ -61,7 +61,7 @@ func TestConsumesFromTypeOrDef(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
result := consumesFromTypeOrDef(tt.method, tt.tp) result := consumesFromTypeOrDef(testingContext(t), tt.method, tt.tp)
assert.Equal(t, tt.expected, result) assert.Equal(t, tt.expected, result)
}) })
} }

View File

@@ -0,0 +1,28 @@
package swagger
import (
"testing"
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
)
type Context struct {
UseDefinitions bool
WrapCodeMsg bool
BizCodeEnumDescription string
}
func testingContext(_ *testing.T) Context {
return Context{}
}
func contextFromApi(info spec.Info) Context {
if len(info.Properties) == 0 {
return Context{}
}
return Context{
UseDefinitions: getBoolFromKVOrDefault(info.Properties, propertyKeyUseDefinitions, defaultValueOfPropertyUseDefinition),
WrapCodeMsg: getBoolFromKVOrDefault(info.Properties, propertyKeyWrapCodeMsg, false),
BizCodeEnumDescription: getStringFromKVOrDefault(info.Properties, propertyKeyBizCodeEnumDescription, "business code"),
}
}

View File

@@ -0,0 +1,32 @@
package swagger
import (
"github.com/go-openapi/spec"
apiSpec "github.com/zeromicro/go-zero/tools/goctl/api/spec"
)
func definitionsFromTypes(ctx Context, types []apiSpec.Type) spec.Definitions {
if !ctx.UseDefinitions {
return nil
}
definitions := make(spec.Definitions)
for _, tp := range types {
typeName := tp.Name()
definitions[typeName] = schemaFromType(ctx, tp)
}
return definitions
}
func schemaFromType(ctx Context, tp apiSpec.Type) spec.Schema {
p, r := propertiesFromType(ctx, tp)
props := spec.SchemaProps{
Type: typeFromGoType(ctx, tp),
Properties: p,
AdditionalProperties: mapFromGoType(ctx, tp),
Items: itemsFromGoType(ctx, tp),
Required: r,
}
return spec.Schema{
SchemaProps: props,
}
}

View File

@@ -12,15 +12,16 @@ info (
licenseURL: "https://github.com/zeromicro/go-zero" // licenseURL corresponding to Swagger licenseURL: "https://github.com/zeromicro/go-zero" // licenseURL corresponding to Swagger
consumes: "application/json" // consumes corresponding to Swagger,default value is `application/json` consumes: "application/json" // consumes corresponding to Swagger,default value is `application/json`
produces: "application/json" // produces corresponding to Swagger,default value is `application/json` produces: "application/json" // produces corresponding to Swagger,default value is `application/json`
schemes: "https" // schemes corresponding to Swagger,default value is `https`` schemes: "http,https" // schemes corresponding to Swagger,default value is `https``
host: "example.com" // host corresponding to Swagger,default value is `127.0.0.1` host: "example.com" // host corresponding to Swagger,default value is `127.0.0.1`
basePath: "/v1" // basePath corresponding to Swagger,default value is `/` basePath: "/v1" // basePath corresponding to Swagger,default value is `/`
wrapCodeMsg: "true" // to wrap in the universal code-msg structure, like {"code":0,"msg":"OK","data":$data} wrapCodeMsg: true // to wrap in the universal code-msg structure, like {"code":0,"msg":"OK","data":$data}
bizCodeEnumDescription: "1001-User not login<br>1002-User permission denied" // enums of business error codes, in JSON format, with the key being the business error code and the value being the description of that error code. This only takes effect when wrapCodeMsg is set to true. bizCodeEnumDescription: "1001-User not login<br>1002-User permission denied" // enums of business error codes, in JSON format, with the key being the business error code and the value being the description of that error code. This only takes effect when wrapCodeMsg is set to true.
// securityDefinitionsFromJson is a custom authentication configuration, and the JSON content will be directly inserted into the securityDefinitions of Swagger. // securityDefinitionsFromJson is a custom authentication configuration, and the JSON content will be directly inserted into the securityDefinitions of Swagger.
// Format reference: https://swagger.io/specification/v2/#security-definitions-object // Format reference: https://swagger.io/specification/v2/#security-definitions-object
// You can declare authType in the @server of the API to specify the authentication type used for its routes. // You can declare authType in the @server of the API to specify the authentication type used for its routes.
securityDefinitionsFromJson: `{"apiKey":{"description":"apiKey type description","type":"apiKey","name":"x-api-key","in":"header"}}` securityDefinitionsFromJson: `{"apiKey":{"description":"apiKey type description","type":"apiKey","name":"x-api-key","in":"header"}}`
useDefinitions: true // if set true, the definitions will be generated in the swagger.json for response body or json request body file, and the models will be referenced in the API.
) )
type ( type (

File diff suppressed because it is too large Load Diff

View File

@@ -12,15 +12,16 @@ info (
licenseURL: "https://github.com/zeromicro/go-zero" // 对应 swagger 的 licenseURL licenseURL: "https://github.com/zeromicro/go-zero" // 对应 swagger 的 licenseURL
consumes: "application/json" // 对应 swagger 的 consumes,不填默认为 application/json consumes: "application/json" // 对应 swagger 的 consumes,不填默认为 application/json
produces: "application/json" // 对应 swagger 的 produces,不填默认为 application/json produces: "application/json" // 对应 swagger 的 produces,不填默认为 application/json
schemes: "https" // 对应 swagger 的 schemes,不填默认为 https schemes: "http,https" // 对应 swagger 的 schemes,不填默认为 https
host: "example.com" // 对应 swagger 的 host,不填默认为 127.0.0.1 host: "example.com" // 对应 swagger 的 host,不填默认为 127.0.0.1
basePath: "/v1" // 对应 swagger 的 basePath,不填默认为 / basePath: "/v1" // 对应 swagger 的 basePath,不填默认为 /
wrapCodeMsg: "true" // 是否用 code-msg 通用响应体,如果开启,则以格式 {"code":0,"msg":"OK","data":$data} 包括响应体 wrapCodeMsg: true // 是否用 code-msg 通用响应体,如果开启,则以格式 {"code":0,"msg":"OK","data":$data} 包括响应体
bizCodeEnumDescription: "1001-未登录<br>1002-无权限操作" // 全局业务错误码枚举描述json 格式,key 为业务错误码value 为该错误码的描述,仅当 wrapCodeMsg 为 true 时生效 bizCodeEnumDescription: "1001-未登录<br>1002-无权限操作" // 全局业务错误码枚举描述json 格式,key 为业务错误码value 为该错误码的描述,仅当 wrapCodeMsg 为 true 时生效
// securityDefinitionsFromJson 为自定义鉴权配置json 内容将直接放入 swagger 的 securityDefinitions 中, // securityDefinitionsFromJson 为自定义鉴权配置json 内容将直接放入 swagger 的 securityDefinitions 中,
// 格式参考 https://swagger.io/specification/v2/#security-definitions-object // 格式参考 https://swagger.io/specification/v2/#security-definitions-object
// 在 api 的 @server 中可声明 authType 来指定其路由使用的鉴权类型 // 在 api 的 @server 中可声明 authType 来指定其路由使用的鉴权类型
securityDefinitionsFromJson: `{"apiKey":{"description":"apiKey 类型鉴权自定义","type":"apiKey","name":"x-api-key","in":"header"}}` securityDefinitionsFromJson: `{"apiKey":{"description":"apiKey 类型鉴权自定义","type":"apiKey","name":"x-api-key","in":"header"}}`
useDefinitions: true// 开启声明将生成models 进行关联definitions 仅对响应体和 json 请求体生效
) )
type ( type (
@@ -48,11 +49,12 @@ type (
summary: "query 类型接口集合" // 对应 swagger 的 summary summary: "query 类型接口集合" // 对应 swagger 的 summary
prefix: v1 prefix: v1
authType: apiKey // 指定该路由使用的鉴权类型,值为 securityDefinitionsFromJson 中定义的名称 authType: apiKey // 指定该路由使用的鉴权类型,值为 securityDefinitionsFromJson 中定义的名称
group:"demo"
) )
service Swagger { service Swagger {
@doc ( @doc (
description: "query 接口" description: "query 接口"
bizCodeEnumDescription: " 1003-用不存在<br>1004-非法操作" // 接口级别业务错误码枚举描述会覆盖全局的业务错误码json 格式,key 为业务错误码value 为该错误码的描述,仅当 wrapCodeMsg 为 true 时生效 bizCodeEnumDescription: " 1003-用不存在<br>1004-非法操作" // 接口级别业务错误码枚举描述会覆盖全局的业务错误码json 格式,key 为业务错误码value 为该错误码的描述,仅当 wrapCodeMsg 为 true 且 useDefinitions 为 false 时生效
) )
@handler query @handler query
get /query (QueryReq) returns (QueryResp) get /query (QueryReq) returns (QueryResp)

File diff suppressed because it is too large Load Diff

View File

@@ -81,21 +81,21 @@ func enumsValueFromOptions(options []string) []any {
return []any{} return []any{}
} }
func defValueFromOptions(options []string, apiType spec.Type) any { func defValueFromOptions(ctx Context, options []string, apiType spec.Type) any {
tp := sampleTypeFromGoType(apiType) tp := sampleTypeFromGoType(ctx, apiType)
return valueFromOptions(options, defFlag, tp) return valueFromOptions(ctx, options, defFlag, tp)
} }
func exampleValueFromOptions(options []string, apiType spec.Type) any { func exampleValueFromOptions(ctx Context, options []string, apiType spec.Type) any {
tp := sampleTypeFromGoType(apiType) tp := sampleTypeFromGoType(ctx, apiType)
val := valueFromOptions(options, exampleFlag, tp) val := valueFromOptions(ctx, options, exampleFlag, tp)
if val != nil { if val != nil {
return val return val
} }
return defValueFromOptions(options, apiType) return defValueFromOptions(ctx, options, apiType)
} }
func valueFromOptions(options []string, key string, tp string) any { func valueFromOptions(_ Context, options []string, key string, tp string) any {
if len(options) == 0 { if len(options) == 0 {
return nil return nil
} }
@@ -103,16 +103,18 @@ func valueFromOptions(options []string, key string, tp string) any {
if strings.HasPrefix(option, key) { if strings.HasPrefix(option, key) {
s := option[len(key):] s := option[len(key):]
switch tp { switch tp {
case "integer": case swaggerTypeInteger:
val, _ := strconv.ParseInt(s, 10, 64) val, _ := strconv.ParseInt(s, 10, 64)
return val return val
case "boolean": case swaggerTypeBoolean:
val, _ := strconv.ParseBool(s) val, _ := strconv.ParseBool(s)
return val return val
case "number": case swaggerTypeNumber:
val, _ := strconv.ParseFloat(s, 64) val, _ := strconv.ParseFloat(s, 64)
return val return val
case "string": case swaggerTypeArray:
return s
case swaggerTypeString:
return s return s
default: default:
return nil return nil

View File

@@ -161,7 +161,7 @@ func TestDefValueFromOptions(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
result := defValueFromOptions(tt.options, tt.apiType) result := defValueFromOptions(testingContext(t), tt.options, tt.apiType)
assert.Equal(t, tt.expected, result) assert.Equal(t, tt.expected, result)
}) })
} }
@@ -202,7 +202,7 @@ func TestExampleValueFromOptions(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
exampleValueFromOptions(tt.options, tt.apiType) exampleValueFromOptions(testingContext(t), tt.options, tt.apiType)
}) })
} }
} }
@@ -247,7 +247,7 @@ func TestValueFromOptions(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
result := valueFromOptions(tt.options, tt.key, tt.tp) result := valueFromOptions(testingContext(t), tt.options, tt.key, tt.tp)
assert.Equal(t, tt.expected, result) assert.Equal(t, tt.expected, result)
}) })
} }

View File

@@ -8,7 +8,25 @@ import (
apiSpec "github.com/zeromicro/go-zero/tools/goctl/api/spec" apiSpec "github.com/zeromicro/go-zero/tools/goctl/api/spec"
) )
func parametersFromType(method string, tp apiSpec.Type) []spec.Parameter { func isPostJson(ctx Context, method string, tp apiSpec.Type) (string, bool) {
if strings.EqualFold(method, http.MethodPost) {
return "", false
}
structType, ok := tp.(apiSpec.DefineStruct)
if !ok {
return "", false
}
var isPostJson bool
rangeMemberAndDo(ctx, structType, func(tag *apiSpec.Tags, required bool, member apiSpec.Member) {
jsonTag, _ := tag.Get(tagJson)
if !isPostJson {
isPostJson = jsonTag != nil
}
})
return structType.RawName, isPostJson
}
func parametersFromType(ctx Context, method string, tp apiSpec.Type) []spec.Parameter {
if tp == nil { if tp == nil {
return []spec.Parameter{} return []spec.Parameter{}
} }
@@ -16,12 +34,13 @@ func parametersFromType(method string, tp apiSpec.Type) []spec.Parameter {
if !ok { if !ok {
return []spec.Parameter{} return []spec.Parameter{}
} }
var ( var (
resp []spec.Parameter resp []spec.Parameter
properties = map[string]spec.Schema{} properties = map[string]spec.Schema{}
requiredFields []string requiredFields []string
) )
rangeMemberAndDo(structType, func(tag *apiSpec.Tags, required bool, member apiSpec.Member) { rangeMemberAndDo(ctx, structType, func(tag *apiSpec.Tags, required bool, member apiSpec.Member) {
headerTag, _ := tag.Get(tagHeader) headerTag, _ := tag.Get(tagHeader)
hasHeader := headerTag != nil hasHeader := headerTag != nil
@@ -44,10 +63,9 @@ func parametersFromType(method string, tp apiSpec.Type) []spec.Parameter {
Enum: enumsValueFromOptions(headerTag.Options), Enum: enumsValueFromOptions(headerTag.Options),
}, },
SimpleSchema: spec.SimpleSchema{ SimpleSchema: spec.SimpleSchema{
Type: sampleTypeFromGoType(member.Type), Type: sampleTypeFromGoType(ctx, member.Type),
Default: defValueFromOptions(headerTag.Options, member.Type), Default: defValueFromOptions(ctx, headerTag.Options, member.Type),
Example: exampleValueFromOptions(headerTag.Options, member.Type), Items: sampleItemsFromGoType(ctx, member.Type),
Items: sampleItemsFromGoType(member.Type),
}, },
ParamProps: spec.ParamProps{ ParamProps: spec.ParamProps{
In: paramsInHeader, In: paramsInHeader,
@@ -68,10 +86,9 @@ func parametersFromType(method string, tp apiSpec.Type) []spec.Parameter {
Enum: enumsValueFromOptions(pathParameterTag.Options), Enum: enumsValueFromOptions(pathParameterTag.Options),
}, },
SimpleSchema: spec.SimpleSchema{ SimpleSchema: spec.SimpleSchema{
Type: sampleTypeFromGoType(member.Type), Type: sampleTypeFromGoType(ctx, member.Type),
Default: defValueFromOptions(pathParameterTag.Options, member.Type), Default: defValueFromOptions(ctx, pathParameterTag.Options, member.Type),
Example: exampleValueFromOptions(pathParameterTag.Options, member.Type), Items: sampleItemsFromGoType(ctx, member.Type),
Items: sampleItemsFromGoType(member.Type),
}, },
ParamProps: spec.ParamProps{ ParamProps: spec.ParamProps{
In: paramsInPath, In: paramsInPath,
@@ -93,10 +110,9 @@ func parametersFromType(method string, tp apiSpec.Type) []spec.Parameter {
Enum: enumsValueFromOptions(formTag.Options), Enum: enumsValueFromOptions(formTag.Options),
}, },
SimpleSchema: spec.SimpleSchema{ SimpleSchema: spec.SimpleSchema{
Type: sampleTypeFromGoType(member.Type), Type: sampleTypeFromGoType(ctx, member.Type),
Default: defValueFromOptions(formTag.Options, member.Type), Default: defValueFromOptions(ctx, formTag.Options, member.Type),
Example: exampleValueFromOptions(formTag.Options, member.Type), Items: sampleItemsFromGoType(ctx, member.Type),
Items: sampleItemsFromGoType(member.Type),
}, },
ParamProps: spec.ParamProps{ ParamProps: spec.ParamProps{
In: paramsInQuery, In: paramsInQuery,
@@ -116,10 +132,9 @@ func parametersFromType(method string, tp apiSpec.Type) []spec.Parameter {
Enum: enumsValueFromOptions(formTag.Options), Enum: enumsValueFromOptions(formTag.Options),
}, },
SimpleSchema: spec.SimpleSchema{ SimpleSchema: spec.SimpleSchema{
Type: sampleTypeFromGoType(member.Type), Type: sampleTypeFromGoType(ctx, member.Type),
Default: defValueFromOptions(formTag.Options, member.Type), Default: defValueFromOptions(ctx, formTag.Options, member.Type),
Example: exampleValueFromOptions(formTag.Options, member.Type), Items: sampleItemsFromGoType(ctx, member.Type),
Items: sampleItemsFromGoType(member.Type),
}, },
ParamProps: spec.ParamProps{ ParamProps: spec.ParamProps{
In: paramsInForm, In: paramsInForm,
@@ -139,25 +154,25 @@ func parametersFromType(method string, tp apiSpec.Type) []spec.Parameter {
} }
var schema = spec.Schema{ var schema = spec.Schema{
SwaggerSchemaProps: spec.SwaggerSchemaProps{ SwaggerSchemaProps: spec.SwaggerSchemaProps{
Example: exampleValueFromOptions(jsonTag.Options, member.Type), Example: exampleValueFromOptions(ctx, jsonTag.Options, member.Type),
}, },
SchemaProps: spec.SchemaProps{ SchemaProps: spec.SchemaProps{
Description: formatComment(member.Comment), Description: formatComment(member.Comment),
Type: typeFromGoType(member.Type), Type: typeFromGoType(ctx, member.Type),
Default: defValueFromOptions(jsonTag.Options, member.Type), Default: defValueFromOptions(ctx, jsonTag.Options, member.Type),
Maximum: maximum, Maximum: maximum,
ExclusiveMaximum: exclusiveMaximum, ExclusiveMaximum: exclusiveMaximum,
Minimum: minimum, Minimum: minimum,
ExclusiveMinimum: exclusiveMinimum, ExclusiveMinimum: exclusiveMinimum,
Enum: enumsValueFromOptions(jsonTag.Options), Enum: enumsValueFromOptions(jsonTag.Options),
AdditionalProperties: mapFromGoType(member.Type), AdditionalProperties: mapFromGoType(ctx, member.Type),
}, },
} }
switch sampleTypeFromGoType(member.Type) { switch sampleTypeFromGoType(ctx, member.Type) {
case swaggerTypeArray: case swaggerTypeArray:
schema.Items = itemsFromGoType(member.Type) schema.Items = itemsFromGoType(ctx, member.Type)
case swaggerTypeObject: case swaggerTypeObject:
p, r := propertiesFromType(member.Type) p, r := propertiesFromType(ctx, member.Type)
schema.Properties = p schema.Properties = p
schema.Required = r schema.Required = r
} }
@@ -165,20 +180,38 @@ func parametersFromType(method string, tp apiSpec.Type) []spec.Parameter {
} }
}) })
if len(properties) > 0 { if len(properties) > 0 {
resp = append(resp, spec.Parameter{ if ctx.UseDefinitions {
ParamProps: spec.ParamProps{ structName, ok := isPostJson(ctx, method, tp)
In: paramsInBody, if ok {
Name: paramsInBody, resp = append(resp, spec.Parameter{
Required: true, ParamProps: spec.ParamProps{
Schema: &spec.Schema{ In: paramsInBody,
SchemaProps: spec.SchemaProps{ Name: paramsInBody,
Type: typeFromGoType(structType), Required: true,
Properties: properties, Schema: &spec.Schema{
Required: requiredFields, SchemaProps: spec.SchemaProps{
Ref: spec.MustCreateRef(getRefName(structName)),
},
},
},
})
}
} else {
resp = append(resp, spec.Parameter{
ParamProps: spec.ParamProps{
In: paramsInBody,
Name: paramsInBody,
Required: true,
Schema: &spec.Schema{
SchemaProps: spec.SchemaProps{
Type: typeFromGoType(ctx, structType),
Properties: properties,
Required: requiredFields,
},
}, },
}, },
}, })
}) }
} }
return resp return resp
} }

View File

@@ -7,20 +7,21 @@ import (
"github.com/go-openapi/spec" "github.com/go-openapi/spec"
apiSpec "github.com/zeromicro/go-zero/tools/goctl/api/spec" apiSpec "github.com/zeromicro/go-zero/tools/goctl/api/spec"
"github.com/zeromicro/go-zero/tools/goctl/util/stringx"
) )
func spec2Paths(info apiSpec.Info, srv apiSpec.Service) *spec.Paths { func spec2Paths(ctx Context, srv apiSpec.Service) *spec.Paths {
paths := &spec.Paths{ paths := &spec.Paths{
Paths: make(map[string]spec.PathItem), Paths: make(map[string]spec.PathItem),
} }
for _, group := range srv.Groups { for _, group := range srv.Groups {
prefix := path.Clean(strings.TrimPrefix(group.GetAnnotation("prefix"), "/")) prefix := path.Clean(strings.TrimPrefix(group.GetAnnotation(propertyKeyPrefix), "/"))
for _, route := range group.Routes { for _, route := range group.Routes {
routPath := pathVariable2SwaggerVariable(route.Path) routPath := pathVariable2SwaggerVariable(ctx, route.Path)
if len(prefix) > 0 && prefix != "." { if len(prefix) > 0 && prefix != "." {
routPath = "/" + path.Clean(prefix) + routPath routPath = "/" + path.Clean(prefix) + routPath
} }
pathItem := spec2Path(info, group, route) pathItem := spec2Path(ctx, group, route)
existPathItem, ok := paths.Paths[routPath] existPathItem, ok := paths.Paths[routPath]
if !ok { if !ok {
paths.Paths[routPath] = pathItem paths.Paths[routPath] = pathItem
@@ -60,8 +61,8 @@ func mergePathItem(old, new spec.PathItem) spec.PathItem {
return old return old
} }
func spec2Path(info apiSpec.Info, group apiSpec.Group, route apiSpec.Route) spec.PathItem { func spec2Path(ctx Context, group apiSpec.Group, route apiSpec.Route) spec.PathItem {
authType := getStringFromKVOrDefault(group.Annotation.Properties, "authType", "") authType := getStringFromKVOrDefault(group.Annotation.Properties, propertyKeyAuthType, "")
var security []map[string][]string var security []map[string][]string
if len(authType) > 0 { if len(authType) > 0 {
security = []map[string][]string{ security = []map[string][]string{
@@ -70,22 +71,29 @@ func spec2Path(info apiSpec.Info, group apiSpec.Group, route apiSpec.Route) spec
}, },
} }
} }
groupName := getStringFromKVOrDefault(group.Annotation.Properties, propertyKeyGroup, "")
operationId := route.Handler
if len(groupName) > 0 {
operationId = stringx.From(groupName + "_" + route.Handler).ToCamel()
}
operationId = stringx.From(operationId).Untitle()
op := &spec.Operation{ op := &spec.Operation{
OperationProps: spec.OperationProps{ OperationProps: spec.OperationProps{
Description: getStringFromKVOrDefault(route.AtDoc.Properties, "description", ""), Description: getStringFromKVOrDefault(route.AtDoc.Properties, propertyKeyDescription, ""),
Consumes: consumesFromTypeOrDef(route.Method, route.RequestType), Consumes: consumesFromTypeOrDef(ctx, route.Method, route.RequestType),
Produces: getListFromInfoOrDefault(route.AtDoc.Properties, "produces", []string{applicationJson}), Produces: getListFromInfoOrDefault(route.AtDoc.Properties, propertyKeyProduces, []string{applicationJson}),
Schemes: getListFromInfoOrDefault(route.AtDoc.Properties, "schemes", []string{schemeHttps}), Schemes: getListFromInfoOrDefault(route.AtDoc.Properties, propertyKeySchemes, []string{schemeHttps}),
Tags: getListFromInfoOrDefault(group.Annotation.Properties, "tags", []string{""}), Tags: getListFromInfoOrDefault(group.Annotation.Properties, propertyKeyTags, getListFromInfoOrDefault(group.Annotation.Properties, propertyKeySummary, []string{})),
Summary: getStringFromKVOrDefault(route.AtDoc.Properties, "summary", getFirstUsableString(route.AtDoc.Text, route.Handler)), Summary: getStringFromKVOrDefault(route.AtDoc.Properties, propertyKeySummary, getFirstUsableString(route.AtDoc.Text, route.Handler)),
Deprecated: getBoolFromKVOrDefault(route.AtDoc.Properties, "deprecated", false), ID: operationId,
Parameters: parametersFromType(route.Method, route.RequestType), Deprecated: getBoolFromKVOrDefault(route.AtDoc.Properties, propertyKeyDeprecated, false),
Responses: jsonResponseFromType(info, route.AtDoc, route.ResponseType), Parameters: parametersFromType(ctx, route.Method, route.RequestType),
Security: security, Security: security,
Responses: jsonResponseFromType(ctx, route.AtDoc, route.ResponseType),
}, },
} }
externalDocsDescription := getStringFromKVOrDefault(route.AtDoc.Properties, "externalDocsDescription", "") externalDocsDescription := getStringFromKVOrDefault(route.AtDoc.Properties, propertyKeyExternalDocsDescription, "")
externalDocsURL := getStringFromKVOrDefault(route.AtDoc.Properties, "externalDocsURL", "") externalDocsURL := getStringFromKVOrDefault(route.AtDoc.Properties, propertyKeyExternalDocsURL, "")
if len(externalDocsDescription) > 0 || len(externalDocsURL) > 0 { if len(externalDocsDescription) > 0 || len(externalDocsURL) > 0 {
op.ExternalDocs = &spec.ExternalDocumentation{ op.ExternalDocs = &spec.ExternalDocumentation{
Description: externalDocsDescription, Description: externalDocsDescription,

View File

@@ -5,18 +5,18 @@ import (
apiSpec "github.com/zeromicro/go-zero/tools/goctl/api/spec" apiSpec "github.com/zeromicro/go-zero/tools/goctl/api/spec"
) )
func propertiesFromType(tp apiSpec.Type) (spec.SchemaProperties, []string) { func propertiesFromType(ctx Context, tp apiSpec.Type) (spec.SchemaProperties, []string) {
var ( var (
properties = map[string]spec.Schema{} properties = map[string]spec.Schema{}
requiredFields []string requiredFields []string
) )
switch val := tp.(type) { switch val := tp.(type) {
case apiSpec.PointerType: case apiSpec.PointerType:
return propertiesFromType(val.Type) return propertiesFromType(ctx, val.Type)
case apiSpec.ArrayType: case apiSpec.ArrayType:
return propertiesFromType(val.Value) return propertiesFromType(ctx, val.Value)
case apiSpec.DefineStruct, apiSpec.NestedStruct: case apiSpec.DefineStruct, apiSpec.NestedStruct:
rangeMemberAndDo(val, func(tag *apiSpec.Tags, required bool, member apiSpec.Member) { rangeMemberAndDo(ctx, val, func(tag *apiSpec.Tags, required bool, member apiSpec.Member) {
var ( var (
jsonTagString = member.Name jsonTagString = member.Name
minimum, maximum *float64 minimum, maximum *float64
@@ -24,42 +24,63 @@ func propertiesFromType(tp apiSpec.Type) (spec.SchemaProperties, []string) {
example, defaultValue any example, defaultValue any
enum []any enum []any
) )
pathTag, _ := tag.Get(tagPath)
if pathTag != nil {
return
}
formTag, _ := tag.Get(tagForm)
if formTag != nil {
return
}
headerTag, _ := tag.Get(tagHeader)
if headerTag != nil {
return
}
jsonTag, _ := tag.Get(tagJson) jsonTag, _ := tag.Get(tagJson)
if jsonTag != nil { if jsonTag != nil {
jsonTagString = jsonTag.Name jsonTagString = jsonTag.Name
minimum, maximum, exclusiveMinimum, exclusiveMaximum = rangeValueFromOptions(jsonTag.Options) minimum, maximum, exclusiveMinimum, exclusiveMaximum = rangeValueFromOptions(jsonTag.Options)
example = exampleValueFromOptions(jsonTag.Options, member.Type) example = exampleValueFromOptions(ctx, jsonTag.Options, member.Type)
defaultValue = defValueFromOptions(jsonTag.Options, member.Type) defaultValue = defValueFromOptions(ctx, jsonTag.Options, member.Type)
enum = enumsValueFromOptions(jsonTag.Options) enum = enumsValueFromOptions(jsonTag.Options)
} }
if required { if required {
requiredFields = append(requiredFields, jsonTagString) requiredFields = append(requiredFields, jsonTagString)
} }
var schema = spec.Schema{
schema := spec.Schema{
SwaggerSchemaProps: spec.SwaggerSchemaProps{ SwaggerSchemaProps: spec.SwaggerSchemaProps{
Example: example, Example: example,
}, },
SchemaProps: spec.SchemaProps{ SchemaProps: spec.SchemaProps{
Description: formatComment(member.Comment), Description: formatComment(member.Comment),
Type: typeFromGoType(member.Type), Type: typeFromGoType(ctx, member.Type),
Default: defaultValue, Default: defaultValue,
Maximum: maximum, Maximum: maximum,
ExclusiveMaximum: exclusiveMaximum, ExclusiveMaximum: exclusiveMaximum,
Minimum: minimum, Minimum: minimum,
ExclusiveMinimum: exclusiveMinimum, ExclusiveMinimum: exclusiveMinimum,
Enum: enum, Enum: enum,
AdditionalProperties: mapFromGoType(member.Type), AdditionalProperties: mapFromGoType(ctx, member.Type),
}, },
} }
switch sampleTypeFromGoType(member.Type) {
switch sampleTypeFromGoType(ctx, member.Type) {
case swaggerTypeArray: case swaggerTypeArray:
schema.Items = itemsFromGoType(member.Type) schema.Items = itemsFromGoType(ctx, member.Type)
case swaggerTypeObject: case swaggerTypeObject:
p, r := propertiesFromType(member.Type) p, r := propertiesFromType(ctx, member.Type)
schema.Properties = p schema.Properties = p
schema.Required = r schema.Required = r
} }
if ctx.UseDefinitions {
structName, containsStruct := containsStruct(member.Type)
if containsStruct {
schema.SchemaProps.Ref = spec.MustCreateRef(getRefName(structName))
}
}
properties[jsonTagString] = schema properties[jsonTagString] = schema
}) })
@@ -67,3 +88,22 @@ func propertiesFromType(tp apiSpec.Type) (spec.SchemaProperties, []string) {
return properties, requiredFields return properties, requiredFields
} }
func containsStruct(tp apiSpec.Type) (string, bool) {
switch val := tp.(type) {
case apiSpec.PointerType:
return containsStruct(val.Type)
case apiSpec.ArrayType:
return containsStruct(val.Value)
case apiSpec.DefineStruct:
return val.RawName, true
case apiSpec.MapType:
return containsStruct(val.Value)
default:
return "", false
}
}
func getRefName(typeName string) string {
return "#/definitions/" + typeName
}

View File

@@ -1,25 +1,62 @@
package swagger package swagger
import ( import (
"net/http"
"github.com/go-openapi/spec" "github.com/go-openapi/spec"
apiSpec "github.com/zeromicro/go-zero/tools/goctl/api/spec" apiSpec "github.com/zeromicro/go-zero/tools/goctl/api/spec"
) )
func jsonResponseFromType(info apiSpec.Info, atDoc apiSpec.AtDoc, tp apiSpec.Type) *spec.Responses { func jsonResponseFromType(ctx Context, atDoc apiSpec.AtDoc, tp apiSpec.Type) *spec.Responses {
p, _ := propertiesFromType(tp) if tp == nil {
return &spec.Responses{
ResponsesProps: spec.ResponsesProps{
StatusCodeResponses: map[int]spec.Response{
http.StatusOK: {
ResponseProps: spec.ResponseProps{
Description: "",
Schema: &spec.Schema{},
},
},
},
},
}
}
props := spec.SchemaProps{ props := spec.SchemaProps{
Type: typeFromGoType(tp), AdditionalProperties: mapFromGoType(ctx, tp),
Properties: p, Items: itemsFromGoType(ctx, tp),
AdditionalProperties: mapFromGoType(tp), }
Items: itemsFromGoType(tp), if ctx.UseDefinitions {
structName, ok := containsStruct(tp)
if ok {
props.Ref = spec.MustCreateRef(getRefName(structName))
return &spec.Responses{
ResponsesProps: spec.ResponsesProps{
StatusCodeResponses: map[int]spec.Response{
http.StatusOK: {
ResponseProps: spec.ResponseProps{
Schema: &spec.Schema{
SchemaProps: wrapCodeMsgProps(ctx, props, atDoc),
},
},
},
},
},
}
}
} }
p, _ := propertiesFromType(ctx, tp)
props.Type = typeFromGoType(ctx, tp)
props.Properties = p
return &spec.Responses{ return &spec.Responses{
ResponsesProps: spec.ResponsesProps{ ResponsesProps: spec.ResponsesProps{
Default: &spec.Response{ StatusCodeResponses: map[int]spec.Response{
ResponseProps: spec.ResponseProps{ http.StatusOK: {
Schema: &spec.Schema{ ResponseProps: spec.ResponseProps{
SchemaProps: wrapCodeMsgProps(props, info, atDoc), Schema: &spec.Schema{
SchemaProps: wrapCodeMsgProps(ctx, props, atDoc),
},
}, },
}, },
}, },

View File

@@ -8,12 +8,11 @@ import (
"github.com/go-openapi/spec" "github.com/go-openapi/spec"
apiSpec "github.com/zeromicro/go-zero/tools/goctl/api/spec" apiSpec "github.com/zeromicro/go-zero/tools/goctl/api/spec"
"github.com/zeromicro/go-zero/tools/goctl/internal/version" "github.com/zeromicro/go-zero/tools/goctl/internal/version"
"github.com/zeromicro/go-zero/tools/goctl/util"
) )
func spec2Swagger(api *apiSpec.ApiSpec) (*spec.Swagger, error) { func spec2Swagger(api *apiSpec.ApiSpec) (*spec.Swagger, error) {
ctx := contextFromApi(api.Info)
extensions, info := specExtensions(api.Info) extensions, info := specExtensions(api.Info)
var securityDefinitions spec.SecurityDefinitions var securityDefinitions spec.SecurityDefinitions
securityDefinitionsFromJson := getStringFromKVOrDefault(api.Info.Properties, "securityDefinitionsFromJson", `{}`) securityDefinitionsFromJson := getStringFromKVOrDefault(api.Info.Properties, "securityDefinitionsFromJson", `{}`)
_ = json.Unmarshal([]byte(securityDefinitionsFromJson), &securityDefinitions) _ = json.Unmarshal([]byte(securityDefinitionsFromJson), &securityDefinitions)
@@ -22,14 +21,15 @@ func spec2Swagger(api *apiSpec.ApiSpec) (*spec.Swagger, error) {
Extensions: extensions, Extensions: extensions,
}, },
SwaggerProps: spec.SwaggerProps{ SwaggerProps: spec.SwaggerProps{
Consumes: getListFromInfoOrDefault(api.Info.Properties, "consumes", []string{applicationJson}), Definitions: definitionsFromTypes(ctx, api.Types),
Produces: getListFromInfoOrDefault(api.Info.Properties, "produces", []string{applicationJson}), Consumes: getListFromInfoOrDefault(api.Info.Properties, propertyKeyConsumes, []string{applicationJson}),
Schemes: getListFromInfoOrDefault(api.Info.Properties, "schemes", []string{schemeHttps}), Produces: getListFromInfoOrDefault(api.Info.Properties, propertyKeyProduces, []string{applicationJson}),
Schemes: getListFromInfoOrDefault(api.Info.Properties, propertyKeySchemes, []string{schemeHttps}),
Swagger: swaggerVersion, Swagger: swaggerVersion,
Info: info, Info: info,
Host: getStringFromKVOrDefault(api.Info.Properties, "host", defaultHost), Host: getStringFromKVOrDefault(api.Info.Properties, propertyKeyHost, ""),
BasePath: getStringFromKVOrDefault(api.Info.Properties, "basePath", defaultBasePath), BasePath: getStringFromKVOrDefault(api.Info.Properties, propertyKeyBasePath, defaultBasePath),
Paths: spec2Paths(api.Info, api.Service), Paths: spec2Paths(ctx, api.Service),
SecurityDefinitions: securityDefinitions, SecurityDefinitions: securityDefinitions,
}, },
} }
@@ -42,7 +42,7 @@ func formatComment(comment string) string {
return strings.TrimSpace(s) return strings.TrimSpace(s)
} }
func sampleItemsFromGoType(tp apiSpec.Type) *spec.Items { func sampleItemsFromGoType(ctx Context, tp apiSpec.Type) *spec.Items {
val, ok := tp.(apiSpec.ArrayType) val, ok := tp.(apiSpec.ArrayType)
if !ok { if !ok {
return nil return nil
@@ -52,14 +52,14 @@ func sampleItemsFromGoType(tp apiSpec.Type) *spec.Items {
case apiSpec.PrimitiveType: case apiSpec.PrimitiveType:
return &spec.Items{ return &spec.Items{
SimpleSchema: spec.SimpleSchema{ SimpleSchema: spec.SimpleSchema{
Type: sampleTypeFromGoType(item), Type: sampleTypeFromGoType(ctx, item),
}, },
} }
case apiSpec.ArrayType: case apiSpec.ArrayType:
return &spec.Items{ return &spec.Items{
SimpleSchema: spec.SimpleSchema{ SimpleSchema: spec.SimpleSchema{
Type: sampleTypeFromGoType(item), Type: sampleTypeFromGoType(ctx, item),
Items: sampleItemsFromGoType(item), Items: sampleItemsFromGoType(ctx, item),
}, },
} }
default: // unsupported type default: // unsupported type
@@ -68,30 +68,30 @@ func sampleItemsFromGoType(tp apiSpec.Type) *spec.Items {
} }
// itemsFromGoType returns the schema or array of the type, just for non json body parameters. // itemsFromGoType returns the schema or array of the type, just for non json body parameters.
func itemsFromGoType(tp apiSpec.Type) *spec.SchemaOrArray { func itemsFromGoType(ctx Context, tp apiSpec.Type) *spec.SchemaOrArray {
array, ok := tp.(apiSpec.ArrayType) array, ok := tp.(apiSpec.ArrayType)
if !ok { if !ok {
return nil return nil
} }
return itemFromGoType(array.Value) return itemFromGoType(ctx, array.Value)
} }
func mapFromGoType(tp apiSpec.Type) *spec.SchemaOrBool { func mapFromGoType(ctx Context, tp apiSpec.Type) *spec.SchemaOrBool {
mapType, ok := tp.(apiSpec.MapType) mapType, ok := tp.(apiSpec.MapType)
if !ok { if !ok {
return nil return nil
} }
var schema = &spec.Schema{ var schema = &spec.Schema{
SchemaProps: spec.SchemaProps{ SchemaProps: spec.SchemaProps{
Type: typeFromGoType(mapType.Value), Type: typeFromGoType(ctx, mapType.Value),
AdditionalProperties: mapFromGoType(mapType.Value), AdditionalProperties: mapFromGoType(ctx, mapType.Value),
}, },
} }
switch sampleTypeFromGoType(mapType.Value) { switch sampleTypeFromGoType(ctx, mapType.Value) {
case swaggerTypeArray: case swaggerTypeArray:
schema.Items = itemsFromGoType(mapType.Value) schema.Items = itemsFromGoType(ctx, mapType.Value)
case swaggerTypeObject: case swaggerTypeObject:
p, r := propertiesFromType(mapType.Value) p, r := propertiesFromType(ctx, mapType.Value)
schema.Properties = p schema.Properties = p
schema.Required = r schema.Required = r
} }
@@ -102,37 +102,37 @@ func mapFromGoType(tp apiSpec.Type) *spec.SchemaOrBool {
} }
// itemFromGoType returns the schema or array of the type, just for non json body parameters. // itemFromGoType returns the schema or array of the type, just for non json body parameters.
func itemFromGoType(tp apiSpec.Type) *spec.SchemaOrArray { func itemFromGoType(ctx Context, tp apiSpec.Type) *spec.SchemaOrArray {
switch itemType := tp.(type) { switch itemType := tp.(type) {
case apiSpec.PrimitiveType: case apiSpec.PrimitiveType:
return &spec.SchemaOrArray{ return &spec.SchemaOrArray{
Schema: &spec.Schema{ Schema: &spec.Schema{
SchemaProps: spec.SchemaProps{ SchemaProps: spec.SchemaProps{
Type: typeFromGoType(tp), Type: typeFromGoType(ctx, tp),
}, },
}, },
} }
case apiSpec.DefineStruct, apiSpec.NestedStruct: case apiSpec.DefineStruct, apiSpec.NestedStruct, apiSpec.MapType:
properties, requiredFields := propertiesFromType(itemType) properties, requiredFields := propertiesFromType(ctx, itemType)
return &spec.SchemaOrArray{ return &spec.SchemaOrArray{
Schema: &spec.Schema{ Schema: &spec.Schema{
SchemaProps: spec.SchemaProps{ SchemaProps: spec.SchemaProps{
Type: typeFromGoType(itemType), Type: typeFromGoType(ctx, itemType),
Items: itemsFromGoType(itemType), Items: itemsFromGoType(ctx, itemType),
Properties: properties, Properties: properties,
Required: requiredFields, Required: requiredFields,
AdditionalProperties: mapFromGoType(itemType), AdditionalProperties: mapFromGoType(ctx, itemType),
}, },
}, },
} }
case apiSpec.PointerType: case apiSpec.PointerType:
return itemFromGoType(itemType.Type) return itemFromGoType(ctx, itemType.Type)
case apiSpec.ArrayType: case apiSpec.ArrayType:
return &spec.SchemaOrArray{ return &spec.SchemaOrArray{
Schema: &spec.Schema{ Schema: &spec.Schema{
SchemaProps: spec.SchemaProps{ SchemaProps: spec.SchemaProps{
Type: typeFromGoType(itemType), Type: typeFromGoType(ctx, itemType),
Items: itemsFromGoType(itemType), Items: itemsFromGoType(ctx, itemType),
}, },
}, },
} }
@@ -140,7 +140,7 @@ func itemFromGoType(tp apiSpec.Type) *spec.SchemaOrArray {
return nil return nil
} }
func typeFromGoType(tp apiSpec.Type) []string { func typeFromGoType(ctx Context, tp apiSpec.Type) []string {
switch val := tp.(type) { switch val := tp.(type) {
case apiSpec.PrimitiveType: case apiSpec.PrimitiveType:
res, ok := tpMapper[val.RawName] res, ok := tpMapper[val.RawName]
@@ -152,12 +152,12 @@ func typeFromGoType(tp apiSpec.Type) []string {
case apiSpec.DefineStruct, apiSpec.MapType: case apiSpec.DefineStruct, apiSpec.MapType:
return []string{swaggerTypeObject} return []string{swaggerTypeObject}
case apiSpec.PointerType: case apiSpec.PointerType:
return typeFromGoType(val.Type) return typeFromGoType(ctx, val.Type)
} }
return nil return nil
} }
func sampleTypeFromGoType(tp apiSpec.Type) string { func sampleTypeFromGoType(ctx Context, tp apiSpec.Type) string {
switch val := tp.(type) { switch val := tp.(type) {
case apiSpec.PrimitiveType: case apiSpec.PrimitiveType:
return tpMapper[val.RawName] return tpMapper[val.RawName]
@@ -166,31 +166,30 @@ func sampleTypeFromGoType(tp apiSpec.Type) string {
case apiSpec.DefineStruct, apiSpec.MapType, apiSpec.NestedStruct: case apiSpec.DefineStruct, apiSpec.MapType, apiSpec.NestedStruct:
return swaggerTypeObject return swaggerTypeObject
case apiSpec.PointerType: case apiSpec.PointerType:
return sampleTypeFromGoType(val.Type) return sampleTypeFromGoType(ctx, val.Type)
default: default:
return "" return ""
} }
} }
func typeContainsTag(structType apiSpec.DefineStruct, tag string) bool { func typeContainsTag(ctx Context, structType apiSpec.DefineStruct, tag string) bool {
for _, field := range structType.Members { members := expandMembers(ctx, structType)
tags, _ := apiSpec.Parse(field.Tag) for _, member := range members {
for _, t := range tags.Tags() { tags, _ := apiSpec.Parse(member.Tag)
if t.Key == tag { if _, err := tags.Get(tag); err == nil {
return true return true
}
} }
} }
return false return false
} }
func expandMembers(tp apiSpec.Type) []apiSpec.Member { func expandMembers(ctx Context, tp apiSpec.Type) []apiSpec.Member {
var members []apiSpec.Member var members []apiSpec.Member
switch val := tp.(type) { switch val := tp.(type) {
case apiSpec.DefineStruct: case apiSpec.DefineStruct:
for _, v := range val.Members { for _, v := range val.Members {
if v.IsInline { if v.IsInline {
members = append(members, expandMembers(v.Type)...) members = append(members, expandMembers(ctx, v.Type)...)
continue continue
} }
members = append(members, v) members = append(members, v)
@@ -198,7 +197,7 @@ func expandMembers(tp apiSpec.Type) []apiSpec.Member {
case apiSpec.NestedStruct: case apiSpec.NestedStruct:
for _, v := range val.Members { for _, v := range val.Members {
if v.IsInline { if v.IsInline {
members = append(members, expandMembers(v.Type)...) members = append(members, expandMembers(ctx, v.Type)...)
continue continue
} }
members = append(members, v) members = append(members, v)
@@ -208,42 +207,42 @@ func expandMembers(tp apiSpec.Type) []apiSpec.Member {
return members return members
} }
func rangeMemberAndDo(structType apiSpec.Type, do func(tag *apiSpec.Tags, required bool, member apiSpec.Member)) { func rangeMemberAndDo(ctx Context, structType apiSpec.Type, do func(tag *apiSpec.Tags, required bool, member apiSpec.Member)) {
var members = expandMembers(structType) var members = expandMembers(ctx, structType)
for _, field := range members { for _, field := range members {
tags, _ := apiSpec.Parse(field.Tag) tags, _ := apiSpec.Parse(field.Tag)
required := isRequired(tags) required := isRequired(ctx, tags)
do(tags, required, field) do(tags, required, field)
} }
} }
func isRequired(tags *apiSpec.Tags) bool { func isRequired(ctx Context, tags *apiSpec.Tags) bool {
tag, err := tags.Get(tagJson) tag, err := tags.Get(tagJson)
if err == nil { if err == nil {
return !isOptional(tag.Options) return !isOptional(ctx, tag.Options)
} }
tag, err = tags.Get(tagForm) tag, err = tags.Get(tagForm)
if err == nil { if err == nil {
return !isOptional(tag.Options) return !isOptional(ctx, tag.Options)
} }
tag, err = tags.Get(tagPath) tag, err = tags.Get(tagPath)
if err == nil { if err == nil {
return !isOptional(tag.Options) return !isOptional(ctx, tag.Options)
} }
return false return false
} }
func isOptional(options []string) bool { func isOptional(_ Context, options []string) bool {
for _, option := range options { for _, option := range options {
if option == "optional" { if option == optionalFlag {
return true return true
} }
} }
return false return false
} }
func pathVariable2SwaggerVariable(path string) string { func pathVariable2SwaggerVariable(_ Context, path string) string {
pathItems := strings.FieldsFunc(path, slashRune) pathItems := strings.FieldsFunc(path, slashRune)
var resp []string var resp []string
for _, v := range pathItems { for _, v := range pathItems {
@@ -256,13 +255,12 @@ func pathVariable2SwaggerVariable(path string) string {
return "/" + strings.Join(resp, "/") return "/" + strings.Join(resp, "/")
} }
func wrapCodeMsgProps(properties spec.SchemaProps, api apiSpec.Info, atDoc apiSpec.AtDoc) spec.SchemaProps { func wrapCodeMsgProps(ctx Context, properties spec.SchemaProps, atDoc apiSpec.AtDoc) spec.SchemaProps {
wrapCodeMsg := getBoolFromKVOrDefault(api.Properties, "wrapCodeMsg", false) if !ctx.WrapCodeMsg {
if !wrapCodeMsg {
return properties return properties
} }
globalCodeDesc := getStringFromKVOrDefault(api.Properties, "bizCodeEnumDescription", "business code") globalCodeDesc := ctx.BizCodeEnumDescription
methodCodeDesc := getStringFromKVOrDefault(atDoc.Properties, "bizCodeEnumDescription", globalCodeDesc) methodCodeDesc := getStringFromKVOrDefault(atDoc.Properties, propertyKeyBizCodeEnumDescription, globalCodeDesc)
return spec.SchemaProps{ return spec.SchemaProps{
Type: []string{swaggerTypeObject}, Type: []string{swaggerTypeObject},
Properties: spec.SchemaProperties{ Properties: spec.SchemaProperties{
@@ -295,27 +293,27 @@ func specExtensions(api apiSpec.Info) (spec.Extensions, *spec.Info) {
ext := spec.Extensions{} ext := spec.Extensions{}
ext.Add("x-goctl-version", version.BuildVersion) ext.Add("x-goctl-version", version.BuildVersion)
ext.Add("x-description", "This is a goctl generated swagger file.") ext.Add("x-description", "This is a goctl generated swagger file.")
ext.Add("x-date", time.Now().Format("2006-01-02 15:04:05")) ext.Add("x-date", time.Now().Format(time.DateTime))
ext.Add("x-github", "https://github.com/zeromicro/go-zero") ext.Add("x-github", "https://github.com/zeromicro/go-zero")
ext.Add("x-go-zero-doc", "https://go-zero.dev/") ext.Add("x-go-zero-doc", "https://go-zero.dev/")
info := &spec.Info{} info := &spec.Info{}
info.Description = util.Unquote(api.Properties["description"]) info.Title = getStringFromKVOrDefault(api.Properties, propertyKeyTitle, "")
info.Title = util.Unquote(api.Properties["title"]) info.Description = getStringFromKVOrDefault(api.Properties, propertyKeyDescription, "")
info.TermsOfService = util.Unquote(api.Properties["termsOfService"]) info.TermsOfService = getStringFromKVOrDefault(api.Properties, propertyKeyTermsOfService, "")
info.Version = util.Unquote(api.Properties["version"]) info.Version = getStringFromKVOrDefault(api.Properties, propertyKeyVersion, "1.0")
contactInfo := spec.ContactInfo{} contactInfo := spec.ContactInfo{}
contactInfo.Name = util.Unquote(api.Properties["contactName"]) contactInfo.Name = getStringFromKVOrDefault(api.Properties, propertyKeyContactName, "")
contactInfo.URL = util.Unquote(api.Properties["contactURL"]) contactInfo.URL = getStringFromKVOrDefault(api.Properties, propertyKeyContactURL, "")
contactInfo.Email = util.Unquote(api.Properties["contactEmail"]) contactInfo.Email = getStringFromKVOrDefault(api.Properties, propertyKeyContactEmail, "")
if len(contactInfo.Name) > 0 || len(contactInfo.URL) > 0 || len(contactInfo.Email) > 0 { if len(contactInfo.Name) > 0 || len(contactInfo.URL) > 0 || len(contactInfo.Email) > 0 {
info.Contact = &contactInfo info.Contact = &contactInfo
} }
license := &spec.License{} license := &spec.License{}
license.Name = util.Unquote(api.Properties["licenseName"]) license.Name = getStringFromKVOrDefault(api.Properties, propertyKeyLicenseName, "")
license.URL = util.Unquote(api.Properties["licenseURL"]) license.URL = getStringFromKVOrDefault(api.Properties, propertyKeyLicenseURL, "")
if len(license.Name) > 0 || len(license.URL) > 0 { if len(license.Name) > 0 || len(license.URL) > 0 {
info.License = license info.License = license
} }

View File

@@ -19,7 +19,7 @@ func Test_pathVariable2SwaggerVariable(t *testing.T) {
} }
for _, tc := range testCases { for _, tc := range testCases {
result := pathVariable2SwaggerVariable(tc.input) result := pathVariable2SwaggerVariable(testingContext(t), tc.input)
assert.Equal(t, tc.expected, result) assert.Equal(t, tc.expected, result)
} }
} }

2
tools/goctl/build.env Normal file
View File

@@ -0,0 +1,2 @@
APP_NAME=goctl
APP_VERSION=1.8.4-beta

50
tools/goctl/build.sh Normal file
View File

@@ -0,0 +1,50 @@
#!/bin/bash
source build.env
APP_NAME=$APP_NAME
VERSION=$APP_VERSION
BUILD_DIR="dist"
ZIP_DIR="${BUILD_DIR}/zips"
PLATFORMS=(
"linux/amd64"
"linux/arm64"
"darwin/amd64"
"darwin/arm64"
"windows/amd64"
"windows/arm64"
)
rm -rf "${BUILD_DIR}"
mkdir -p "${ZIP_DIR}"
for PLATFORM in "${PLATFORMS[@]}"; do
GOOS=${PLATFORM%/*}
GOARCH=${PLATFORM#*/}
OUTPUT="${BUILD_DIR}/${APP_NAME}-${VERSION}-${GOOS}-${GOARCH}"
if [ "${GOOS}" = "windows" ]; then
OUTPUT="${OUTPUT}.exe"
fi
echo "Building for ${GOOS}/${GOARCH}..."
env GOOS="${GOOS}" GOARCH="${GOARCH}" go build -o "${OUTPUT}" goctl.go
if [ $? -ne 0 ]; then
echo "Error building for ${GOOS}/${GOARCH}"
exit 1
fi
ZIP_OUTPUT="${ZIP_DIR}/$(basename "${OUTPUT}")"
if [ "${GOOS}" = "windows" ]; then
zip -j "${ZIP_OUTPUT%.exe}.zip" "${OUTPUT}"
else
zip -j "${ZIP_OUTPUT}.zip" "${OUTPUT}"
fi
echo "Created zip: ${ZIP_OUTPUT}.zip"
done
echo "All builds completed successfully. Zip files are in ${ZIP_DIR}/"

103
tools/goctl/change.md Normal file
View File

@@ -0,0 +1,103 @@
# 1.8.4-beta
## swagger
- [features] Supported operation id for swagger
## Other
- Updated version to 1.8.4-beta
# 1.8.4-alpha
## swagger
1. [bug fix] remove example generation when request body are `query`, `path` and `header`
- it not supported in api spec 2.0
- it's will generate example when request body is json format.
2. [features] swagger generation supported definitions
- supported response definitions
- supported json request body definitions
- do not support query and form definitions, use parameters instead.
**How to use?**
Use the `useDefinitions` keyword in the info code block of the API file to declare the enable. This value is a boolean value. When set to `true`, it will enable the generation of definitions. Otherwise, it will be generated according to properties, and the default is `false`, for example:
```go
syntax = "v1"
info (
...
wrapCodeMsg: true
useDefinitions: true
)
...
```
the demo result of swagger.json
```json
{
...
"responses": {
"200": {
"description": "",
"schema": {
"type": "object",
"properties": {
"code": {
"description": "1001-User not login\u003cbr\u003e1002-User permission denied",
"type": "integer",
"example": 0
},
"data": {
"$ref": "#/definitions/FormResp"
},
"msg": {
"description": "business message",
"type": "string",
"example": "ok"
}
}
}
}
}
...
}
```
For a complete API example, please refer to the `api/swagger/example/example.api` file in pr. For a complete swagger result example, please refer to the `api/swagger/example/example.swagger.json` file in pr.
## 2. `goctl api go` code generation
- [bug-fix] Add flag `--type-group` to control the output of types(deprecated: experimental switch control type grouping is no longer used), if true, the types in only one group will separate by file.
- example `goctl api go --api demo.api --dir demo --type-group`
- use `group` keyword in @server block to define the group name in api file, for example
```go
@server(
group: user
)
service demo{
...
}
```
the example of separated types by file
```
.
└── types
├── common.go
├── gotoolexport.go
├── importfile.go
├── process.go
└── types.go
```
## 3 API Parser
- supported identifier value for info key-value in api parser
for example
```
syntax = "v1"
info(
enable: true
disable: false
)
...
```

View File

@@ -4,7 +4,7 @@ go 1.21
require ( require (
github.com/DATA-DOG/go-sqlmock v1.5.2 github.com/DATA-DOG/go-sqlmock v1.5.2
github.com/emicklei/proto v1.14.1 github.com/emicklei/proto v1.14.2
github.com/fatih/structtag v1.2.0 github.com/fatih/structtag v1.2.0
github.com/go-openapi/spec v0.21.1-0.20250328170532-a3928469592e github.com/go-openapi/spec v0.21.1-0.20250328170532-a3928469592e
github.com/go-sql-driver/mysql v1.9.0 github.com/go-sql-driver/mysql v1.9.0
@@ -16,7 +16,7 @@ require (
github.com/withfig/autocomplete-tools/integrations/cobra v1.2.1 github.com/withfig/autocomplete-tools/integrations/cobra v1.2.1
github.com/zeromicro/antlr v0.0.1 github.com/zeromicro/antlr v0.0.1
github.com/zeromicro/ddl-parser v1.0.5 github.com/zeromicro/ddl-parser v1.0.5
github.com/zeromicro/go-zero v1.8.2 github.com/zeromicro/go-zero v1.8.4
golang.org/x/text v0.22.0 golang.org/x/text v0.22.0
google.golang.org/grpc v1.65.0 google.golang.org/grpc v1.65.0
google.golang.org/protobuf v1.36.5 google.golang.org/protobuf v1.36.5
@@ -25,8 +25,7 @@ require (
require ( require (
filippo.io/edwards25519 v1.1.0 // indirect filippo.io/edwards25519 v1.1.0 // indirect
github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 // indirect github.com/alicebob/miniredis/v2 v2.35.0 // indirect
github.com/alicebob/miniredis/v2 v2.34.0 // indirect
github.com/antlr/antlr4/runtime/Go/antlr v0.0.0-20210521184019-c5ad59b459ec // indirect github.com/antlr/antlr4/runtime/Go/antlr v0.0.0-20210521184019-c5ad59b459ec // indirect
github.com/beorn7/perks v1.0.1 // indirect github.com/beorn7/perks v1.0.1 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect
@@ -49,6 +48,8 @@ require (
github.com/google/go-cmp v0.6.0 // indirect github.com/google/go-cmp v0.6.0 // indirect
github.com/google/gofuzz v1.2.0 // indirect github.com/google/gofuzz v1.2.0 // indirect
github.com/google/uuid v1.6.0 // indirect github.com/google/uuid v1.6.0 // indirect
github.com/grafana/pyroscope-go v1.2.2 // indirect
github.com/grafana/pyroscope-go/godeltaprof v0.1.8 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect
@@ -72,7 +73,7 @@ require (
github.com/prometheus/client_model v0.6.1 // indirect github.com/prometheus/client_model v0.6.1 // indirect
github.com/prometheus/common v0.62.0 // indirect github.com/prometheus/common v0.62.0 // indirect
github.com/prometheus/procfs v0.15.1 // indirect github.com/prometheus/procfs v0.15.1 // indirect
github.com/redis/go-redis/v9 v9.7.3 // indirect github.com/redis/go-redis/v9 v9.10.0 // indirect
github.com/spaolacci/murmur3 v1.1.0 // indirect github.com/spaolacci/murmur3 v1.1.0 // indirect
github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778 // indirect github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778 // indirect
github.com/yuin/gopher-lua v1.1.1 // indirect github.com/yuin/gopher-lua v1.1.1 // indirect

View File

@@ -2,10 +2,8 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 h1:uvdUDbHQHO85qeSydJtItA4T55Pw6BtAejd0APRJOCE= github.com/alicebob/miniredis/v2 v2.35.0 h1:QwLphYqCEAo1eu1TqPRN2jgVMPBweeQcR21jeqDCONI=
github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc= github.com/alicebob/miniredis/v2 v2.35.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM=
github.com/alicebob/miniredis/v2 v2.34.0 h1:mBFWMaJSNL9RwdGRyEDoAAv8OQc5UlEhLDQggTglU/0=
github.com/alicebob/miniredis/v2 v2.34.0/go.mod h1:kWShP4b58T1CW0Y5dViCd5ztzrDqRWqM3nksiyXk5s8=
github.com/antlr/antlr4/runtime/Go/antlr v0.0.0-20210521184019-c5ad59b459ec h1:EEyRvzmpEUZ+I8WmD5cw/vY8EqhambkOqy5iFr0908A= github.com/antlr/antlr4/runtime/Go/antlr v0.0.0-20210521184019-c5ad59b459ec h1:EEyRvzmpEUZ+I8WmD5cw/vY8EqhambkOqy5iFr0908A=
github.com/antlr/antlr4/runtime/Go/antlr v0.0.0-20210521184019-c5ad59b459ec/go.mod h1:F7bn7fEU90QkQ3tnmaTx3LTKLEDqnwWODIYppRQ5hnY= github.com/antlr/antlr4/runtime/Go/antlr v0.0.0-20210521184019-c5ad59b459ec/go.mod h1:F7bn7fEU90QkQ3tnmaTx3LTKLEDqnwWODIYppRQ5hnY=
github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8=
@@ -32,8 +30,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/emicklei/go-restful/v3 v3.11.0 h1:rAQeMHw1c7zTmncogyy8VvRZwtkmkZ4FxERmMY4rD+g= github.com/emicklei/go-restful/v3 v3.11.0 h1:rAQeMHw1c7zTmncogyy8VvRZwtkmkZ4FxERmMY4rD+g=
github.com/emicklei/go-restful/v3 v3.11.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= github.com/emicklei/go-restful/v3 v3.11.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc=
github.com/emicklei/proto v1.14.1 h1:fFq+Bj70XXZWXWikcVRvYZxrMS4KIIiPAqdJ8vPrenY= github.com/emicklei/proto v1.14.2 h1:wJPxPy2Xifja9cEMrcA/g08art5+7CGJNFNk35iXC1I=
github.com/emicklei/proto v1.14.1/go.mod h1:rn1FgRS/FANiZdD2djyH7TMA9jdRDcYQ9IEN9yvjX0A= github.com/emicklei/proto v1.14.2/go.mod h1:rn1FgRS/FANiZdD2djyH7TMA9jdRDcYQ9IEN9yvjX0A=
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU=
github.com/fatih/structtag v1.2.0 h1:/OdNE99OxoI/PqaW/SuSK9uxxT3f/tcSZgon/ssNSx4= github.com/fatih/structtag v1.2.0 h1:/OdNE99OxoI/PqaW/SuSK9uxxT3f/tcSZgon/ssNSx4=
@@ -77,6 +75,10 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gookit/color v1.5.4 h1:FZmqs7XOyGgCAxmWyPslpiok1k05wmY3SJTytgvYFs0= github.com/gookit/color v1.5.4 h1:FZmqs7XOyGgCAxmWyPslpiok1k05wmY3SJTytgvYFs0=
github.com/gookit/color v1.5.4/go.mod h1:pZJOeOS8DM43rXbp4AZo1n9zCU2qjpcRko0b6/QJi9w= github.com/gookit/color v1.5.4/go.mod h1:pZJOeOS8DM43rXbp4AZo1n9zCU2qjpcRko0b6/QJi9w=
github.com/grafana/pyroscope-go v1.2.2 h1:uvKCyZMD724RkaCEMrSTC38Yn7AnFe8S2wiAIYdDPCE=
github.com/grafana/pyroscope-go v1.2.2/go.mod h1:zzT9QXQAp2Iz2ZdS216UiV8y9uXJYQiGE1q8v1FyhqU=
github.com/grafana/pyroscope-go/godeltaprof v0.1.8 h1:iwOtYXeeVSAeYefJNaxDytgjKtUuKQbJqgAIjlnicKg=
github.com/grafana/pyroscope-go/godeltaprof v0.1.8/go.mod h1:2+l7K7twW49Ct4wFluZD3tZ6e0SjanjcUUBPVD/UuGU=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 h1:bkypFPDjIYGfCYD5mRBvpqxfYX1YCS1PXdKYWi8FsN0= github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 h1:bkypFPDjIYGfCYD5mRBvpqxfYX1YCS1PXdKYWi8FsN0=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0/go.mod h1:P+Lt/0by1T8bfcF3z737NnSbmxQAppXMRziHUxPOC8k= github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0/go.mod h1:P+Lt/0by1T8bfcF3z737NnSbmxQAppXMRziHUxPOC8k=
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw= github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw=
@@ -146,8 +148,8 @@ github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I= github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM= github.com/redis/go-redis/v9 v9.10.0 h1:FxwK3eV8p/CQa0Ch276C7u2d0eNC9kCmAYQ7mCXCzVs=
github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA= github.com/redis/go-redis/v9 v9.10.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
@@ -183,8 +185,8 @@ github.com/zeromicro/antlr v0.0.1 h1:CQpIn/dc0pUjgGQ81y98s/NGOm2Hfru2NNio2I9mQgk
github.com/zeromicro/antlr v0.0.1/go.mod h1:nfpjEwFR6Q4xGDJMcZnCL9tEfQRgszMwu3rDz2Z+p5M= github.com/zeromicro/antlr v0.0.1/go.mod h1:nfpjEwFR6Q4xGDJMcZnCL9tEfQRgszMwu3rDz2Z+p5M=
github.com/zeromicro/ddl-parser v1.0.5 h1:LaVqHdzMTjasua1yYpIYaksxKqRzFrEukj2Wi2EbWaQ= github.com/zeromicro/ddl-parser v1.0.5 h1:LaVqHdzMTjasua1yYpIYaksxKqRzFrEukj2Wi2EbWaQ=
github.com/zeromicro/ddl-parser v1.0.5/go.mod h1:ISU/8NuPyEpl9pa17Py9TBPetMjtsiHrb9f5XGiYbo8= github.com/zeromicro/ddl-parser v1.0.5/go.mod h1:ISU/8NuPyEpl9pa17Py9TBPetMjtsiHrb9f5XGiYbo8=
github.com/zeromicro/go-zero v1.8.2 h1:AbJckBoojbr1lqCN1dkvURTIHOau7yvKReEd7ZmjuCk= github.com/zeromicro/go-zero v1.8.4 h1:3s7kOoThCnkDoqCafsqSX58Y9osYTBIa5QEmomw07TE=
github.com/zeromicro/go-zero v1.8.2/go.mod h1:G5dF+jzCEuq0t1j8qdrtVAy30QMgctGcKSfqFIGsvSg= github.com/zeromicro/go-zero v1.8.4/go.mod h1:eM5f6If/RF+jG1wSCmlvfXD2h2l23vJwETI8oDpjYt4=
go.etcd.io/etcd/api/v3 v3.5.15 h1:3KpLJir1ZEBrYuV2v+Twaa/e2MdDCEZ/70H+lzEiwsk= go.etcd.io/etcd/api/v3 v3.5.15 h1:3KpLJir1ZEBrYuV2v+Twaa/e2MdDCEZ/70H+lzEiwsk=
go.etcd.io/etcd/api/v3 v3.5.15/go.mod h1:N9EhGzXq58WuMllgH9ZvnEr7SI9pS0k0+DHZezGp7jM= go.etcd.io/etcd/api/v3 v3.5.15/go.mod h1:N9EhGzXq58WuMllgH9ZvnEr7SI9pS0k0+DHZezGp7jM=
go.etcd.io/etcd/client/pkg/v3 v3.5.15 h1:fo0HpWz/KlHGMCC+YejpiCmyWDEuIpnTDzpJLB5fWlA= go.etcd.io/etcd/client/pkg/v3 v3.5.15 h1:fo0HpWz/KlHGMCC+YejpiCmyWDEuIpnTDzpJLB5fWlA=

View File

@@ -38,7 +38,8 @@
"remote": "{{.global.remote}}", "remote": "{{.global.remote}}",
"branch": "{{.global.branch}}", "branch": "{{.global.branch}}",
"style": "{{.global.style}}", "style": "{{.global.style}}",
"test": "Generate test files" "test": "Generate test files",
"type-group": "Generate type group files"
}, },
"new": { "new": {
"short": "Fast create api service", "short": "Fast create api service",
@@ -205,6 +206,7 @@
"short": "Generate mongo model", "short": "Generate mongo model",
"type": "Specified model type name", "type": "Specified model type name",
"cache": "Generate code with cache [optional]", "cache": "Generate code with cache [optional]",
"prefix": "Generate code with cache prefix [optional]",
"easy": "Generate code with auto generated CollectionName for easy declare [optional]", "easy": "Generate code with auto generated CollectionName for easy declare [optional]",
"dir": "{{.goctl.model.dir}}", "dir": "{{.goctl.model.dir}}",
"style": "{{.global.style}}", "style": "{{.global.style}}",

View File

@@ -6,7 +6,7 @@ import (
) )
// BuildVersion is the version of goctl. // BuildVersion is the version of goctl.
const BuildVersion = "1.8.3" const BuildVersion = "1.8.4"
var tag = map[string]int{"pre-alpha": 0, "alpha": 1, "pre-bata": 2, "beta": 3, "released": 4, "": 5} var tag = map[string]int{"pre-alpha": 0, "alpha": 1, "pre-bata": 2, "beta": 3, "released": 4, "": 5}

View File

@@ -4,9 +4,9 @@ import (
"errors" "errors"
"fmt" "fmt"
"os" "os"
"slices"
"time" "time"
"github.com/zeromicro/go-zero/core/stringx"
"github.com/zeromicro/go-zero/tools/goctl/rpc/execx" "github.com/zeromicro/go-zero/tools/goctl/rpc/execx"
"github.com/zeromicro/go-zero/tools/goctl/util/console" "github.com/zeromicro/go-zero/tools/goctl/util/console"
"github.com/zeromicro/go-zero/tools/goctl/util/ctx" "github.com/zeromicro/go-zero/tools/goctl/util/ctx"
@@ -37,7 +37,7 @@ func editMod(version string, verbose bool) error {
return err return err
} }
if !stringx.Contains(latest, version) { if !slices.Contains(latest, version) {
return fmt.Errorf("release version %q is not found", version) return fmt.Errorf("release version %q is not found", version)
} }

View File

@@ -62,6 +62,7 @@ func init() {
mongoCmdFlags.StringSliceVarP(&mongo.VarStringSliceType, "type", "t") mongoCmdFlags.StringSliceVarP(&mongo.VarStringSliceType, "type", "t")
mongoCmdFlags.BoolVarP(&mongo.VarBoolCache, "cache", "c") mongoCmdFlags.BoolVarP(&mongo.VarBoolCache, "cache", "c")
mongoCmdFlags.StringVarP(&mongo.VarStringPrefix, "prefix", "p")
mongoCmdFlags.BoolVarP(&mongo.VarBoolEasy, "easy", "e") mongoCmdFlags.BoolVarP(&mongo.VarBoolEasy, "easy", "e")
mongoCmdFlags.StringVarP(&mongo.VarStringDir, "dir", "d") mongoCmdFlags.StringVarP(&mongo.VarStringDir, "dir", "d")
mongoCmdFlags.StringVar(&mongo.VarStringStyle, "style") mongoCmdFlags.StringVar(&mongo.VarStringStyle, "style")

View File

@@ -17,6 +17,7 @@ import (
type Context struct { type Context struct {
Types []string Types []string
Cache bool Cache bool
Prefix string
Easy bool Easy bool
Output string Output string
Cfg *config.Config Cfg *config.Config
@@ -60,6 +61,7 @@ func generateModel(ctx *Context) error {
"Type": stringx.From(t).Title(), "Type": stringx.From(t).Title(),
"lowerType": stringx.From(t).Untitle(), "lowerType": stringx.From(t).Untitle(),
"Cache": ctx.Cache, "Cache": ctx.Cache,
"Prefix": ctx.Prefix,
"version": version.BuildVersion, "version": version.BuildVersion,
}, output, true); err != nil { }, output, true); err != nil {
return err return err

View File

@@ -19,6 +19,8 @@ var (
VarStringDir string VarStringDir string
// VarBoolCache describes whether cache is enabled. // VarBoolCache describes whether cache is enabled.
VarBoolCache bool VarBoolCache bool
// VarStringPrefix string describes the prefix for the cache key.
VarStringPrefix string
// VarBoolEasy describes whether to generate Collection Name in the code for easy declare. // VarBoolEasy describes whether to generate Collection Name in the code for easy declare.
VarBoolEasy bool VarBoolEasy bool
// VarStringStyle describes the style. // VarStringStyle describes the style.
@@ -35,6 +37,7 @@ var (
func Action(_ *cobra.Command, _ []string) error { func Action(_ *cobra.Command, _ []string) error {
tp := VarStringSliceType tp := VarStringSliceType
c := VarBoolCache c := VarBoolCache
p := VarStringPrefix
easy := VarBoolEasy easy := VarBoolEasy
o := strings.TrimSpace(VarStringDir) o := strings.TrimSpace(VarStringDir)
s := VarStringStyle s := VarStringStyle
@@ -74,6 +77,7 @@ func Action(_ *cobra.Command, _ []string) error {
return generate.Do(&generate.Context{ return generate.Do(&generate.Context{
Types: tp, Types: tp,
Cache: c, Cache: c,
Prefix: p,
Easy: easy, Easy: easy,
Output: a, Output: a,
Cfg: cfg, Cfg: cfg,

View File

@@ -13,7 +13,7 @@ import (
"go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo"
) )
{{if .Cache}}var prefix{{.Type}}CacheKey = "cache:{{.lowerType}}:"{{end}} {{if .Cache}}var prefix{{.Type}}CacheKey = "{{if .Prefix}}{{.Prefix}}:{{end}}cache:{{.lowerType}}:"{{end}}
type {{.lowerType}}Model interface{ type {{.lowerType}}Model interface{
Insert(ctx context.Context,data *{{.Type}}) error Insert(ctx context.Context,data *{{.Type}}) error

View File

@@ -1356,7 +1356,7 @@ func (p *Parser) parseKVExpression() *ast.KVExpr {
expr.Colon = p.curTokenNode() expr.Colon = p.curTokenNode()
// token STRING // token STRING
if !p.advanceIfPeekTokenIs(token.STRING, token.RAW_STRING) { if !p.advanceIfPeekTokenIs(token.STRING, token.RAW_STRING, token.IDENT) {
return nil return nil
} }

View File

@@ -130,6 +130,8 @@ func TestParser_Parse_infoStmt(t *testing.T) {
"author": `"type author here"`, "author": `"type author here"`,
"email": `"type email here"`, "email": `"type email here"`,
"version": `"type version here"`, "version": `"type version here"`,
"enable": `true`,
"disable": `false`,
} }
p := New("foo.api", infoTestAPI) p := New("foo.api", infoTestAPI)
result := p.Parse() result := p.Parse()

View File

@@ -4,4 +4,6 @@ info(
author: "type author here" author: "type author here"
email: "type email here" email: "type email here"
version: "type version here" version: "type version here"
enable: true
disable: false
) )

View File

@@ -10,6 +10,8 @@ info ( // info stmt
author: "type author here" author: "type author here"
email: "type email here" email: "type email here"
version: "type version here" version: "type version here"
enable: true
disable: false
) )
type AliasInt int type AliasInt int

View File

@@ -192,3 +192,131 @@ func TestUntitle(t *testing.T) {
assert.Equal(t, c.want, ret) assert.Equal(t, c.want, ret)
} }
} }
func TestContainsAny(t *testing.T) {
type args struct {
s string
runes []rune
}
tests := []struct {
name string
args args
want bool
}{
{
name: "runes is empty",
args: args{
s: "test",
runes: []rune{},
},
want: true,
},
{
name: "s is empty and runes is not empty",
args: args{
s: "",
runes: []rune{'a', 'b', 'c'},
},
want: false,
},
{
name: "s contains runes",
args: args{
s: "hello",
runes: []rune{'e', 'f'},
},
want: true,
},
{
name: "s does not contain runes",
args: args{
s: "hello",
runes: []rune{'x', 'y'},
},
want: false,
},
{
name: "s and runes both have one matching character",
args: args{
s: "a",
runes: []rune{'a'},
},
want: true,
},
{
name: "s and runes both have one non-matching character",
args: args{
s: "a",
runes: []rune{'b'},
},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equalf(t, tt.want, ContainsAny(tt.args.s, tt.args.runes...), "ContainsAny(%v, %v)", tt.args.s, tt.args.runes)
})
}
}
func TestContainsWhiteSpace(t *testing.T) {
type args struct {
s string
}
tests := []struct {
name string
args args
want bool
}{
{
name: "contains space",
args: args{s: "hello world"},
want: true,
},
{
name: "contains newline",
args: args{s: "hello\nworld"},
want: true,
},
{
name: "contains tab",
args: args{s: "hello\tworld"},
want: true,
},
{
name: "contains form feed",
args: args{s: "hello\fworld"},
want: true,
},
{
name: "contains vertical tab",
args: args{s: "hello\vworld"},
want: true,
},
{
name: "no whitespace",
args: args{s: "helloworld"},
want: false,
},
{
name: "empty string",
args: args{s: ""},
want: false,
},
{
name: "only whitespace",
args: args{s: " \t\n\f\v"},
want: true,
},
{
name: "contains non-standard whitespace",
args: args{s: "hello\u00A0world"},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equalf(t, tt.want, ContainsWhiteSpace(tt.args.s), "ContainsWhiteSpace(%v)", tt.args.s)
})
}
}

View File

@@ -10,7 +10,6 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/lang" "github.com/zeromicro/go-zero/core/lang"
"github.com/zeromicro/go-zero/core/mathx"
"google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver"
) )
@@ -47,7 +46,7 @@ func TestDirectBuilder_Build(t *testing.T) {
}, cc, resolver.BuildOptions{}) }, cc, resolver.BuildOptions{})
assert.NoError(t, err) assert.NoError(t, err)
size := mathx.MinInt(test, subsetSize) size := min(test, subsetSize)
assert.Equal(t, size, len(cc.state.Addresses)) assert.Equal(t, size, len(cc.state.Addresses))
m := make(map[string]lang.PlaceholderType) m := make(map[string]lang.PlaceholderType)
for _, each := range cc.state.Addresses { for _, each := range cc.state.Addresses {