Compare commits

..

35 Commits

Author SHA1 Message Date
Kevin Wan
c9ff6a10d3 feat: support serverless in rest (#5001)
Signed-off-by: kevin <wanjunfeng@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-07-13 00:00:52 +08:00
Kevin Wan
a71e56de52 fix: context key error in sql read write mode (#5000) 2025-07-12 06:58:08 +08:00
Kevin Wan
bae8d4f4c8 chore: refactoring sql read write mode (#4990)
Signed-off-by: kevin <wanjunfeng@gmail.com>
2025-07-11 01:05:55 +08:00
zhoushuguang
8c6266f338 sql read write support (#4976)
Co-authored-by: light.zhou <light.zhou@bkyo.io>
2025-07-09 16:04:56 +00:00
Kevin Wan
95d5b81f44 chore: optimize pr 4979 (#4988)
Signed-off-by: kevin <wanjunfeng@gmail.com>
2025-07-09 23:55:24 +08:00
geekeryy
bca7bbc142 fix: correct duration type comparison in environment variable processing (#4979) 2025-07-09 15:22:27 +00:00
Kevin Wan
df9a52664b fix issue #4986 2025-07-08 13:58:48 +00:00
Kevin Wan
937cf0db96 Update readme-cn.md (#4983) 2025-07-04 11:02:49 +08:00
Kevin Wan
75cebb65f8 fix: timeout 0s not working (#4932)
Signed-off-by: kevin <wanjunfeng@gmail.com>
2025-07-01 17:01:24 +08:00
dependabot[bot]
410f56e73a chore(deps): bump github.com/redis/go-redis/v9 from 9.10.0 to 9.11.0 (#4969)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-06-25 18:35:01 +08:00
dependabot[bot]
017909a3ab chore(deps): bump github.com/emicklei/proto from 1.14.1 to 1.14.2 in /tools/goctl (#4961)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-06-19 15:40:18 +08:00
kesonan
0d31e6c375 (goctl): fix #4943 (#4953) 2025-06-14 15:36:30 +00:00
Kevin Wan
0ba86b1849 chore: add more tests (#4949) 2025-06-13 22:10:08 +08:00
wanwu
4cacc4d9d3 fix: the time.Duration type panics due to numerical values (#4944)
Co-authored-by: sam.yang <sam.yang@yijinin.com>
2025-06-12 15:11:07 +00:00
Eric
a99c14da4a fix: typo of the logic of CpuThreshold in comments (#4942)
Co-authored-by: zhouyy <zhouyy@ickey.cn>
2025-06-12 08:28:44 +00:00
Kevin Wan
985582264a chore: fix warnings (#4940) 2025-06-12 00:04:29 +08:00
Kevin Wan
8364e341e1 chore: update go-zero dep (#4933)
Signed-off-by: kevin <wanjunfeng@gmail.com>
2025-06-09 18:08:20 +08:00
Kevin Wan
0f2b589d4d Revert "fix: api group set timeout: 0s not working." (#4931) 2025-06-08 23:14:38 +08:00
spectatorMrZ
19fec36d24 fix: api group set timeout: 0s not working. (#4785) 2025-06-08 14:50:21 +00:00
Kevin Wan
f037bf344d chore: add more tests (#4930) 2025-06-08 22:08:04 +08:00
MarkJoyMa
d99cf35b07 Feat/continue profiling (#4867)
Co-authored-by: aiden.ma <Aiden.ma@yijinin.com>
Co-authored-by: aiden.ma <aiden.ma@bkyo.io>
2025-06-07 21:12:31 +08:00
Kevin Wan
f459f1b5ff chore: update goctl version (#4929) 2025-06-07 21:01:35 +08:00
Haiwei Zhang
0140fd417b feat(goctl): generate mongo model with cache prefix (#4907) 2025-06-07 12:54:33 +00:00
jaron
7969e0ca38 fix(goctl): Fix getting swagger consume types (#4903) 2025-06-07 12:46:34 +00:00
Kevin Wan
91c885b5b0 chore: add more unit tests for mcp (#4928) 2025-06-07 20:41:57 +08:00
MarkJoyMa
d4cccca387 Fix the problem that mcp request id is not of int type (#4914) 2025-06-07 10:37:18 +08:00
dependabot[bot]
4b2095ed03 chore(deps): bump github.com/redis/go-redis/v9 from 9.9.0 to 9.10.0 (#4926) 2025-06-07 10:07:26 +08:00
dependabot[bot]
1229eeb2d2 chore(deps): bump go.mongodb.org/mongo-driver from 1.17.3 to 1.17.4 (#4924) 2025-06-06 19:45:26 +08:00
dependabot[bot]
9142b146c5 chore(deps): bump github.com/alicebob/miniredis/v2 from 2.34.0 to 2.35.0 (#4919)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-06-06 10:09:15 +08:00
Kevin Wan
8a1b2d5aed chore: fix typo (#4920) 2025-06-05 22:51:22 +08:00
Leon cap
da5d39e6ca fix: correct spelling of 'cancellation' in timeout handler comment (#4916)
Co-authored-by: Kevin Wan <wanjunfeng@gmail.com>
2025-06-05 22:42:53 +08:00
Leon cap
68c5a17c67 fix: correct spelling of 'underlying' in Header method comment (#4918) 2025-06-05 10:36:21 +00:00
Leon cap
b53f9f5f2d fix: correct spelling of 'TimeoutHandler' in timeout handler comment (#4917) 2025-06-04 15:48:37 +00:00
dependabot[bot]
36d57626b6 chore(deps): bump github.com/redis/go-redis/v9 from 9.8.0 to 9.9.0 (#4905)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-05-28 11:32:57 +08:00
Kevin Wan
4e36ba832f Update readme.md (#4897) 2025-05-25 22:25:56 +08:00
46 changed files with 1613 additions and 144 deletions

View File

@@ -423,7 +423,7 @@ func TestRegistry_Monitor(t *testing.T) {
GetRegistry().clusters = map[string]*cluster{
getClusterKey(endpoints): {
watchers: map[watchKey]*watchValue{
watchKey{
{
key: "foo",
exactMatch: true,
}: {
@@ -449,7 +449,7 @@ func TestRegistry_Unmonitor(t *testing.T) {
GetRegistry().clusters = map[string]*cluster{
getClusterKey(endpoints): {
watchers: map[watchKey]*watchValue{
watchKey{
{
key: "foo",
exactMatch: true,
}: {

View File

@@ -7,12 +7,11 @@ import (
)
var (
fieldsContextKey contextKey
globalFields atomic.Value
globalFieldsLock sync.Mutex
)
type contextKey struct{}
type fieldsKey struct{}
// AddGlobalFields adds global fields.
func AddGlobalFields(fields ...LogField) {
@@ -29,16 +28,16 @@ func AddGlobalFields(fields ...LogField) {
// ContextWithFields returns a new context with the given fields.
func ContextWithFields(ctx context.Context, fields ...LogField) context.Context {
if val := ctx.Value(fieldsContextKey); val != nil {
if val := ctx.Value(fieldsKey{}); val != nil {
if arr, ok := val.([]LogField); ok {
allFields := make([]LogField, 0, len(arr)+len(fields))
allFields = append(allFields, arr...)
allFields = append(allFields, fields...)
return context.WithValue(ctx, fieldsContextKey, allFields)
return context.WithValue(ctx, fieldsKey{}, allFields)
}
}
return context.WithValue(ctx, fieldsContextKey, fields)
return context.WithValue(ctx, fieldsKey{}, fields)
}
// WithFields returns a new logger with the given fields.

View File

@@ -34,7 +34,7 @@ func TestAddGlobalFields(t *testing.T) {
func TestContextWithFields(t *testing.T) {
ctx := ContextWithFields(context.Background(), Field("a", 1), Field("b", 2))
vals := ctx.Value(fieldsContextKey)
vals := ctx.Value(fieldsKey{})
assert.NotNil(t, vals)
fields, ok := vals.([]LogField)
assert.True(t, ok)
@@ -43,7 +43,7 @@ func TestContextWithFields(t *testing.T) {
func TestWithFields(t *testing.T) {
ctx := WithFields(context.Background(), Field("a", 1), Field("b", 2))
vals := ctx.Value(fieldsContextKey)
vals := ctx.Value(fieldsKey{})
assert.NotNil(t, vals)
fields, ok := vals.([]LogField)
assert.True(t, ok)
@@ -55,7 +55,7 @@ func TestWithFieldsAppend(t *testing.T) {
ctx := context.WithValue(context.Background(), dummyKey, "dummy")
ctx = ContextWithFields(ctx, Field("a", 1), Field("b", 2))
ctx = ContextWithFields(ctx, Field("c", 3), Field("d", 4))
vals := ctx.Value(fieldsContextKey)
vals := ctx.Value(fieldsKey{})
assert.NotNil(t, vals)
fields, ok := vals.([]LogField)
assert.True(t, ok)
@@ -80,8 +80,8 @@ func TestWithFieldsAppendCopy(t *testing.T) {
ctxa := ContextWithFields(ctx, af)
ctxb := ContextWithFields(ctx, bf)
assert.EqualValues(t, af, ctxa.Value(fieldsContextKey).([]LogField)[count])
assert.EqualValues(t, bf, ctxb.Value(fieldsContextKey).([]LogField)[count])
assert.EqualValues(t, af, ctxa.Value(fieldsKey{}).([]LogField)[count])
assert.EqualValues(t, bf, ctxb.Value(fieldsKey{}).([]LogField)[count])
}
func BenchmarkAtomicValue(b *testing.B) {

View File

@@ -224,7 +224,7 @@ func (l *richLogger) buildFields(fields ...LogField) []LogField {
fields = append(fields, Field(spanKey, spanID))
}
val := l.ctx.Value(fieldsContextKey)
val := l.ctx.Value(fieldsKey{})
if val != nil {
if arr, ok := val.([]LogField); ok {
fields = append(fields, arr...)

View File

@@ -30,7 +30,9 @@ var (
errValueNotSettable = errors.New("value is not settable")
errValueNotStruct = errors.New("value type is not struct")
keyUnmarshaler = NewUnmarshaler(defaultKeyName)
boolType = reflect.TypeOf(false)
durationType = reflect.TypeOf(time.Duration(0))
stringType = reflect.TypeOf("")
cacheKeys = make(map[string][]string)
cacheKeysLock sync.Mutex
defaultCache = make(map[string]any)
@@ -622,9 +624,19 @@ func (u *Unmarshaler) processFieldNotFromString(fieldType reflect.Type, value re
return u.fillSliceFromString(fieldType, value, mapValue, fullName)
case valueKind == reflect.String && derefedFieldType == durationType:
return fillDurationValue(fieldType, value, mapValue.(string))
v, err := convertToString(mapValue, fullName)
if err != nil {
return err
}
return fillDurationValue(fieldType, value, v)
case valueKind == reflect.String && typeKind == reflect.Struct && u.implementsUnmarshaler(fieldType):
return u.fillUnmarshalerStruct(fieldType, value, mapValue.(string))
v, err := convertToString(mapValue, fullName)
if err != nil {
return err
}
return u.fillUnmarshalerStruct(fieldType, value, v)
default:
return u.processFieldPrimitive(fieldType, value, mapValue, opts, fullName)
}
@@ -755,24 +767,24 @@ func (u *Unmarshaler) processFieldWithEnvValue(fieldType reflect.Type, value ref
return err
}
fieldKind := fieldType.Kind()
switch fieldKind {
case reflect.Bool:
derefType := Deref(fieldType)
switch derefType {
case boolType:
val, err := strconv.ParseBool(envVal)
if err != nil {
return fmt.Errorf("unmarshal field %q with environment variable, %w", fullName, err)
}
value.SetBool(val)
SetValue(fieldType, value, reflect.ValueOf(val))
return nil
case durationType.Kind():
case durationType:
if err := fillDurationValue(fieldType, value, envVal); err != nil {
return fmt.Errorf("unmarshal field %q with environment variable, %w", fullName, err)
}
return nil
case reflect.String:
value.SetString(envVal)
case stringType:
SetValue(fieldType, value, reflect.ValueOf(envVal))
return nil
default:
return u.processFieldPrimitiveWithJSONNumber(fieldType, value, json.Number(envVal), opts, fullName)

View File

@@ -203,6 +203,20 @@ func TestUnmarshalDuration(t *testing.T) {
}
}
func TestUnmarshalDurationUnexpectedError(t *testing.T) {
type inner struct {
Duration time.Duration `key:"duration"`
}
content := "{\"duration\": 1}"
var m = map[string]any{}
err := jsonx.Unmarshal([]byte(content), &m)
assert.NoError(t, err)
var in inner
err = UnmarshalKey(m, &in)
assert.Error(t, err)
assert.Contains(t, err.Error(), "expect string")
}
func TestUnmarshalDurationDefault(t *testing.T) {
type inner struct {
Int int `key:"int"`
@@ -4665,6 +4679,23 @@ func TestUnmarshal_EnvInt(t *testing.T) {
}
}
func TestUnmarshal_EnvInt64(t *testing.T) {
type Value struct {
Age int64 `key:"age,env=TEST_NAME_INT64"`
}
const (
envName = "TEST_NAME_INT64"
envVal = "88"
)
t.Setenv(envName, envVal)
var v Value
if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
assert.Equal(t, int64(88), v.Age)
}
}
func TestUnmarshal_EnvIntOverwrite(t *testing.T) {
type Value struct {
Age int `key:"age,env=TEST_NAME_INT"`
@@ -4770,20 +4801,33 @@ func TestUnmarshal_EnvBoolBad(t *testing.T) {
}
func TestUnmarshal_EnvDuration(t *testing.T) {
type Value struct {
Duration time.Duration `key:"duration,env=TEST_NAME_DURATION"`
}
const (
envName = "TEST_NAME_DURATION"
envVal = "1s"
)
t.Setenv(envName, envVal)
var v Value
if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
assert.Equal(t, time.Second, v.Duration)
}
t.Run("valid duration", func(t *testing.T) {
type Value struct {
Duration time.Duration `key:"duration,env=TEST_NAME_DURATION"`
}
var v Value
if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
assert.Equal(t, time.Second, v.Duration)
}
})
t.Run("ptr of duration", func(t *testing.T) {
type Value struct {
Duration *time.Duration `key:"duration,env=TEST_NAME_DURATION"`
}
var v Value
if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
assert.Equal(t, time.Second, *v.Duration)
}
})
}
func TestUnmarshal_EnvDurationBadValue(t *testing.T) {
@@ -5995,6 +6039,16 @@ func TestUnmarshal_Unmarshaler(t *testing.T) {
}, &v))
assert.Nil(t, v.Foo)
})
t.Run("json.Number", func(t *testing.T) {
v := struct {
Foo *mockUnmarshaler `json:"name"`
}{}
m := map[string]any{
"name": json.Number("123"),
}
assert.Error(t, UnmarshalJsonMap(m, &v))
})
}
func TestParseJsonStringValue(t *testing.T) {

View File

@@ -92,6 +92,15 @@ func ValidatePtr(v reflect.Value) error {
return nil
}
func convertToString(val any, fullName string) (string, error) {
v, ok := val.(string)
if !ok {
return "", fmt.Errorf("expect string for field %s, but got type %T", fullName, val)
}
return v, nil
}
func convertTypeFromString(kind reflect.Kind, str string) (any, error) {
switch kind {
case reflect.Bool:

View File

@@ -8,6 +8,7 @@ import (
"github.com/zeromicro/go-zero/core/stat"
"github.com/zeromicro/go-zero/core/trace"
"github.com/zeromicro/go-zero/internal/devserver"
"github.com/zeromicro/go-zero/internal/profiling"
)
const (
@@ -38,6 +39,8 @@ type (
Telemetry trace.Config `json:",optional"`
DevServer DevServerConfig `json:",optional"`
Shutdown proc.ShutdownConf `json:",optional"`
// Profiling is the configuration for continuous profiling.
Profiling profiling.Config `json:",optional"`
}
)
@@ -70,7 +73,9 @@ func (sc ServiceConf) SetUp() error {
if len(sc.MetricsUrl) > 0 {
stat.SetReportWriter(stat.NewRemoteWriter(sc.MetricsUrl))
}
devserver.StartAgent(sc.DevServer)
profiling.Start(sc.Profiling)
return nil
}

View File

@@ -0,0 +1,29 @@
package sqlx
import "errors"
var (
errEmptyDatasource = errors.New("empty datasource")
errEmptyDriverName = errors.New("empty driver name")
)
// SqlConf defines the configuration for sqlx.
type SqlConf struct {
DataSource string
DriverName string `json:",default=mysql"`
Replicas []string `json:",optional"`
Policy string `json:",default=round-robin,options=round-robin|random"`
}
// Validate validates the SqlxConf.
func (sc SqlConf) Validate() error {
if len(sc.DataSource) == 0 {
return errEmptyDatasource
}
if len(sc.DriverName) == 0 {
return errEmptyDriverName
}
return nil
}

View File

@@ -0,0 +1,29 @@
package sqlx
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/conf"
)
func TestValidate(t *testing.T) {
text := []byte(`DataSource: primary:password@tcp(127.0.0.1:3306)/primary_db
`)
var sc SqlConf
err := conf.LoadFromYamlBytes(text, &sc)
assert.Nil(t, err)
assert.Equal(t, "mysql", sc.DriverName)
assert.Equal(t, policyRoundRobin, sc.Policy)
assert.Nil(t, sc.Validate())
sc = SqlConf{}
assert.Equal(t, errEmptyDatasource, sc.Validate())
sc.DataSource = "primary:password@tcp(127.0.0.1:3306)/primary_db"
assert.Equal(t, errEmptyDriverName, sc.Validate())
sc.DriverName = "mysql"
assert.Nil(t, sc.Validate())
}

View File

@@ -0,0 +1,65 @@
package sqlx
import "context"
const (
// policyRoundRobin round-robin policy for selecting replicas.
policyRoundRobin = "round-robin"
// policyRandom random policy for selecting replicas.
policyRandom = "random"
// readPrimaryMode indicates that the operation is a read,
// but should be performed on the primary database instance.
//
// This mode is used in scenarios where data freshness and consistency are critical,
// such as immediately after writes or where replication lag may cause stale reads.
readPrimaryMode readWriteMode = "read-primary"
// readReplicaMode indicates that the operation is a read from replicas.
// This is suitable for scenarios where eventual consistency is acceptable,
// and the goal is to offload traffic from the primary and improve read scalability.
readReplicaMode readWriteMode = "read-replica"
// writeMode indicates that the operation is a write operation (to primary).
writeMode readWriteMode = "write"
// notSpecifiedMode indicates that the read/write mode is not specified.
notSpecifiedMode readWriteMode = ""
)
type readWriteModeKey struct{}
// WithReadPrimary sets the context to read-primary mode.
func WithReadPrimary(ctx context.Context) context.Context {
return context.WithValue(ctx, readWriteModeKey{}, readPrimaryMode)
}
// WithReadReplica sets the context to read-replica mode.
func WithReadReplica(ctx context.Context) context.Context {
return context.WithValue(ctx, readWriteModeKey{}, readReplicaMode)
}
// WithWrite sets the context to write mode, indicating that the operation is a write operation.
func WithWrite(ctx context.Context) context.Context {
return context.WithValue(ctx, readWriteModeKey{}, writeMode)
}
type readWriteMode string
func (m readWriteMode) isValid() bool {
return m == readPrimaryMode || m == readReplicaMode || m == writeMode
}
func getReadWriteMode(ctx context.Context) readWriteMode {
if mode := ctx.Value(readWriteModeKey{}); mode != nil {
if v, ok := mode.(readWriteMode); ok && v.isValid() {
return v
}
}
return notSpecifiedMode
}
func usePrimary(ctx context.Context) bool {
return getReadWriteMode(ctx) != readReplicaMode
}

View File

@@ -0,0 +1,142 @@
package sqlx
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
)
func TestIsValid(t *testing.T) {
testCases := []struct {
name string
mode readWriteMode
expected bool
}{
{
name: "valid read-primary mode",
mode: readPrimaryMode,
expected: true,
},
{
name: "valid read-replica mode",
mode: readReplicaMode,
expected: true,
},
{
name: "valid write mode",
mode: writeMode,
expected: true,
},
{
name: "not specified mode (empty)",
mode: notSpecifiedMode,
expected: false,
},
{
name: "invalid custom string",
mode: readWriteMode("delete"),
expected: false,
},
{
name: "case sensitive check",
mode: readWriteMode("READ"),
expected: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
actual := tc.mode.isValid()
assert.Equal(t, tc.expected, actual)
})
}
}
func TestWithReadMode(t *testing.T) {
ctx := context.Background()
readPrimaryCtx := WithReadPrimary(ctx)
val := readPrimaryCtx.Value(readWriteModeKey{})
assert.Equal(t, readPrimaryMode, val)
readReplicaCtx := WithReadReplica(ctx)
val = readReplicaCtx.Value(readWriteModeKey{})
assert.Equal(t, readReplicaMode, val)
}
func TestWithWriteMode(t *testing.T) {
ctx := context.Background()
writeCtx := WithWrite(ctx)
val := writeCtx.Value(readWriteModeKey{})
assert.Equal(t, writeMode, val)
}
func TestGetReadWriteMode(t *testing.T) {
t.Run("valid read-primary mode", func(t *testing.T) {
ctx := context.WithValue(context.Background(), readWriteModeKey{}, readPrimaryMode)
assert.Equal(t, readPrimaryMode, getReadWriteMode(ctx))
})
t.Run("valid read-replica mode", func(t *testing.T) {
ctx := context.WithValue(context.Background(), readWriteModeKey{}, readReplicaMode)
assert.Equal(t, readReplicaMode, getReadWriteMode(ctx))
})
t.Run("valid write mode", func(t *testing.T) {
ctx := context.WithValue(context.Background(), readWriteModeKey{}, writeMode)
assert.Equal(t, writeMode, getReadWriteMode(ctx))
})
t.Run("invalid mode value (wrong type)", func(t *testing.T) {
ctx := context.WithValue(context.Background(), readWriteModeKey{}, "not-a-mode")
assert.Equal(t, notSpecifiedMode, getReadWriteMode(ctx))
})
t.Run("invalid mode value (wrong value)", func(t *testing.T) {
ctx := context.WithValue(context.Background(), readWriteModeKey{}, readWriteMode("delete"))
assert.Equal(t, notSpecifiedMode, getReadWriteMode(ctx))
})
t.Run("no mode set", func(t *testing.T) {
ctx := context.Background()
assert.Equal(t, notSpecifiedMode, getReadWriteMode(ctx))
})
}
func TestUsePrimary(t *testing.T) {
t.Run("context with read-replica mode", func(t *testing.T) {
ctx := context.WithValue(context.Background(), readWriteModeKey{}, readReplicaMode)
assert.False(t, usePrimary(ctx))
})
t.Run("context with read-primary mode", func(t *testing.T) {
ctx := context.WithValue(context.Background(), readWriteModeKey{}, readPrimaryMode)
assert.True(t, usePrimary(ctx))
})
t.Run("context with write mode", func(t *testing.T) {
ctx := context.WithValue(context.Background(), readWriteModeKey{}, writeMode)
assert.True(t, usePrimary(ctx))
})
t.Run("context with invalid mode", func(t *testing.T) {
ctx := context.WithValue(context.Background(), readWriteModeKey{}, readWriteMode("invalid"))
assert.True(t, usePrimary(ctx))
})
t.Run("context with no mode set", func(t *testing.T) {
ctx := context.Background()
assert.True(t, usePrimary(ctx))
})
}
func TestWithModeTwice(t *testing.T) {
ctx := context.Background()
ctx = WithReadPrimary(ctx)
writeCtx := WithWrite(ctx)
val := writeCtx.Value(readWriteModeKey{})
assert.Equal(t, writeMode, val)
}

View File

@@ -4,6 +4,9 @@ import (
"context"
"database/sql"
"errors"
"fmt"
"math/rand"
"sync/atomic"
"github.com/zeromicro/go-zero/core/breaker"
"github.com/zeromicro/go-zero/core/errorx"
@@ -52,9 +55,10 @@ type (
beginTx beginnable
brk breaker.Breaker
accept breaker.Acceptable
index uint32
}
connProvider func() (*sql.DB, error)
connProvider func(ctx context.Context) (*sql.DB, error)
sessionConn interface {
Exec(query string, args ...any) (sql.Result, error)
@@ -64,10 +68,41 @@ type (
}
)
// MustNewConn returns a SqlConn with the given SqlConf.
func MustNewConn(c SqlConf, opts ...SqlOption) SqlConn {
conn, err := NewConn(c, opts...)
if err != nil {
logx.Must(err)
}
return conn
}
// NewConn returns a SqlConn with the given SqlConf.
func NewConn(c SqlConf, opts ...SqlOption) (SqlConn, error) {
if err := c.Validate(); err != nil {
return nil, err
}
conn := &commonSqlConn{
onError: func(ctx context.Context, err error) {
logInstanceError(ctx, c.DataSource, err)
},
beginTx: begin,
brk: breaker.NewBreaker(),
}
for _, opt := range opts {
opt(conn)
}
conn.connProv = getConnProvider(conn, c.DriverName, c.DataSource, c.Policy, c.Replicas)
return conn, nil
}
// NewSqlConn returns a SqlConn with given driver name and datasource.
func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
conn := &commonSqlConn{
connProv: func() (*sql.DB, error) {
connProv: func(context.Context) (*sql.DB, error) {
return getSqlConn(driverName, datasource)
},
onError: func(ctx context.Context, err error) {
@@ -87,7 +122,7 @@ func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
// Use it with caution; it's provided for other ORM to interact with.
func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn {
conn := &commonSqlConn{
connProv: func() (*sql.DB, error) {
connProv: func(ctx context.Context) (*sql.DB, error) {
return db, nil
},
onError: func(ctx context.Context, err error) {
@@ -123,7 +158,7 @@ func (db *commonSqlConn) ExecCtx(ctx context.Context, q string, args ...any) (
err = db.brk.DoWithAcceptableCtx(ctx, func() error {
var conn *sql.DB
conn, err = db.connProv()
conn, err = db.connProv(ctx)
if err != nil {
db.onError(ctx, err)
return err
@@ -151,7 +186,7 @@ func (db *commonSqlConn) PrepareCtx(ctx context.Context, query string) (stmt Stm
err = db.brk.DoWithAcceptableCtx(ctx, func() error {
var conn *sql.DB
conn, err = db.connProv()
conn, err = db.connProv(ctx)
if err != nil {
db.onError(ctx, err)
return err
@@ -242,7 +277,7 @@ func (db *commonSqlConn) QueryRowsPartialCtx(ctx context.Context, v any,
}
func (db *commonSqlConn) RawDB() (*sql.DB, error) {
return db.connProv()
return db.connProv(context.Background())
}
func (db *commonSqlConn) Transact(fn func(Session) error) error {
@@ -288,7 +323,7 @@ func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows)
q string, args ...any) (err error) {
var scanFailed bool
err = db.brk.DoWithAcceptableCtx(ctx, func() error {
conn, err := db.connProv()
conn, err := db.connProv(ctx)
if err != nil {
db.onError(ctx, err)
return err
@@ -311,6 +346,38 @@ func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows)
return
}
func getConnProvider(sc *commonSqlConn, driverName, datasource, policy string, replicas []string) connProvider {
return func(ctx context.Context) (*sql.DB, error) {
replicaCount := len(replicas)
if replicaCount == 0 || usePrimary(ctx) {
return getSqlConn(driverName, datasource)
}
var dsn string
if replicaCount == 1 {
dsn = replicas[0]
} else {
if len(policy) == 0 {
policy = policyRoundRobin
}
switch policy {
case policyRandom:
dsn = replicas[rand.Intn(replicaCount)]
case policyRoundRobin:
index := atomic.AddUint32(&sc.index, 1) - 1
dsn = replicas[index%uint32(replicaCount)]
default:
return nil, fmt.Errorf("unknown policy: %s", policy)
}
}
return getSqlConn(driverName, dsn)
}
}
// WithAcceptable returns a SqlOption that setting the acceptable function.
// acceptable is the func to check if the error can be accepted.
func WithAcceptable(acceptable func(err error) bool) SqlOption {

View File

@@ -1,6 +1,7 @@
package sqlx
import (
"context"
"database/sql"
"errors"
"io"
@@ -98,7 +99,7 @@ func TestSqlConn_RawDB(t *testing.T) {
func TestSqlConn_Errors(t *testing.T) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
conn := NewSqlConnFromDB(db)
conn.(*commonSqlConn).connProv = func() (*sql.DB, error) {
conn.(*commonSqlConn).connProv = func(ctx context.Context) (*sql.DB, error) {
return nil, errors.New("error")
}
_, err := conn.Prepare("any")
@@ -138,6 +139,148 @@ func TestSqlConn_Errors(t *testing.T) {
})
}
func TestConfigSqlConn(t *testing.T) {
db, mock, err := sqlmock.New()
assert.NotNil(t, db)
assert.NotNil(t, mock)
assert.Nil(t, err)
connManager.Inject(mockedDatasource, db)
mock.ExpectExec("any")
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"foo"}))
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
conn := MustNewConn(conf, withMysqlAcceptable())
_, err = conn.Exec("any", "value")
assert.NotNil(t, err)
_, err = conn.Prepare("any")
assert.NotNil(t, err)
var val string
assert.NotNil(t, conn.QueryRow(&val, "any"))
assert.NotNil(t, conn.QueryRowPartial(&val, "any"))
assert.NotNil(t, conn.QueryRows(&val, "any"))
assert.NotNil(t, conn.QueryRowsPartial(&val, "any"))
}
func TestConfigSqlConnStatement(t *testing.T) {
db, mock, err := sqlmock.New()
assert.NotNil(t, db)
assert.NotNil(t, mock)
assert.Nil(t, err)
connManager.Inject(mockedDatasource, db)
mock.ExpectPrepare("any")
mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
mock.ExpectPrepare("any")
row := sqlmock.NewRows([]string{"foo"}).AddRow("bar")
mock.ExpectQuery("any").WillReturnRows(row)
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
conn := MustNewConn(conf, withMysqlAcceptable())
stmt, err := conn.Prepare("any")
assert.NoError(t, err)
res, err := stmt.Exec()
assert.NoError(t, err)
lastInsertID, err := res.LastInsertId()
assert.NoError(t, err)
assert.Equal(t, int64(2), lastInsertID)
rowsAffected, err := res.RowsAffected()
assert.NoError(t, err)
assert.Equal(t, int64(3), rowsAffected)
stmt, err = conn.Prepare("any")
assert.NoError(t, err)
var val string
err = stmt.QueryRow(&val)
assert.NoError(t, err)
assert.Equal(t, "bar", val)
mock.ExpectPrepare("any")
rows := sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar")
mock.ExpectQuery("any").WillReturnRows(rows)
stmt, err = conn.Prepare("any")
assert.NoError(t, err)
var vals []string
assert.NoError(t, stmt.QueryRowsPartial(&vals))
assert.ElementsMatch(t, []string{"foo", "bar"}, vals)
}
func TestConfigSqlConnQuery(t *testing.T) {
db, mock, err := sqlmock.New()
assert.NotNil(t, db)
assert.NotNil(t, mock)
assert.Nil(t, err)
connManager.Inject(mockedDatasource, db)
t.Run("QueryRow", func(t *testing.T) {
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"foo"}).AddRow("bar"))
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
conn := MustNewConn(conf)
var val string
assert.NoError(t, conn.QueryRow(&val, "any"))
assert.Equal(t, "bar", val)
})
t.Run("QueryRowPartial", func(t *testing.T) {
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"foo"}).AddRow("bar"))
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
conn := MustNewConn(conf)
var val string
assert.NoError(t, conn.QueryRowPartial(&val, "any"))
assert.Equal(t, "bar", val)
})
t.Run("QueryRows", func(t *testing.T) {
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar"))
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
conn := MustNewConn(conf)
var vals []string
assert.NoError(t, conn.QueryRows(&vals, "any"))
assert.ElementsMatch(t, []string{"foo", "bar"}, vals)
})
t.Run("QueryRowsPartial", func(t *testing.T) {
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar"))
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
conn := MustNewConn(conf)
var vals []string
assert.NoError(t, conn.QueryRowsPartial(&vals, "any"))
assert.ElementsMatch(t, []string{"foo", "bar"}, vals)
})
}
func TestConfigSqlConnErr(t *testing.T) {
t.Run("panic on empty config", func(t *testing.T) {
original := logx.ExitOnFatal.True()
logx.ExitOnFatal.Set(false)
defer logx.ExitOnFatal.Set(original)
assert.Panics(t, func() {
MustNewConn(SqlConf{})
})
})
t.Run("on error", func(t *testing.T) {
db, mock, err := sqlmock.New()
assert.NotNil(t, db)
assert.NotNil(t, mock)
assert.Nil(t, err)
connManager.Inject(mockedDatasource, db)
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
conn := MustNewConn(conf)
conn.(*commonSqlConn).connProv = func(ctx context.Context) (*sql.DB, error) {
return nil, errors.New("error")
}
_, err = conn.Prepare("any")
assert.Error(t, err)
})
}
func TestStatement(t *testing.T) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectPrepare("any").WillBeClosed()
@@ -303,6 +446,93 @@ func TestWithAcceptable(t *testing.T) {
assert.True(t, conn.accept(acceptableErr3))
}
func TestProvider(t *testing.T) {
defer func() {
_ = connManager.Close()
}()
primaryDSN := "primary:password@tcp(127.0.0.1:3306)/primary_db"
replicasDSN := []string{
"replica_one:pwd@tcp(localhost:3306)/replica_one",
"replica_two:pwd@tcp(localhost:3306)/replica_two",
"replica_three:pwd@tcp(localhost:3306)/replica_three",
}
primaryDB, err := connManager.GetResource(primaryDSN, func() (io.Closer, error) { return sql.Open(mysqlDriverName, primaryDSN) })
assert.Nil(t, err)
assert.NotNil(t, primaryDB)
replicaOneDB, err := connManager.GetResource(replicasDSN[0], func() (io.Closer, error) { return sql.Open(mysqlDriverName, replicasDSN[0]) })
assert.Nil(t, err)
assert.NotNil(t, replicaOneDB)
replicaTwoDB, err := connManager.GetResource(replicasDSN[1], func() (io.Closer, error) { return sql.Open(mysqlDriverName, replicasDSN[1]) })
assert.Nil(t, err)
assert.NotNil(t, replicaTwoDB)
replicaThreeDB, err := connManager.GetResource(replicasDSN[2], func() (io.Closer, error) { return sql.Open(mysqlDriverName, replicasDSN[2]) })
assert.Nil(t, err)
assert.NotNil(t, replicaThreeDB)
sc := &commonSqlConn{}
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRoundRobin, nil)
ctx := context.Background()
db, err := sc.connProv(ctx)
assert.Nil(t, err)
assert.Equal(t, primaryDB, db)
ctx = WithWrite(ctx)
db, err = sc.connProv(ctx)
assert.Nil(t, err)
assert.Equal(t, primaryDB, db)
ctx = WithReadPrimary(ctx)
db, err = sc.connProv(ctx)
assert.Nil(t, err)
assert.Equal(t, primaryDB, db)
// no mode set, should return primary
ctx = context.Background()
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRoundRobin, replicasDSN)
db, err = sc.connProv(ctx)
assert.Nil(t, err)
assert.Equal(t, primaryDB, db)
ctx = WithReadReplica(ctx)
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRoundRobin, []string{replicasDSN[0]})
db, err = sc.connProv(ctx)
assert.Nil(t, err)
assert.Equal(t, replicaOneDB, db)
// default policy is round-robin
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRoundRobin, replicasDSN)
replicas := []io.Closer{replicaOneDB, replicaTwoDB, replicaThreeDB}
for i := 0; i < len(replicasDSN); i++ {
db, err = sc.connProv(ctx)
assert.Nil(t, err)
assert.Equal(t, replicas[i], db)
}
// random policy
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRandom, replicasDSN)
for i := 0; i < len(replicasDSN); i++ {
db, err = sc.connProv(ctx)
assert.Nil(t, err)
assert.Contains(t, replicas, db)
}
// unknown policy
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, "unknown", replicasDSN)
_, err = sc.connProv(ctx)
assert.NotNil(t, err)
// empty policy transforms to round-robin
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, "", replicasDSN)
for i := 0; i < len(replicasDSN); i++ {
db, err = sc.connProv(ctx)
assert.Nil(t, err)
assert.Equal(t, replicas[i], db)
}
}
func buildConn() (mock sqlmock.Sqlmock, err error) {
_, err = connManager.GetResource(mockedDatasource, func() (io.Closer, error) {
var db *sql.DB

View File

@@ -27,7 +27,7 @@ func getCachedSqlConn(driverName, server string) (*sql.DB, error) {
return nil, err
}
if driverName != mysqlDriverName {
if driverName == mysqlDriverName {
if cfg, e := mysql.ParseDSN(server); e != nil {
// if cannot parse, don't collect the metrics
logx.Error(e)

View File

@@ -156,7 +156,7 @@ func begin(db *sql.DB) (trans, error) {
func transact(ctx context.Context, db *commonSqlConn, b beginnable,
fn func(context.Context, Session) error) (err error) {
conn, err := db.connProv()
conn, err := db.connProv(ctx)
if err != nil {
db.onError(ctx, err)
return err

View File

@@ -117,7 +117,7 @@ func TestTxExceptions(t *testing.T) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
conn := &commonSqlConn{
connProv: func() (*sql.DB, error) {
connProv: func(ctx context.Context) (*sql.DB, error) {
return nil, errors.New("foo")
},
beginTx: begin,

9
go.mod
View File

@@ -4,7 +4,7 @@ go 1.21
require (
github.com/DATA-DOG/go-sqlmock v1.5.2
github.com/alicebob/miniredis/v2 v2.34.0
github.com/alicebob/miniredis/v2 v2.35.0
github.com/fatih/color v1.18.0
github.com/fullstorydev/grpcurl v1.9.3
github.com/go-sql-driver/mysql v1.9.0
@@ -12,16 +12,17 @@ require (
github.com/golang/mock v1.6.0
github.com/golang/protobuf v1.5.4
github.com/google/uuid v1.6.0
github.com/grafana/pyroscope-go v1.2.2
github.com/jackc/pgx/v5 v5.7.4
github.com/jhump/protoreflect v1.17.0
github.com/pelletier/go-toml/v2 v2.2.2
github.com/prometheus/client_golang v1.21.1
github.com/redis/go-redis/v9 v9.8.0
github.com/redis/go-redis/v9 v9.11.0
github.com/spaolacci/murmur3 v1.1.0
github.com/stretchr/testify v1.10.0
go.etcd.io/etcd/api/v3 v3.5.15
go.etcd.io/etcd/client/v3 v3.5.15
go.mongodb.org/mongo-driver v1.17.3
go.mongodb.org/mongo-driver v1.17.4
go.opentelemetry.io/otel v1.24.0
go.opentelemetry.io/otel/exporters/jaeger v1.17.0
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.24.0
@@ -49,7 +50,6 @@ require (
require (
filippo.io/edwards25519 v1.1.0 // indirect
github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/bufbuild/protocompile v0.14.1 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
@@ -72,6 +72,7 @@ require (
github.com/google/gnostic-models v0.6.8 // indirect
github.com/google/go-cmp v0.6.0 // indirect
github.com/google/gofuzz v1.2.0 // indirect
github.com/grafana/pyroscope-go/godeltaprof v0.1.8 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 // indirect
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect

18
go.sum
View File

@@ -2,10 +2,8 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 h1:uvdUDbHQHO85qeSydJtItA4T55Pw6BtAejd0APRJOCE=
github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc=
github.com/alicebob/miniredis/v2 v2.34.0 h1:mBFWMaJSNL9RwdGRyEDoAAv8OQc5UlEhLDQggTglU/0=
github.com/alicebob/miniredis/v2 v2.34.0/go.mod h1:kWShP4b58T1CW0Y5dViCd5ztzrDqRWqM3nksiyXk5s8=
github.com/alicebob/miniredis/v2 v2.35.0 h1:QwLphYqCEAo1eu1TqPRN2jgVMPBweeQcR21jeqDCONI=
github.com/alicebob/miniredis/v2 v2.35.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM=
github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8=
github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
@@ -82,6 +80,10 @@ github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1 h1:K6RDEckDVWvDI9JAJY
github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/grafana/pyroscope-go v1.2.2 h1:uvKCyZMD724RkaCEMrSTC38Yn7AnFe8S2wiAIYdDPCE=
github.com/grafana/pyroscope-go v1.2.2/go.mod h1:zzT9QXQAp2Iz2ZdS216UiV8y9uXJYQiGE1q8v1FyhqU=
github.com/grafana/pyroscope-go/godeltaprof v0.1.8 h1:iwOtYXeeVSAeYefJNaxDytgjKtUuKQbJqgAIjlnicKg=
github.com/grafana/pyroscope-go/godeltaprof v0.1.8/go.mod h1:2+l7K7twW49Ct4wFluZD3tZ6e0SjanjcUUBPVD/UuGU=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 h1:bkypFPDjIYGfCYD5mRBvpqxfYX1YCS1PXdKYWi8FsN0=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0/go.mod h1:P+Lt/0by1T8bfcF3z737NnSbmxQAppXMRziHUxPOC8k=
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw=
@@ -156,8 +158,8 @@ github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
github.com/redis/go-redis/v9 v9.8.0 h1:q3nRvjrlge/6UD7eTu/DSg2uYiU2mCL0G/uzBWqhicI=
github.com/redis/go-redis/v9 v9.8.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
github.com/redis/go-redis/v9 v9.11.0 h1:E3S08Gl/nJNn5vkxd2i78wZxWAPNZgUNTp8WIJUAiIs=
github.com/redis/go-redis/v9 v9.11.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
@@ -200,8 +202,8 @@ go.etcd.io/etcd/client/pkg/v3 v3.5.15 h1:fo0HpWz/KlHGMCC+YejpiCmyWDEuIpnTDzpJLB5
go.etcd.io/etcd/client/pkg/v3 v3.5.15/go.mod h1:mXDI4NAOwEiszrHCb0aqfAYNCrZP4e9hRca3d1YK8EU=
go.etcd.io/etcd/client/v3 v3.5.15 h1:23M0eY4Fd/inNv1ZfU3AxrbbOdW79r9V9Rl62Nm6ip4=
go.etcd.io/etcd/client/v3 v3.5.15/go.mod h1:CLSJxrYjvLtHsrPKsy7LmZEE+DK2ktfd2bN4RhBMwlU=
go.mongodb.org/mongo-driver v1.17.3 h1:TQyXhnsWfWtgAhMtOgtYHMTkZIfBTpMTsMnd9ZBeHxQ=
go.mongodb.org/mongo-driver v1.17.3/go.mod h1:Hy04i7O2kC4RS06ZrhPRqj/u4DTYkFDAAccj+rVKqgQ=
go.mongodb.org/mongo-driver v1.17.4 h1:jUorfmVzljjr0FLzYQsGP8cgN/qzzxlY9Vh0C9KFXVw=
go.mongodb.org/mongo-driver v1.17.4/go.mod h1:Hy04i7O2kC4RS06ZrhPRqj/u4DTYkFDAAccj+rVKqgQ=
go.opentelemetry.io/otel v1.24.0 h1:0LAOdjNmQeSTzGBzduGe/rU4tZhMwL5rWgtp9Ku5Jfo=
go.opentelemetry.io/otel v1.24.0/go.mod h1:W7b9Ozg4nkF5tWI5zsXkaKKDjdVjpD4oAt9Qi/MArHo=
go.opentelemetry.io/otel/exporters/jaeger v1.17.0 h1:D7UpUy2Xc2wsi1Ras6V40q806WM07rqoCWzXu7Sqy+4=

View File

@@ -0,0 +1,263 @@
package profiling
import (
"runtime"
"sync"
"time"
"github.com/grafana/pyroscope-go"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/proc"
"github.com/zeromicro/go-zero/core/stat"
"github.com/zeromicro/go-zero/core/threading"
)
const (
defaultCheckInterval = time.Second * 10
defaultProfilingDuration = time.Minute * 2
defaultUploadRate = time.Second * 15
)
type (
Config struct {
// Name is the name of the application.
Name string `json:",optional,inherit"`
// ServerAddr is the address of the profiling server.
ServerAddr string
// AuthUser is the username for basic authentication.
AuthUser string `json:",optional"`
// AuthPassword is the password for basic authentication.
AuthPassword string `json:",optional"`
// UploadRate is the duration for which profiling data is uploaded.
UploadRate time.Duration `json:",default=15s"`
// CheckInterval is the interval to check if profiling should start.
CheckInterval time.Duration `json:",default=10s"`
// ProfilingDuration is the duration for which profiling data is collected.
ProfilingDuration time.Duration `json:",default=2m"`
// CpuThreshold the collection is allowed only when the current service cpu > CpuThreshold
CpuThreshold int64 `json:",default=700,range=[0:1000)"`
// ProfileType is the type of profiling to be performed.
ProfileType ProfileType
}
ProfileType struct {
// Logger is a flag to enable or disable logging.
Logger bool `json:",default=false"`
// CPU is a flag to disable CPU profiling.
CPU bool `json:",default=true"`
// Goroutines is a flag to disable goroutine profiling.
Goroutines bool `json:",default=true"`
// Memory is a flag to disable memory profiling.
Memory bool `json:",default=true"`
// Mutex is a flag to disable mutex profiling.
Mutex bool `json:",default=false"`
// Block is a flag to disable block profiling.
Block bool `json:",default=false"`
}
profiler interface {
Start() error
Stop() error
}
pyroscopeProfiler struct {
c Config
profiler *pyroscope.Profiler
}
)
var (
once sync.Once
newProfiler = func(c Config) profiler {
return newPyroscopeProfiler(c)
}
)
// Start initializes the pyroscope profiler with the given configuration.
func Start(c Config) {
// check if the profiling is enabled
if len(c.ServerAddr) == 0 {
return
}
// set default values for the configuration
if c.ProfilingDuration <= 0 {
c.ProfilingDuration = defaultProfilingDuration
}
// set default values for the configuration
if c.CheckInterval <= 0 {
c.CheckInterval = defaultCheckInterval
}
if c.UploadRate <= 0 {
c.UploadRate = defaultUploadRate
}
once.Do(func() {
logx.Info("continuous profiling started")
threading.GoSafe(func() {
startPyroscope(c, proc.Done())
})
})
}
// startPyroscope starts the pyroscope profiler with the given configuration.
func startPyroscope(c Config, done <-chan struct{}) {
var (
pr profiler
err error
latestProfilingTime time.Time
intervalTicker = time.NewTicker(c.CheckInterval)
profilingTicker = time.NewTicker(c.ProfilingDuration)
)
defer profilingTicker.Stop()
defer intervalTicker.Stop()
for {
select {
case <-intervalTicker.C:
// Check if the machine is overloaded and if the profiler is not running
if pr == nil && isCpuOverloaded(c) {
pr = newProfiler(c)
if err := pr.Start(); err != nil {
logx.Errorf("failed to start profiler: %v", err)
continue
}
// record the latest profiling time
latestProfilingTime = time.Now()
logx.Infof("pyroscope profiler started.")
}
case <-profilingTicker.C:
// check if the profiling duration has passed
if !time.Now().After(latestProfilingTime.Add(c.ProfilingDuration)) {
continue
}
// check if the profiler is already running, if so, skip
if pr != nil {
if err = pr.Stop(); err != nil {
logx.Errorf("failed to stop profiler: %v", err)
}
logx.Infof("pyroscope profiler stopped.")
pr = nil
}
case <-done:
logx.Infof("continuous profiling stopped.")
return
}
}
}
// genPyroscopeConf generates the pyroscope configuration based on the given config.
func genPyroscopeConf(c Config) pyroscope.Config {
pConf := pyroscope.Config{
UploadRate: c.UploadRate,
ApplicationName: c.Name,
BasicAuthUser: c.AuthUser, // http basic auth user
BasicAuthPassword: c.AuthPassword, // http basic auth password
ServerAddress: c.ServerAddr,
Logger: nil,
HTTPHeaders: map[string]string{},
// you can provide static tags via a map:
Tags: map[string]string{
"name": c.Name,
},
}
if c.ProfileType.Logger {
pConf.Logger = logx.WithCallerSkip(0)
}
if c.ProfileType.CPU {
pConf.ProfileTypes = append(pConf.ProfileTypes, pyroscope.ProfileCPU)
}
if c.ProfileType.Goroutines {
pConf.ProfileTypes = append(pConf.ProfileTypes, pyroscope.ProfileGoroutines)
}
if c.ProfileType.Memory {
pConf.ProfileTypes = append(pConf.ProfileTypes, pyroscope.ProfileAllocObjects, pyroscope.ProfileAllocSpace,
pyroscope.ProfileInuseObjects, pyroscope.ProfileInuseSpace)
}
if c.ProfileType.Mutex {
pConf.ProfileTypes = append(pConf.ProfileTypes, pyroscope.ProfileMutexCount, pyroscope.ProfileMutexDuration)
}
if c.ProfileType.Block {
pConf.ProfileTypes = append(pConf.ProfileTypes, pyroscope.ProfileBlockCount, pyroscope.ProfileBlockDuration)
}
logx.Infof("applicationName: %s", pConf.ApplicationName)
return pConf
}
// isCpuOverloaded checks the machine performance based on the given configuration.
func isCpuOverloaded(c Config) bool {
currentValue := stat.CpuUsage()
if currentValue >= c.CpuThreshold {
logx.Infof("continuous profiling cpu overload, cpu: %d", currentValue)
return true
}
return false
}
func newPyroscopeProfiler(c Config) profiler {
return &pyroscopeProfiler{
c: c,
}
}
func (p *pyroscopeProfiler) Start() error {
pConf := genPyroscopeConf(p.c)
// set mutex and block profile rate
setFraction(p.c)
prof, err := pyroscope.Start(pConf)
if err != nil {
resetFraction(p.c)
return err
}
p.profiler = prof
return nil
}
func (p *pyroscopeProfiler) Stop() error {
if p.profiler == nil {
return nil
}
if err := p.profiler.Stop(); err != nil {
return err
}
resetFraction(p.c)
p.profiler = nil
return nil
}
func setFraction(c Config) {
// These 2 lines are only required if you're using mutex or block profiling
if c.ProfileType.Mutex {
runtime.SetMutexProfileFraction(10) // 10/seconds
}
if c.ProfileType.Block {
runtime.SetBlockProfileRate(1000 * 1000) // 1/millisecond
}
}
func resetFraction(c Config) {
// These 2 lines are only required if you're using mutex or block profiling
if c.ProfileType.Mutex {
runtime.SetMutexProfileFraction(0)
}
if c.ProfileType.Block {
runtime.SetBlockProfileRate(0)
}
}

View File

@@ -0,0 +1,177 @@
package profiling
import (
"sync"
"testing"
"time"
"github.com/grafana/pyroscope-go"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/conf"
"github.com/zeromicro/go-zero/core/syncx"
)
func TestStart(t *testing.T) {
t.Run("profiling", func(t *testing.T) {
var c Config
assert.NoError(t, conf.FillDefault(&c))
c.Name = "test"
p := newProfiler(c)
assert.NotNil(t, p)
assert.NoError(t, p.Start())
assert.NoError(t, p.Stop())
})
t.Run("invalid config", func(t *testing.T) {
mp := &mockProfiler{}
newProfiler = func(c Config) profiler {
return mp
}
Start(Config{})
Start(Config{
ServerAddr: "localhost:4040",
})
})
t.Run("test start profiler", func(t *testing.T) {
mp := &mockProfiler{}
newProfiler = func(c Config) profiler {
return mp
}
c := Config{
Name: "test",
ServerAddr: "localhost:4040",
CheckInterval: time.Millisecond,
ProfilingDuration: time.Millisecond * 10,
CpuThreshold: 0,
}
var done = make(chan struct{})
go startPyroscope(c, done)
time.Sleep(time.Millisecond * 50)
close(done)
assert.True(t, mp.started.True())
assert.True(t, mp.stopped.True())
})
t.Run("test start profiler with cpu overloaded", func(t *testing.T) {
mp := &mockProfiler{}
newProfiler = func(c Config) profiler {
return mp
}
c := Config{
Name: "test",
ServerAddr: "localhost:4040",
CheckInterval: time.Millisecond,
ProfilingDuration: time.Millisecond * 10,
CpuThreshold: 900,
}
var done = make(chan struct{})
go startPyroscope(c, done)
time.Sleep(time.Millisecond * 50)
close(done)
assert.False(t, mp.started.True())
})
t.Run("start/stop err", func(t *testing.T) {
mp := &mockProfiler{
err: assert.AnError,
}
newProfiler = func(c Config) profiler {
return mp
}
c := Config{
Name: "test",
ServerAddr: "localhost:4040",
CheckInterval: time.Millisecond,
ProfilingDuration: time.Millisecond * 10,
CpuThreshold: 0,
}
var done = make(chan struct{})
go startPyroscope(c, done)
time.Sleep(time.Millisecond * 50)
close(done)
assert.False(t, mp.started.True())
assert.False(t, mp.stopped.True())
})
}
func TestGenPyroscopeConf(t *testing.T) {
c := Config{
Name: "",
ServerAddr: "localhost:4040",
AuthUser: "user",
AuthPassword: "password",
ProfileType: ProfileType{
Logger: true,
CPU: true,
Goroutines: true,
Memory: true,
Mutex: true,
Block: true,
},
}
pyroscopeConf := genPyroscopeConf(c)
assert.Equal(t, c.ServerAddr, pyroscopeConf.ServerAddress)
assert.Equal(t, c.AuthUser, pyroscopeConf.BasicAuthUser)
assert.Equal(t, c.AuthPassword, pyroscopeConf.BasicAuthPassword)
assert.Equal(t, c.Name, pyroscopeConf.ApplicationName)
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileCPU)
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileGoroutines)
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileAllocObjects)
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileAllocSpace)
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileInuseObjects)
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileInuseSpace)
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileMutexCount)
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileMutexDuration)
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileBlockCount)
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileBlockDuration)
setFraction(c)
resetFraction(c)
newPyroscopeProfiler(c)
}
func TestNewPyroscopeProfiler(t *testing.T) {
p := newPyroscopeProfiler(Config{})
assert.Error(t, p.Start())
assert.NoError(t, p.Stop())
}
type mockProfiler struct {
mutex sync.Mutex
started syncx.AtomicBool
stopped syncx.AtomicBool
err error
}
func (m *mockProfiler) Start() error {
m.mutex.Lock()
if m.err == nil {
m.started.Set(true)
}
m.mutex.Unlock()
return m.err
}
func (m *mockProfiler) Stop() error {
m.mutex.Lock()
if m.err == nil {
m.stopped.Set(true)
}
m.mutex.Unlock()
return m.err
}

View File

@@ -173,17 +173,20 @@ func (s *sseMcpServer) handleRequest(w http.ResponseWriter, r *http.Request) {
return
}
w.WriteHeader(http.StatusAccepted)
// For notification methods (no ID), we don't send a response
isNotification := req.ID == 0
isNotification, err := req.isNotification()
if err != nil {
http.Error(w, "Invalid request.ID", http.StatusBadRequest)
}
w.WriteHeader(http.StatusAccepted)
// Special handling for initialization sequence
// Always allow initialize and notifications/initialized regardless of client state
if req.Method == methodInitialize {
logx.Infof("Processing initialize request with ID: %d", req.ID)
logx.Infof("Processing initialize request with ID: %v", req.ID)
s.processInitialize(r.Context(), client, req)
logx.Infof("Sent initialize response for ID: %d, waiting for notifications/initialized", req.ID)
logx.Infof("Sent initialize response for ID: %v, waiting for notifications/initialized", req.ID)
return
} else if req.Method == methodNotificationsInitialized {
// Handle initialized notification
@@ -206,41 +209,41 @@ func (s *sseMcpServer) handleRequest(w http.ResponseWriter, r *http.Request) {
// Process normal requests only after initialization
switch req.Method {
case methodToolsCall:
logx.Infof("Received tools call request with ID: %d", req.ID)
logx.Infof("Received tools call request with ID: %v", req.ID)
s.processToolCall(r.Context(), client, req)
logx.Infof("Sent tools call response for ID: %d", req.ID)
logx.Infof("Sent tools call response for ID: %v", req.ID)
case methodToolsList:
logx.Infof("Processing tools/list request with ID: %d", req.ID)
logx.Infof("Processing tools/list request with ID: %v", req.ID)
s.processListTools(r.Context(), client, req)
logx.Infof("Sent tools/list response for ID: %d", req.ID)
logx.Infof("Sent tools/list response for ID: %v", req.ID)
case methodPromptsList:
logx.Infof("Processing prompts/list request with ID: %d", req.ID)
logx.Infof("Processing prompts/list request with ID: %v", req.ID)
s.processListPrompts(r.Context(), client, req)
logx.Infof("Sent prompts/list response for ID: %d", req.ID)
logx.Infof("Sent prompts/list response for ID: %v", req.ID)
case methodPromptsGet:
logx.Infof("Processing prompts/get request with ID: %d", req.ID)
logx.Infof("Processing prompts/get request with ID: %v", req.ID)
s.processGetPrompt(r.Context(), client, req)
logx.Infof("Sent prompts/get response for ID: %d", req.ID)
logx.Infof("Sent prompts/get response for ID: %v", req.ID)
case methodResourcesList:
logx.Infof("Processing resources/list request with ID: %d", req.ID)
logx.Infof("Processing resources/list request with ID: %v", req.ID)
s.processListResources(r.Context(), client, req)
logx.Infof("Sent resources/list response for ID: %d", req.ID)
logx.Infof("Sent resources/list response for ID: %v", req.ID)
case methodResourcesRead:
logx.Infof("Processing resources/read request with ID: %d", req.ID)
logx.Infof("Processing resources/read request with ID: %v", req.ID)
s.processResourcesRead(r.Context(), client, req)
logx.Infof("Sent resources/read response for ID: %d", req.ID)
logx.Infof("Sent resources/read response for ID: %v", req.ID)
case methodResourcesSubscribe:
logx.Infof("Processing resources/subscribe request with ID: %d", req.ID)
logx.Infof("Processing resources/subscribe request with ID: %v", req.ID)
s.processResourceSubscribe(r.Context(), client, req)
logx.Infof("Sent resources/subscribe response for ID: %d", req.ID)
logx.Infof("Sent resources/subscribe response for ID: %v", req.ID)
case methodPing:
logx.Infof("Processing ping request with ID: %d", req.ID)
logx.Infof("Processing ping request with ID: %v", req.ID)
s.processPing(r.Context(), client, req)
case methodNotificationsCancelled:
logx.Infof("Received notifications/cancelled notification: %d", req.ID)
logx.Infof("Received notifications/cancelled notification: %v", req.ID)
s.processNotificationCancelled(r.Context(), client, req)
default:
logx.Infof("Unknown method: %s from client: %d", req.Method, req.ID)
logx.Infof("Unknown method: %s from client: %v", req.Method, req.ID)
s.sendErrorResponse(r.Context(), client, req.ID, "Method not found", errCodeMethodNotFound)
}
}
@@ -880,10 +883,10 @@ func (s *sseMcpServer) processPing(ctx context.Context, client *mcpClient, req R
// sendErrorResponse sends an error response via the SSE channel
func (s *sseMcpServer) sendErrorResponse(ctx context.Context, client *mcpClient,
id int64, message string, code int) {
id any, message string, code int) {
errorResponse := struct {
JsonRpc string `json:"jsonrpc"`
ID int64 `json:"id"`
ID any `json:"id"`
Error errorMessage `json:"error"`
}{
JsonRpc: jsonRpcVersion,
@@ -898,7 +901,7 @@ func (s *sseMcpServer) sendErrorResponse(ctx context.Context, client *mcpClient,
jsonData, _ := json.Marshal(errorResponse)
// Use CRLF line endings as requested
sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData))
logx.Infof("Sending error for ID %d: %s", id, sseMessage)
logx.Infof("Sending error for ID %v: %s", id, sseMessage)
// cannot receive from ctx.Done() because we're sending to the channel for SSE messages
select {
@@ -910,7 +913,7 @@ func (s *sseMcpServer) sendErrorResponse(ctx context.Context, client *mcpClient,
}
// sendResponse sends a success response via the SSE channel
func (s *sseMcpServer) sendResponse(ctx context.Context, client *mcpClient, id int64, result any) {
func (s *sseMcpServer) sendResponse(ctx context.Context, client *mcpClient, id any, result any) {
response := Response{
JsonRpc: jsonRpcVersion,
ID: id,
@@ -925,13 +928,13 @@ func (s *sseMcpServer) sendResponse(ctx context.Context, client *mcpClient, id i
// Use CRLF line endings as requested
sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData))
logx.Infof("Sending response for ID %d: %s", id, sseMessage)
logx.Infof("Sending response for ID %v: %s", id, sseMessage)
// cannot receive from ctx.Done() because we're sending to the channel for SSE messages
select {
case client.channel <- sseMessage:
default:
// Channel buffer is full, log warning and continue
logx.Infof("Client %s channel is full while sending response with ID %d", client.id, id)
logx.Infof("Client %s channel is full while sending response with ID %v", client.id, id)
}
}

View File

@@ -175,6 +175,20 @@ func TestHandleRequest_badRequest(t *testing.T) {
mock.server.handleRequest(w, r)
assert.Equal(t, http.StatusBadRequest, w.Code)
})
t.Run("bad id", func(t *testing.T) {
mock := newMockMcpServer(t)
defer mock.shutdown()
addTestClient(mock.server, "test-session", true)
body := `{"jsonrpc": "2.0", "id": {}, "method": "tools.call", "params": {}}`
r := httptest.NewRequest(http.MethodPost, "/?session_id=test-session", bytes.NewReader([]byte(body)))
w := httptest.NewRecorder()
mock.server.handleRequest(w, r)
assert.Equal(t, http.StatusBadRequest, w.Code)
assert.Contains(t, w.Body.String(), "Invalid request.ID")
})
}
func TestRegisterTool(t *testing.T) {

View File

@@ -3,6 +3,7 @@ package mcp
import (
"context"
"encoding/json"
"fmt"
"sync"
"github.com/zeromicro/go-zero/rest"
@@ -15,11 +16,28 @@ type Cursor string
type Request struct {
SessionId string `form:"session_id"` // Session identifier for client tracking
JsonRpc string `json:"jsonrpc"` // Must be "2.0" per JSON-RPC spec
ID int64 `json:"id"` // Request identifier for matching responses
ID any `json:"id"` // Request identifier for matching responses
Method string `json:"method"` // Method name to invoke
Params json.RawMessage `json:"params"` // Parameters for the method
}
func (r Request) isNotification() (bool, error) {
switch val := r.ID.(type) {
case int:
return val == 0, nil
case int64:
return val == 0, nil
case float64:
return val == 0.0, nil
case string:
return len(val) == 0, nil
case nil:
return true, nil
default:
return false, fmt.Errorf("invalid type %T", val)
}
}
type PaginatedParams struct {
Cursor string `json:"cursor"`
Meta struct {
@@ -244,7 +262,7 @@ type errorObj struct {
// Response represents a JSON-RPC response
type Response struct {
JsonRpc string `json:"jsonrpc"` // Always "2.0"
ID int64 `json:"id"` // Same as request ID
ID any `json:"id"` // Same as request ID
Result any `json:"result"` // Result object (null if error)
Error *errorObj `json:"error,omitempty"` // Error object (null if success)
}

View File

@@ -3,6 +3,7 @@ package mcp
import (
"context"
"encoding/json"
"errors"
"testing"
"github.com/stretchr/testify/assert"
@@ -55,7 +56,7 @@ func TestRequestUnmarshaling(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, "2.0", req.JsonRpc)
assert.Equal(t, int64(789), req.ID)
assert.Equal(t, float64(789), req.ID)
assert.Equal(t, "test_method", req.Method)
// Check params unmarshaled correctly
@@ -204,3 +205,67 @@ func TestCallToolResult(t *testing.T) {
assert.Contains(t, string(data), `"content":[{"text":"Sample result"}]`)
assert.NotContains(t, string(data), `"isError":`)
}
func TestRequest_isNotification(t *testing.T) {
tests := []struct {
name string
id any
want bool
wantErr error
}{
// integer test cases
{name: "int zero", id: 0, want: true, wantErr: nil},
{name: "int non-zero", id: 1, want: false, wantErr: nil},
{name: "int64 zero", id: int64(0), want: true, wantErr: nil},
{name: "int64 max", id: int64(9223372036854775807), want: false, wantErr: nil},
// floating point number test cases
{name: "float64 zero", id: float64(0.0), want: true, wantErr: nil},
{name: "float64 positive", id: float64(0.000001), want: false, wantErr: nil},
{name: "float64 negative", id: float64(-0.000001), want: false, wantErr: nil},
{name: "float64 epsilon", id: float64(1e-300), want: false, wantErr: nil},
// string test cases
{name: "empty string", id: "", want: true, wantErr: nil},
{name: "non-empty string", id: "abc", want: false, wantErr: nil},
{name: "space string", id: " ", want: false, wantErr: nil},
{name: "unicode string", id: "こんにちは", want: false, wantErr: nil},
// special cases
{name: "nil", id: nil, want: true, wantErr: nil},
// logical type test cases
{name: "bool true", id: true, want: false, wantErr: errors.New("invalid type bool")},
{name: "bool false", id: false, want: false, wantErr: errors.New("invalid type bool")},
{name: "struct type", id: struct{}{}, want: false, wantErr: errors.New("invalid type struct {}")},
{name: "slice type", id: []int{1, 2, 3}, want: false, wantErr: errors.New("invalid type []int")},
{name: "map type", id: map[string]int{"a": 1}, want: false, wantErr: errors.New("invalid type map[string]int")},
{name: "pointer type", id: new(int), want: false, wantErr: errors.New("invalid type *int")},
{name: "func type", id: func() {}, want: false, wantErr: errors.New("invalid type func()")},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := Request{
SessionId: "test-session",
JsonRpc: "2.0",
ID: tt.id,
Method: "testMethod",
Params: json.RawMessage(`{}`),
}
got, err := req.isNotification()
if (err != nil) != (tt.wantErr != nil) {
t.Fatalf("error presence mismatch: got error = %v, wantErr %v", err, tt.wantErr)
}
if err != nil && tt.wantErr != nil && err.Error() != tt.wantErr.Error() {
t.Fatalf("error message mismatch:\ngot %q\nwant %q", err.Error(), tt.wantErr.Error())
}
if got != tt.want {
t.Errorf("isNotification() = %v, want %v for ID %v (%T)", got, tt.want, tt.id, tt.id)
}
})
}
}

View File

@@ -6,7 +6,6 @@
[English](readme.md) | 简体中文
[![Go](https://github.com/zeromicro/go-zero/workflows/Go/badge.svg?branch=master)](https://github.com/zeromicro/go-zero/actions)
[![Go Report Card](https://goreportcard.com/badge/github.com/zeromicro/go-zero)](https://goreportcard.com/report/github.com/zeromicro/go-zero)
[![goproxy](https://goproxy.cn/stats/github.com/zeromicro/go-zero/badges/download-count.svg)](https://goproxy.cn/stats/github.com/zeromicro/go-zero/badges/download-count.svg)
[![codecov](https://codecov.io/gh/zeromicro/go-zero/branch/master/graph/badge.svg)](https://codecov.io/gh/zeromicro/go-zero)
@@ -304,6 +303,7 @@ go-zero 已被许多公司用于生产部署,接入场景如在线教育、电
>105. 昆仑万维科技股份有限公司
>106. 无锡盛算信息技术有限公司
>107. 深圳市聚货通信息科技有限公司
>108. 浙江银盾云科技有限公司
如果贵公司也已使用 go-zero欢迎在 [登记地址](https://github.com/zeromicro/go-zero/issues/602) 登记,仅仅为了推广,不做其它用途。

View File

@@ -7,7 +7,6 @@ go-zero is a web and rpc framework with lots of builtin engineering practices. I
<div align=center>
[![Go](https://github.com/zeromicro/go-zero/workflows/Go/badge.svg?branch=master)](https://github.com/zeromicro/go-zero/actions)
[![codecov](https://codecov.io/gh/zeromicro/go-zero/branch/master/graph/badge.svg)](https://codecov.io/gh/zeromicro/go-zero)
[![Go Report Card](https://goreportcard.com/badge/github.com/zeromicro/go-zero)](https://goreportcard.com/report/github.com/zeromicro/go-zero)
[![Release](https://img.shields.io/github/v/release/zeromicro/go-zero.svg?style=flat-square)](https://github.com/zeromicro/go-zero)
@@ -251,7 +250,3 @@ go-zero enlisted in the [CNCF Cloud Native Landscape](https://landscape.cncf.io/
## Give a Star! ⭐
If you like this project or are using it to learn or start your own solution, give it a star to get updates on new releases. Your support matters!
## Buy me a coffee
<a href="https://www.buymeacoffee.com/kevwan" target="_blank"><img src="https://cdn.buymeacoffee.com/buttons/v2/default-yellow.png" alt="Buy Me A Coffee" style="height: 60px !important;width: 217px !important;" ></a>

View File

@@ -28,7 +28,10 @@ var ErrSignatureConfig = errors.New("bad config for Signature")
type engine struct {
conf RestConf
routes []featuredRoutes
// timeout is the max timeout of all routes
// timeout is the max timeout of all routes,
// and is used to set http.Server.ReadTimeout and http.Server.WriteTimeout.
// this network timeout is used to avoid DoS attacks by sending data slowly
// or receiving data slowly with many connections to exhaust server resources.
timeout time.Duration
unauthorizedCallback handler.UnauthorizedCallback
unsignedCallback handler.UnsignedCallback
@@ -60,11 +63,7 @@ func (ng *engine) addRoutes(r featuredRoutes) {
}
ng.routes = append(ng.routes, r)
// need to guarantee the timeout is the max of all routes
// otherwise impossible to set http.Server.ReadTimeout & WriteTimeout
if r.timeout > ng.timeout {
ng.timeout = r.timeout
}
ng.mightUpdateTimeout(r)
}
func buildSSERoutes(routes []Route) []Route {
@@ -192,11 +191,12 @@ func (ng *engine) checkedMaxBytes(bytes int64) int64 {
return ng.conf.MaxBytes
}
func (ng *engine) checkedTimeout(timeout time.Duration) time.Duration {
if timeout > 0 {
return timeout
func (ng *engine) checkedTimeout(timeout *time.Duration) time.Duration {
if timeout != nil {
return *timeout
}
// if timeout not set in featured routes, use global timeout
return time.Duration(ng.conf.Timeout) * time.Millisecond
}
@@ -232,6 +232,28 @@ func (ng *engine) hasTimeout() bool {
return ng.conf.Middlewares.Timeout && ng.timeout > 0
}
// mightUpdateTimeout checks if the route timeout is greater than the current,
// and updates the engine's timeout accordingly.
func (ng *engine) mightUpdateTimeout(r featuredRoutes) {
// if global timeout is set to 0, it means no need to set read/write timeout
// if route timeout is nil, no need to update ng.timeout
if ng.timeout == 0 || r.timeout == nil {
return
}
// if route timeout is 0 (means no timeout), cannot set read/write timeout
if *r.timeout == 0 {
ng.timeout = 0
return
}
// need to guarantee the timeout is the max of all routes
// otherwise impossible to set http.Server.ReadTimeout & WriteTimeout
if *r.timeout > ng.timeout {
ng.timeout = *r.timeout
}
}
// notFoundHandler returns a middleware that handles 404 not found requests.
func (ng *engine) notFoundHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -333,7 +355,7 @@ func (ng *engine) start(router httpx.Router, opts ...StartOption) error {
}
// make sure user defined options overwrite default options
opts = append([]StartOption{ng.withTimeout()}, opts...)
opts = append([]StartOption{ng.withNetworkTimeout()}, opts...)
if len(ng.conf.CertFile) == 0 && len(ng.conf.KeyFile) == 0 {
return internal.StartHttp(ng.conf.Host, ng.conf.Port, router, opts...)
@@ -356,7 +378,7 @@ func (ng *engine) use(middleware Middleware) {
ng.middlewares = append(ng.middlewares, middleware)
}
func (ng *engine) withTimeout() internal.StartOption {
func (ng *engine) withNetworkTimeout() internal.StartOption {
return func(svr *http.Server) {
if !ng.hasTimeout() {
return

View File

@@ -73,7 +73,17 @@ Verbose: true
Path: "/",
Handler: func(w http.ResponseWriter, r *http.Request) {},
}},
timeout: time.Minute,
timeout: ptrOfDuration(time.Minute),
},
{
jwt: jwtSetting{},
signature: signatureSetting{},
routes: []Route{{
Method: http.MethodGet,
Path: "/",
Handler: func(w http.ResponseWriter, r *http.Request) {},
}},
timeout: ptrOfDuration(0),
},
{
priority: true,
@@ -84,7 +94,7 @@ Verbose: true
Path: "/",
Handler: func(w http.ResponseWriter, r *http.Request) {},
}},
timeout: time.Second,
timeout: ptrOfDuration(time.Second),
},
{
priority: true,
@@ -227,8 +237,12 @@ Verbose: true
}))
timeout := time.Second * 3
if route.timeout > timeout {
timeout = route.timeout
if route.timeout != nil {
if *route.timeout == 0 {
timeout = 0
} else if *route.timeout > timeout {
timeout = *route.timeout
}
}
assert.Equal(t, timeout, ng.timeout)
})
@@ -236,10 +250,69 @@ Verbose: true
}
}
func TestNewEngine_unsignedCallback(t *testing.T) {
priKeyfile, err := fs.TempFilenameWithText(priKey)
assert.Nil(t, err)
defer os.Remove(priKeyfile)
yaml := `Name: foo
Host: localhost
Port: 0
Middlewares:
Log: false
`
route := featuredRoutes{
priority: true,
jwt: jwtSetting{
enabled: true,
},
signature: signatureSetting{
enabled: true,
SignatureConf: SignatureConf{
Strict: true,
PrivateKeys: []PrivateKeyConf{
{
Fingerprint: "a",
KeyFile: priKeyfile,
},
},
},
},
routes: []Route{{
Method: http.MethodGet,
Path: "/",
Handler: func(w http.ResponseWriter, r *http.Request) {},
}},
}
var index int32
t.Run(fmt.Sprintf("%s-%v", yaml, route.routes), func(t *testing.T) {
var cnf RestConf
assert.Nil(t, conf.LoadFromYamlBytes([]byte(yaml), &cnf))
ng := newEngine(cnf)
if atomic.AddInt32(&index, 1)%2 == 0 {
ng.setUnsignedCallback(func(w http.ResponseWriter, r *http.Request,
next http.Handler, strict bool, code int) {
})
}
ng.addRoutes(route)
ng.use(func(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, r)
}
})
assert.NotNil(t, ng.start(mockedRouter{}, func(svr *http.Server) {
}))
assert.Equal(t, time.Duration(time.Second*3), ng.timeout)
})
}
func TestEngine_checkedTimeout(t *testing.T) {
tests := []struct {
name string
timeout time.Duration
timeout *time.Duration
expect time.Duration
}{
{
@@ -248,17 +321,17 @@ func TestEngine_checkedTimeout(t *testing.T) {
},
{
name: "less",
timeout: time.Millisecond * 500,
timeout: ptrOfDuration(time.Millisecond * 500),
expect: time.Millisecond * 500,
},
{
name: "equal",
timeout: time.Second,
timeout: ptrOfDuration(time.Second),
expect: time.Second,
},
{
name: "more",
timeout: time.Millisecond * 1500,
timeout: ptrOfDuration(time.Millisecond * 1500),
expect: time.Millisecond * 1500,
},
}
@@ -401,7 +474,7 @@ func TestEngine_withTimeout(t *testing.T) {
},
})
svr := &http.Server{}
ng.withTimeout()(svr)
ng.withNetworkTimeout()(svr)
assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*4/5, svr.ReadTimeout)
assert.Equal(t, time.Duration(0), svr.ReadHeaderTimeout)
@@ -451,7 +524,7 @@ func TestEngine_ReadWriteTimeout(t *testing.T) {
},
})
svr := &http.Server{}
ng.withTimeout()(svr)
ng.withNetworkTimeout()(svr)
assert.Equal(t, time.Duration(0), svr.ReadHeaderTimeout)
assert.Equal(t, time.Duration(0), svr.IdleTimeout)

View File

@@ -106,8 +106,8 @@ func (h *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
case <-ctx.Done():
tw.mu.Lock()
defer tw.mu.Unlock()
// there isn't any user-defined middleware before TimoutHandler,
// so we can guarantee that cancelation in biz related code won't come here.
// there isn't any user-defined middleware before TimeoutHandler,
// so we can guarantee that cancellation in biz related code won't come here.
httpx.ErrorCtx(r.Context(), w, ctx.Err(), func(w http.ResponseWriter, err error) {
if errors.Is(err, context.Canceled) {
w.WriteHeader(statusClientClosedRequest)
@@ -151,7 +151,7 @@ func (tw *timeoutWriter) Flush() {
flusher.Flush()
}
// Header returns the underline temporary http.Header.
// Header returns the underlying temporary http.Header.
func (tw *timeoutWriter) Header() http.Header {
return tw.h
}

View File

@@ -119,6 +119,16 @@ func (s *Server) Use(middleware Middleware) {
s.ngin.use(middleware)
}
// build builds the Server and binds the routes to the router.
func (s *Server) build() error {
return s.ngin.bindRoutes(s.router)
}
// serve serves the HTTP requests using the Server's router.
func (s *Server) serve(w http.ResponseWriter, r *http.Request) {
s.router.ServeHTTP(w, r)
}
// ToMiddleware converts the given handler to a Middleware.
func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware {
return func(handle http.HandlerFunc) http.HandlerFunc {
@@ -283,14 +293,14 @@ func WithSignature(signature SignatureConf) RouteOption {
func WithSSE() RouteOption {
return func(r *featuredRoutes) {
r.sse = true
r.timeout = 0
r.timeout = ptrOfDuration(0)
}
}
// WithTimeout returns a RouteOption to set timeout with given value.
func WithTimeout(timeout time.Duration) RouteOption {
return func(r *featuredRoutes) {
r.timeout = timeout
r.timeout = &timeout
}
}
@@ -325,6 +335,10 @@ func handleError(err error) {
panic(err)
}
func ptrOfDuration(d time.Duration) *time.Duration {
return &d
}
func validateSecret(secret string) {
if len(secret) < 8 {
panic("secret's length can't be less than 8")

View File

@@ -345,7 +345,7 @@ func TestWithPriority(t *testing.T) {
func TestWithTimeout(t *testing.T) {
var fr featuredRoutes
WithTimeout(time.Hour)(&fr)
assert.Equal(t, time.Hour, fr.timeout)
assert.Equal(t, time.Hour, *fr.timeout)
}
func TestWithTLSConfig(t *testing.T) {
@@ -819,6 +819,6 @@ func TestServerEmbedFileSystem(t *testing.T) {
// serve(server, w, r)
// // verify the response
func serve(s *Server, w http.ResponseWriter, r *http.Request) {
s.ngin.bindRoutes(s.router)
s.router.ServeHTTP(w, r)
_ = s.build()
s.serve(w, r)
}

27
rest/serverless.go Normal file
View File

@@ -0,0 +1,27 @@
package rest
import "net/http"
// Serverless is a wrapper around Server that allows it to be used in serverless environments.
type Serverless struct {
server *Server
}
// NewServerless creates a new Serverless instance from the provided Server.
func NewServerless(server *Server) (*Serverless, error) {
// Ensure the server is built before using it in a serverless context.
// Why not call server.build() when serving requests,
// is because we need to ensure fail fast behavior.
if err := server.build(); err != nil {
return nil, err
}
return &Serverless{
server: server,
}, nil
}
// Serve handles HTTP requests by delegating them to the underlying Server instance.
func (s *Serverless) Serve(w http.ResponseWriter, r *http.Request) {
s.server.serve(w, r)
}

67
rest/serverless_test.go Normal file
View File

@@ -0,0 +1,67 @@
package rest
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/conf"
"github.com/zeromicro/go-zero/core/logx/logtest"
)
func TestNewServerless(t *testing.T) {
logtest.Discard(t)
const configYaml = `
Name: foo
Host: localhost
Port: 0
`
var cnf RestConf
assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
svr, err := NewServer(cnf)
assert.NoError(t, err)
svr.AddRoute(Route{
Method: http.MethodGet,
Path: "/",
Handler: func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Hello World"))
},
})
serverless, err := NewServerless(svr)
assert.NoError(t, err)
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "/", nil)
serverless.Serve(w, r)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "Hello World", w.Body.String())
}
func TestNewServerlessWithError(t *testing.T) {
logtest.Discard(t)
const configYaml = `
Name: foo
Host: localhost
Port: 0
`
var cnf RestConf
assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
svr, err := NewServer(cnf)
assert.NoError(t, err)
svr.AddRoute(Route{
Method: http.MethodGet,
Path: "notstartwith/",
Handler: nil,
})
_, err = NewServerless(svr)
assert.Error(t, err)
}

View File

@@ -31,7 +31,7 @@ type (
}
featuredRoutes struct {
timeout time.Duration
timeout *time.Duration
priority bool
jwt jwtSetting
signature signatureSetting

View File

@@ -49,6 +49,9 @@ func Parse(tag string) (*Tags, error) {
// Get gets tag value by specified key
func (t *Tags) Get(key string) (*Tag, error) {
if t == nil {
return nil, errTagNotExist
}
for _, tag := range t.tags {
if tag.Key == key {
return tag, nil
@@ -60,6 +63,9 @@ func (t *Tags) Get(key string) (*Tag, error) {
// Keys returns all keys in Tags
func (t *Tags) Keys() []string {
if t == nil {
return []string{}
}
var keys []string
for _, tag := range t.tags {
keys = append(keys, tag.Key)
@@ -69,5 +75,8 @@ func (t *Tags) Keys() []string {
// Tags returns all tags in Tags
func (t *Tags) Tags() []*Tag {
if t == nil {
return []*Tag{}
}
return t.tags
}

View File

@@ -0,0 +1,68 @@
package spec
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestTags_Get(t *testing.T) {
tags := &Tags{
tags: []*Tag{
{Key: "json", Name: "foo", Options: []string{"omitempty"}},
{Key: "xml", Name: "bar", Options: nil},
},
}
tag, err := tags.Get("json")
assert.NoError(t, err)
assert.NotNil(t, tag)
assert.Equal(t, "json", tag.Key)
assert.Equal(t, "foo", tag.Name)
_, err = tags.Get("yaml")
assert.Error(t, err)
var nilTags *Tags
_, err = nilTags.Get("json")
assert.Error(t, err)
}
func TestTags_Keys(t *testing.T) {
tags := &Tags{
tags: []*Tag{
{Key: "json", Name: "foo", Options: []string{"omitempty"}},
{Key: "xml", Name: "bar", Options: nil},
},
}
keys := tags.Keys()
expected := []string{"json", "xml"}
assert.Equal(t, expected, keys)
var nilTags *Tags
nilKeys := nilTags.Keys()
assert.Empty(t, nilKeys)
}
func TestTags_Tags(t *testing.T) {
tags := &Tags{
tags: []*Tag{
{Key: "json", Name: "foo", Options: []string{"omitempty"}},
{Key: "xml", Name: "bar", Options: nil},
},
}
result := tags.Tags()
assert.Len(t, result, 2)
assert.Equal(t, "json", result[0].Key)
assert.Equal(t, "foo", result[0].Name)
assert.Equal(t, []string{"omitempty"}, result[0].Options)
assert.Equal(t, "xml", result[1].Key)
assert.Equal(t, "bar", result[1].Name)
assert.Nil(t, result[1].Options)
var nilTags *Tags
nilResult := nilTags.Tags()
assert.Empty(t, nilResult)
}

View File

@@ -172,13 +172,12 @@ func sampleTypeFromGoType(ctx Context, tp apiSpec.Type) string {
}
}
func typeContainsTag(_ Context, structType apiSpec.DefineStruct, tag string) bool {
for _, field := range structType.Members {
tags, _ := apiSpec.Parse(field.Tag)
for _, t := range tags.Tags() {
if t.Key == tag {
return true
}
func typeContainsTag(ctx Context, structType apiSpec.DefineStruct, tag string) bool {
members := expandMembers(ctx, structType)
for _, member := range members {
tags, _ := apiSpec.Parse(member.Tag)
if _, err := tags.Get(tag); err == nil {
return true
}
}
return false

View File

@@ -4,7 +4,7 @@ go 1.21
require (
github.com/DATA-DOG/go-sqlmock v1.5.2
github.com/emicklei/proto v1.14.1
github.com/emicklei/proto v1.14.2
github.com/fatih/structtag v1.2.0
github.com/go-openapi/spec v0.21.1-0.20250328170532-a3928469592e
github.com/go-sql-driver/mysql v1.9.0
@@ -16,7 +16,7 @@ require (
github.com/withfig/autocomplete-tools/integrations/cobra v1.2.1
github.com/zeromicro/antlr v0.0.1
github.com/zeromicro/ddl-parser v1.0.5
github.com/zeromicro/go-zero v1.8.3
github.com/zeromicro/go-zero v1.8.4
golang.org/x/text v0.22.0
google.golang.org/grpc v1.65.0
google.golang.org/protobuf v1.36.5
@@ -25,8 +25,7 @@ require (
require (
filippo.io/edwards25519 v1.1.0 // indirect
github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 // indirect
github.com/alicebob/miniredis/v2 v2.34.0 // indirect
github.com/alicebob/miniredis/v2 v2.35.0 // indirect
github.com/antlr/antlr4/runtime/Go/antlr v0.0.0-20210521184019-c5ad59b459ec // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
@@ -49,6 +48,8 @@ require (
github.com/google/go-cmp v0.6.0 // indirect
github.com/google/gofuzz v1.2.0 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/grafana/pyroscope-go v1.2.2 // indirect
github.com/grafana/pyroscope-go/godeltaprof v0.1.8 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
@@ -72,7 +73,7 @@ require (
github.com/prometheus/client_model v0.6.1 // indirect
github.com/prometheus/common v0.62.0 // indirect
github.com/prometheus/procfs v0.15.1 // indirect
github.com/redis/go-redis/v9 v9.8.0 // indirect
github.com/redis/go-redis/v9 v9.10.0 // indirect
github.com/spaolacci/murmur3 v1.1.0 // indirect
github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778 // indirect
github.com/yuin/gopher-lua v1.1.1 // indirect

View File

@@ -2,10 +2,8 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 h1:uvdUDbHQHO85qeSydJtItA4T55Pw6BtAejd0APRJOCE=
github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc=
github.com/alicebob/miniredis/v2 v2.34.0 h1:mBFWMaJSNL9RwdGRyEDoAAv8OQc5UlEhLDQggTglU/0=
github.com/alicebob/miniredis/v2 v2.34.0/go.mod h1:kWShP4b58T1CW0Y5dViCd5ztzrDqRWqM3nksiyXk5s8=
github.com/alicebob/miniredis/v2 v2.35.0 h1:QwLphYqCEAo1eu1TqPRN2jgVMPBweeQcR21jeqDCONI=
github.com/alicebob/miniredis/v2 v2.35.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM=
github.com/antlr/antlr4/runtime/Go/antlr v0.0.0-20210521184019-c5ad59b459ec h1:EEyRvzmpEUZ+I8WmD5cw/vY8EqhambkOqy5iFr0908A=
github.com/antlr/antlr4/runtime/Go/antlr v0.0.0-20210521184019-c5ad59b459ec/go.mod h1:F7bn7fEU90QkQ3tnmaTx3LTKLEDqnwWODIYppRQ5hnY=
github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8=
@@ -32,8 +30,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/emicklei/go-restful/v3 v3.11.0 h1:rAQeMHw1c7zTmncogyy8VvRZwtkmkZ4FxERmMY4rD+g=
github.com/emicklei/go-restful/v3 v3.11.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc=
github.com/emicklei/proto v1.14.1 h1:fFq+Bj70XXZWXWikcVRvYZxrMS4KIIiPAqdJ8vPrenY=
github.com/emicklei/proto v1.14.1/go.mod h1:rn1FgRS/FANiZdD2djyH7TMA9jdRDcYQ9IEN9yvjX0A=
github.com/emicklei/proto v1.14.2 h1:wJPxPy2Xifja9cEMrcA/g08art5+7CGJNFNk35iXC1I=
github.com/emicklei/proto v1.14.2/go.mod h1:rn1FgRS/FANiZdD2djyH7TMA9jdRDcYQ9IEN9yvjX0A=
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU=
github.com/fatih/structtag v1.2.0 h1:/OdNE99OxoI/PqaW/SuSK9uxxT3f/tcSZgon/ssNSx4=
@@ -77,6 +75,10 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gookit/color v1.5.4 h1:FZmqs7XOyGgCAxmWyPslpiok1k05wmY3SJTytgvYFs0=
github.com/gookit/color v1.5.4/go.mod h1:pZJOeOS8DM43rXbp4AZo1n9zCU2qjpcRko0b6/QJi9w=
github.com/grafana/pyroscope-go v1.2.2 h1:uvKCyZMD724RkaCEMrSTC38Yn7AnFe8S2wiAIYdDPCE=
github.com/grafana/pyroscope-go v1.2.2/go.mod h1:zzT9QXQAp2Iz2ZdS216UiV8y9uXJYQiGE1q8v1FyhqU=
github.com/grafana/pyroscope-go/godeltaprof v0.1.8 h1:iwOtYXeeVSAeYefJNaxDytgjKtUuKQbJqgAIjlnicKg=
github.com/grafana/pyroscope-go/godeltaprof v0.1.8/go.mod h1:2+l7K7twW49Ct4wFluZD3tZ6e0SjanjcUUBPVD/UuGU=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 h1:bkypFPDjIYGfCYD5mRBvpqxfYX1YCS1PXdKYWi8FsN0=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0/go.mod h1:P+Lt/0by1T8bfcF3z737NnSbmxQAppXMRziHUxPOC8k=
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw=
@@ -146,8 +148,8 @@ github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
github.com/redis/go-redis/v9 v9.8.0 h1:q3nRvjrlge/6UD7eTu/DSg2uYiU2mCL0G/uzBWqhicI=
github.com/redis/go-redis/v9 v9.8.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
github.com/redis/go-redis/v9 v9.10.0 h1:FxwK3eV8p/CQa0Ch276C7u2d0eNC9kCmAYQ7mCXCzVs=
github.com/redis/go-redis/v9 v9.10.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
@@ -183,8 +185,8 @@ github.com/zeromicro/antlr v0.0.1 h1:CQpIn/dc0pUjgGQ81y98s/NGOm2Hfru2NNio2I9mQgk
github.com/zeromicro/antlr v0.0.1/go.mod h1:nfpjEwFR6Q4xGDJMcZnCL9tEfQRgszMwu3rDz2Z+p5M=
github.com/zeromicro/ddl-parser v1.0.5 h1:LaVqHdzMTjasua1yYpIYaksxKqRzFrEukj2Wi2EbWaQ=
github.com/zeromicro/ddl-parser v1.0.5/go.mod h1:ISU/8NuPyEpl9pa17Py9TBPetMjtsiHrb9f5XGiYbo8=
github.com/zeromicro/go-zero v1.8.3 h1:AwpBJQLAsZAt4OOnK0eR8UU1Ja2RFBIXfKkHdnXQKfc=
github.com/zeromicro/go-zero v1.8.3/go.mod h1:EnuEA3XdIQvAvc4WWTskRTO0jM2/aQi7OXv1gKWRNJ0=
github.com/zeromicro/go-zero v1.8.4 h1:3s7kOoThCnkDoqCafsqSX58Y9osYTBIa5QEmomw07TE=
github.com/zeromicro/go-zero v1.8.4/go.mod h1:eM5f6If/RF+jG1wSCmlvfXD2h2l23vJwETI8oDpjYt4=
go.etcd.io/etcd/api/v3 v3.5.15 h1:3KpLJir1ZEBrYuV2v+Twaa/e2MdDCEZ/70H+lzEiwsk=
go.etcd.io/etcd/api/v3 v3.5.15/go.mod h1:N9EhGzXq58WuMllgH9ZvnEr7SI9pS0k0+DHZezGp7jM=
go.etcd.io/etcd/client/pkg/v3 v3.5.15 h1:fo0HpWz/KlHGMCC+YejpiCmyWDEuIpnTDzpJLB5fWlA=

View File

@@ -206,6 +206,7 @@
"short": "Generate mongo model",
"type": "Specified model type name",
"cache": "Generate code with cache [optional]",
"prefix": "Generate code with cache prefix [optional]",
"easy": "Generate code with auto generated CollectionName for easy declare [optional]",
"dir": "{{.goctl.model.dir}}",
"style": "{{.global.style}}",

View File

@@ -6,7 +6,7 @@ import (
)
// BuildVersion is the version of goctl.
const BuildVersion = "1.8.4-beta"
const BuildVersion = "1.8.4"
var tag = map[string]int{"pre-alpha": 0, "alpha": 1, "pre-bata": 2, "beta": 3, "released": 4, "": 5}

View File

@@ -62,6 +62,7 @@ func init() {
mongoCmdFlags.StringSliceVarP(&mongo.VarStringSliceType, "type", "t")
mongoCmdFlags.BoolVarP(&mongo.VarBoolCache, "cache", "c")
mongoCmdFlags.StringVarP(&mongo.VarStringPrefix, "prefix", "p")
mongoCmdFlags.BoolVarP(&mongo.VarBoolEasy, "easy", "e")
mongoCmdFlags.StringVarP(&mongo.VarStringDir, "dir", "d")
mongoCmdFlags.StringVar(&mongo.VarStringStyle, "style")

View File

@@ -17,6 +17,7 @@ import (
type Context struct {
Types []string
Cache bool
Prefix string
Easy bool
Output string
Cfg *config.Config
@@ -60,6 +61,7 @@ func generateModel(ctx *Context) error {
"Type": stringx.From(t).Title(),
"lowerType": stringx.From(t).Untitle(),
"Cache": ctx.Cache,
"Prefix": ctx.Prefix,
"version": version.BuildVersion,
}, output, true); err != nil {
return err

View File

@@ -19,6 +19,8 @@ var (
VarStringDir string
// VarBoolCache describes whether cache is enabled.
VarBoolCache bool
// VarStringPrefix string describes the prefix for the cache key.
VarStringPrefix string
// VarBoolEasy describes whether to generate Collection Name in the code for easy declare.
VarBoolEasy bool
// VarStringStyle describes the style.
@@ -35,6 +37,7 @@ var (
func Action(_ *cobra.Command, _ []string) error {
tp := VarStringSliceType
c := VarBoolCache
p := VarStringPrefix
easy := VarBoolEasy
o := strings.TrimSpace(VarStringDir)
s := VarStringStyle
@@ -74,6 +77,7 @@ func Action(_ *cobra.Command, _ []string) error {
return generate.Do(&generate.Context{
Types: tp,
Cache: c,
Prefix: p,
Easy: easy,
Output: a,
Cfg: cfg,

View File

@@ -13,7 +13,7 @@ import (
"go.mongodb.org/mongo-driver/mongo"
)
{{if .Cache}}var prefix{{.Type}}CacheKey = "cache:{{.lowerType}}:"{{end}}
{{if .Cache}}var prefix{{.Type}}CacheKey = "{{if .Prefix}}{{.Prefix}}:{{end}}cache:{{.lowerType}}:"{{end}}
type {{.lowerType}}Model interface{
Insert(ctx context.Context,data *{{.Type}}) error