mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-05-27 00:25:29 +08:00
Compare commits
19 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
888551627c | ||
|
|
bd623aaac3 | ||
|
|
9e6c2ba2c0 | ||
|
|
c0db8d017d | ||
|
|
52b4f8ca91 | ||
|
|
4884a7b3c6 | ||
|
|
3c6951577d | ||
|
|
fcd15c9b17 | ||
|
|
155e6061cb | ||
|
|
dda7666097 | ||
|
|
c954568b61 | ||
|
|
c2acc43a52 | ||
|
|
1a1a6f5239 | ||
|
|
60c7edf8f8 | ||
|
|
7ad86a52f3 | ||
|
|
1e4e5a02b2 | ||
|
|
39540e21d2 | ||
|
|
b321622c95 | ||
|
|
a25cba5380 |
@@ -26,7 +26,8 @@ func DoWithTimeout(fn func() error, timeout time.Duration, opts ...DoOption) err
|
|||||||
ctx, cancel := contextx.ShrinkDeadline(parentCtx, timeout)
|
ctx, cancel := contextx.ShrinkDeadline(parentCtx, timeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
done := make(chan error)
|
// create channel with buffer size 1 to avoid goroutine leak
|
||||||
|
done := make(chan error, 1)
|
||||||
panicChan := make(chan interface{}, 1)
|
panicChan := make(chan interface{}, 1)
|
||||||
go func() {
|
go func() {
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -35,7 +36,6 @@ func DoWithTimeout(fn func() error, timeout time.Duration, opts ...DoOption) err
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
done <- fn()
|
done <- fn()
|
||||||
close(done)
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ package logx
|
|||||||
type LogConf struct {
|
type LogConf struct {
|
||||||
ServiceName string `json:",optional"`
|
ServiceName string `json:",optional"`
|
||||||
Mode string `json:",default=console,options=console|file|volume"`
|
Mode string `json:",default=console,options=console|file|volume"`
|
||||||
|
TimeFormat string `json:",optional"`
|
||||||
Path string `json:",default=logs"`
|
Path string `json:",default=logs"`
|
||||||
Level string `json:",default=info,options=info|error|severe"`
|
Level string `json:",default=info,options=info|error|severe"`
|
||||||
Compress bool `json:",optional"`
|
Compress bool `json:",optional"`
|
||||||
|
|||||||
@@ -32,8 +32,6 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
timeFormat = "2006-01-02T15:04:05.000Z07"
|
|
||||||
|
|
||||||
accessFilename = "access.log"
|
accessFilename = "access.log"
|
||||||
errorFilename = "error.log"
|
errorFilename = "error.log"
|
||||||
severeFilename = "severe.log"
|
severeFilename = "severe.log"
|
||||||
@@ -64,6 +62,7 @@ var (
|
|||||||
// ErrLogServiceNameNotSet is an error that indicates that the service name is not set.
|
// ErrLogServiceNameNotSet is an error that indicates that the service name is not set.
|
||||||
ErrLogServiceNameNotSet = errors.New("log service name must be set")
|
ErrLogServiceNameNotSet = errors.New("log service name must be set")
|
||||||
|
|
||||||
|
timeFormat = "2006-01-02T15:04:05.000Z07"
|
||||||
writeConsole bool
|
writeConsole bool
|
||||||
logLevel uint32
|
logLevel uint32
|
||||||
infoLog io.WriteCloser
|
infoLog io.WriteCloser
|
||||||
@@ -117,6 +116,10 @@ func MustSetup(c LogConf) {
|
|||||||
// we need to allow different service frameworks to initialize logx respectively.
|
// we need to allow different service frameworks to initialize logx respectively.
|
||||||
// the same logic for SetUp
|
// the same logic for SetUp
|
||||||
func SetUp(c LogConf) error {
|
func SetUp(c LogConf) error {
|
||||||
|
if len(c.TimeFormat) > 0 {
|
||||||
|
timeFormat = c.TimeFormat
|
||||||
|
}
|
||||||
|
|
||||||
switch c.Mode {
|
switch c.Mode {
|
||||||
case consoleMode:
|
case consoleMode:
|
||||||
setupWithConsole(c)
|
setupWithConsole(c)
|
||||||
|
|||||||
@@ -43,11 +43,11 @@ type (
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
func newCollection(collection *mgo.Collection) Collection {
|
func newCollection(collection *mgo.Collection, brk breaker.Breaker) Collection {
|
||||||
return &decoratedCollection{
|
return &decoratedCollection{
|
||||||
name: collection.FullName,
|
name: collection.FullName,
|
||||||
collection: collection,
|
collection: collection,
|
||||||
brk: breaker.NewBreaker(),
|
brk: brk,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ func TestNewCollection(t *testing.T) {
|
|||||||
Database: nil,
|
Database: nil,
|
||||||
Name: "foo",
|
Name: "foo",
|
||||||
FullName: "bar",
|
FullName: "bar",
|
||||||
})
|
}, breaker.GetBreaker("localhost"))
|
||||||
assert.Equal(t, "bar", col.(*decoratedCollection).name)
|
assert.Equal(t, "bar", col.(*decoratedCollection).name)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/globalsign/mgo"
|
"github.com/globalsign/mgo"
|
||||||
|
"github.com/tal-tech/go-zero/core/breaker"
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
@@ -20,6 +21,7 @@ type (
|
|||||||
session *concurrentSession
|
session *concurrentSession
|
||||||
db *mgo.Database
|
db *mgo.Database
|
||||||
collection string
|
collection string
|
||||||
|
brk breaker.Breaker
|
||||||
opts []Option
|
opts []Option
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -46,6 +48,7 @@ func NewModel(url, collection string, opts ...Option) (*Model, error) {
|
|||||||
// If name is empty, the database name provided in the dialed URL is used instead
|
// If name is empty, the database name provided in the dialed URL is used instead
|
||||||
db: session.DB(""),
|
db: session.DB(""),
|
||||||
collection: collection,
|
collection: collection,
|
||||||
|
brk: breaker.GetBreaker(url),
|
||||||
opts: opts,
|
opts: opts,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
@@ -66,7 +69,7 @@ func (mm *Model) FindId(id interface{}) (Query, error) {
|
|||||||
|
|
||||||
// GetCollection returns a Collection with given session.
|
// GetCollection returns a Collection with given session.
|
||||||
func (mm *Model) GetCollection(session *mgo.Session) Collection {
|
func (mm *Model) GetCollection(session *mgo.Session) Collection {
|
||||||
return newCollection(mm.db.C(mm.collection).With(session))
|
return newCollection(mm.db.C(mm.collection).With(session), mm.brk)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Insert inserts docs into mm.
|
// Insert inserts docs into mm.
|
||||||
|
|||||||
@@ -250,6 +250,21 @@ func (s *Redis) Eval(script string, keys []string, args ...interface{}) (val int
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// EvalSha is the implementation of redis evalsha command.
|
||||||
|
func (s *Redis) EvalSha(sha string, keys []string, args ...interface{}) (val interface{}, err error) {
|
||||||
|
err = s.brk.DoWithAcceptable(func() error {
|
||||||
|
conn, err := getRedis(s)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
val, err = conn.EvalSha(sha, keys, args...).Result()
|
||||||
|
return err
|
||||||
|
}, acceptable)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Exists is the implementation of redis exists command.
|
// Exists is the implementation of redis exists command.
|
||||||
func (s *Redis) Exists(key string) (val bool, err error) {
|
func (s *Redis) Exists(key string) (val bool, err error) {
|
||||||
err = s.brk.DoWithAcceptable(func() error {
|
err = s.brk.DoWithAcceptable(func() error {
|
||||||
@@ -449,14 +464,14 @@ func (s *Redis) GetBit(key string, offset int64) (val int, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Hdel is the implementation of redis hdel command.
|
// Hdel is the implementation of redis hdel command.
|
||||||
func (s *Redis) Hdel(key, field string) (val bool, err error) {
|
func (s *Redis) Hdel(key string, fields ...string) (val bool, err error) {
|
||||||
err = s.brk.DoWithAcceptable(func() error {
|
err = s.brk.DoWithAcceptable(func() error {
|
||||||
conn, err := getRedis(s)
|
conn, err := getRedis(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
v, err := conn.HDel(key, field).Result()
|
v, err := conn.HDel(key, fields...).Result()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -1032,6 +1047,16 @@ func (s *Redis) Scard(key string) (val int64, err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ScriptLoad is the implementation of redis script load command.
|
||||||
|
func (s *Redis) ScriptLoad(script string) (string, error) {
|
||||||
|
conn, err := getRedis(s)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return conn.ScriptLoad(script).Result()
|
||||||
|
}
|
||||||
|
|
||||||
// Set is the implementation of redis set command.
|
// Set is the implementation of redis set command.
|
||||||
func (s *Redis) Set(key string, value string) error {
|
func (s *Redis) Set(key string, value string) error {
|
||||||
return s.brk.DoWithAcceptable(func() error {
|
return s.brk.DoWithAcceptable(func() error {
|
||||||
@@ -1101,26 +1126,6 @@ func (s *Redis) Sismember(key string, value interface{}) (val bool, err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Srem is the implementation of redis srem command.
|
|
||||||
func (s *Redis) Srem(key string, values ...interface{}) (val int, err error) {
|
|
||||||
err = s.brk.DoWithAcceptable(func() error {
|
|
||||||
conn, err := getRedis(s)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
v, err := conn.SRem(key, values...).Result()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
val = int(v)
|
|
||||||
return nil
|
|
||||||
}, acceptable)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Smembers is the implementation of redis smembers command.
|
// Smembers is the implementation of redis smembers command.
|
||||||
func (s *Redis) Smembers(key string) (val []string, err error) {
|
func (s *Redis) Smembers(key string) (val []string, err error) {
|
||||||
err = s.brk.DoWithAcceptable(func() error {
|
err = s.brk.DoWithAcceptable(func() error {
|
||||||
@@ -1166,6 +1171,31 @@ func (s *Redis) Srandmember(key string, count int) (val []string, err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Srem is the implementation of redis srem command.
|
||||||
|
func (s *Redis) Srem(key string, values ...interface{}) (val int, err error) {
|
||||||
|
err = s.brk.DoWithAcceptable(func() error {
|
||||||
|
conn, err := getRedis(s)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
v, err := conn.SRem(key, values...).Result()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
val = int(v)
|
||||||
|
return nil
|
||||||
|
}, acceptable)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns the string representation of s.
|
||||||
|
func (s *Redis) String() string {
|
||||||
|
return s.Addr
|
||||||
|
}
|
||||||
|
|
||||||
// Sunion is the implementation of redis sunion command.
|
// Sunion is the implementation of redis sunion command.
|
||||||
func (s *Redis) Sunion(keys ...string) (val []string, err error) {
|
func (s *Redis) Sunion(keys ...string) (val []string, err error) {
|
||||||
err = s.brk.DoWithAcceptable(func() error {
|
err = s.brk.DoWithAcceptable(func() error {
|
||||||
@@ -1667,20 +1697,6 @@ func (s *Redis) Zunionstore(dest string, store ZStore, keys ...string) (val int6
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// String returns the string representation of s.
|
|
||||||
func (s *Redis) String() string {
|
|
||||||
return s.Addr
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Redis) scriptLoad(script string) (string, error) {
|
|
||||||
conn, err := getRedis(s)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return conn.ScriptLoad(script).Result()
|
|
||||||
}
|
|
||||||
|
|
||||||
func acceptable(err error) bool {
|
func acceptable(err error) bool {
|
||||||
return err == nil || err == red.Nil
|
return err == nil || err == red.Nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -947,13 +947,24 @@ func TestRedisString(t *testing.T) {
|
|||||||
func TestRedisScriptLoad(t *testing.T) {
|
func TestRedisScriptLoad(t *testing.T) {
|
||||||
runOnRedis(t, func(client *Redis) {
|
runOnRedis(t, func(client *Redis) {
|
||||||
client.Ping()
|
client.Ping()
|
||||||
_, err := NewRedis(client.Addr, "").scriptLoad("foo")
|
_, err := NewRedis(client.Addr, "").ScriptLoad("foo")
|
||||||
assert.NotNil(t, err)
|
assert.NotNil(t, err)
|
||||||
_, err = client.scriptLoad("foo")
|
_, err = client.ScriptLoad("foo")
|
||||||
assert.NotNil(t, err)
|
assert.NotNil(t, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRedisEvalSha(t *testing.T) {
|
||||||
|
runOnRedis(t, func(client *Redis) {
|
||||||
|
client.Ping()
|
||||||
|
scriptHash, err := client.ScriptLoad(`return redis.call("EXISTS", KEYS[1])`)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
result, err := client.EvalSha(scriptHash, []string{"key1"})
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, int64(0), result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestRedisToPairs(t *testing.T) {
|
func TestRedisToPairs(t *testing.T) {
|
||||||
pairs := toPairs([]red.Z{
|
pairs := toPairs([]red.Z{
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ type (
|
|||||||
ResultHandler func(sql.Result, error)
|
ResultHandler func(sql.Result, error)
|
||||||
|
|
||||||
// A BulkInserter is used to batch insert records.
|
// A BulkInserter is used to batch insert records.
|
||||||
|
// Postgresql is not supported yet, because of the sql is formated with symbol `$`.
|
||||||
BulkInserter struct {
|
BulkInserter struct {
|
||||||
executor *executors.PeriodicalExecutor
|
executor *executors.PeriodicalExecutor
|
||||||
inserter *dbInserter
|
inserter *dbInserter
|
||||||
|
|||||||
@@ -12,14 +12,10 @@ import (
|
|||||||
const slowThreshold = time.Millisecond * 500
|
const slowThreshold = time.Millisecond * 500
|
||||||
|
|
||||||
func exec(conn sessionConn, q string, args ...interface{}) (sql.Result, error) {
|
func exec(conn sessionConn, q string, args ...interface{}) (sql.Result, error) {
|
||||||
stmt, err := format(q, args...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
startTime := timex.Now()
|
startTime := timex.Now()
|
||||||
result, err := conn.Exec(q, args...)
|
result, err := conn.Exec(q, args...)
|
||||||
duration := timex.Since(startTime)
|
duration := timex.Since(startTime)
|
||||||
|
stmt := formatForPrint(q, args)
|
||||||
if duration > slowThreshold {
|
if duration > slowThreshold {
|
||||||
logx.WithDuration(duration).Slowf("[SQL] exec: slowcall - %s", stmt)
|
logx.WithDuration(duration).Slowf("[SQL] exec: slowcall - %s", stmt)
|
||||||
} else {
|
} else {
|
||||||
@@ -33,10 +29,10 @@ func exec(conn sessionConn, q string, args ...interface{}) (sql.Result, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func execStmt(conn stmtConn, args ...interface{}) (sql.Result, error) {
|
func execStmt(conn stmtConn, args ...interface{}) (sql.Result, error) {
|
||||||
stmt := fmt.Sprint(args...)
|
|
||||||
startTime := timex.Now()
|
startTime := timex.Now()
|
||||||
result, err := conn.Exec(args...)
|
result, err := conn.Exec(args...)
|
||||||
duration := timex.Since(startTime)
|
duration := timex.Since(startTime)
|
||||||
|
stmt := fmt.Sprint(args...)
|
||||||
if duration > slowThreshold {
|
if duration > slowThreshold {
|
||||||
logx.WithDuration(duration).Slowf("[SQL] execStmt: slowcall - %s", stmt)
|
logx.WithDuration(duration).Slowf("[SQL] execStmt: slowcall - %s", stmt)
|
||||||
} else {
|
} else {
|
||||||
@@ -50,14 +46,10 @@ func execStmt(conn stmtConn, args ...interface{}) (sql.Result, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func query(conn sessionConn, scanner func(*sql.Rows) error, q string, args ...interface{}) error {
|
func query(conn sessionConn, scanner func(*sql.Rows) error, q string, args ...interface{}) error {
|
||||||
stmt, err := format(q, args...)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
startTime := timex.Now()
|
startTime := timex.Now()
|
||||||
rows, err := conn.Query(q, args...)
|
rows, err := conn.Query(q, args...)
|
||||||
duration := timex.Since(startTime)
|
duration := timex.Since(startTime)
|
||||||
|
stmt := fmt.Sprint(args...)
|
||||||
if duration > slowThreshold {
|
if duration > slowThreshold {
|
||||||
logx.WithDuration(duration).Slowf("[SQL] query: slowcall - %s", stmt)
|
logx.WithDuration(duration).Slowf("[SQL] query: slowcall - %s", stmt)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ func TestStmt_exec(t *testing.T) {
|
|||||||
name string
|
name string
|
||||||
args []interface{}
|
args []interface{}
|
||||||
delay bool
|
delay bool
|
||||||
formatError bool
|
|
||||||
hasError bool
|
hasError bool
|
||||||
err error
|
err error
|
||||||
lastInsertId int64
|
lastInsertId int64
|
||||||
@@ -28,12 +27,6 @@ func TestStmt_exec(t *testing.T) {
|
|||||||
lastInsertId: 1,
|
lastInsertId: 1,
|
||||||
rowsAffected: 2,
|
rowsAffected: 2,
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "wrong format",
|
|
||||||
args: []interface{}{1, 2},
|
|
||||||
formatError: true,
|
|
||||||
hasError: true,
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
name: "exec error",
|
name: "exec error",
|
||||||
args: []interface{}{1},
|
args: []interface{}{1},
|
||||||
@@ -70,18 +63,13 @@ func TestStmt_exec(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, fn := range fns {
|
for _, fn := range fns {
|
||||||
i := i
|
|
||||||
fn := fn
|
fn := fn
|
||||||
t.Run(test.name, func(t *testing.T) {
|
t.Run(test.name, func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
res, err := fn(test.args...)
|
res, err := fn(test.args...)
|
||||||
if i == 0 && test.formatError {
|
if test.hasError {
|
||||||
assert.NotNil(t, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !test.formatError && test.hasError {
|
|
||||||
assert.NotNil(t, err)
|
assert.NotNil(t, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -100,23 +88,16 @@ func TestStmt_exec(t *testing.T) {
|
|||||||
|
|
||||||
func TestStmt_query(t *testing.T) {
|
func TestStmt_query(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
args []interface{}
|
args []interface{}
|
||||||
delay bool
|
delay bool
|
||||||
formatError bool
|
hasError bool
|
||||||
hasError bool
|
err error
|
||||||
err error
|
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "normal",
|
name: "normal",
|
||||||
args: []interface{}{1},
|
args: []interface{}{1},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "wrong format",
|
|
||||||
args: []interface{}{1, 2},
|
|
||||||
formatError: true,
|
|
||||||
hasError: true,
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
name: "query error",
|
name: "query error",
|
||||||
args: []interface{}{1},
|
args: []interface{}{1},
|
||||||
@@ -151,18 +132,13 @@ func TestStmt_query(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, fn := range fns {
|
for _, fn := range fns {
|
||||||
i := i
|
|
||||||
fn := fn
|
fn := fn
|
||||||
t.Run(test.name, func(t *testing.T) {
|
t.Run(test.name, func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
err := fn(test.args...)
|
err := fn(test.args...)
|
||||||
if i == 0 && test.formatError {
|
if test.hasError {
|
||||||
assert.NotNil(t, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !test.formatError && test.hasError {
|
|
||||||
assert.NotNil(t, err)
|
assert.NotNil(t, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -45,6 +45,24 @@ func escape(input string) string {
|
|||||||
return b.String()
|
return b.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func formatForPrint(query string, args ...interface{}) string {
|
||||||
|
if len(args) == 0 {
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
|
||||||
|
var vals []string
|
||||||
|
for _, arg := range args {
|
||||||
|
vals = append(vals, fmt.Sprintf("%q", mapping.Repr(arg)))
|
||||||
|
}
|
||||||
|
|
||||||
|
var b strings.Builder
|
||||||
|
b.WriteByte('[')
|
||||||
|
b.WriteString(strings.Join(vals, ", "))
|
||||||
|
b.WriteByte(']')
|
||||||
|
|
||||||
|
return strings.Join([]string{query, b.String()}, " ")
|
||||||
|
}
|
||||||
|
|
||||||
func format(query string, args ...interface{}) (string, error) {
|
func format(query string, args ...interface{}) (string, error) {
|
||||||
numArgs := len(args)
|
numArgs := len(args)
|
||||||
if numArgs == 0 {
|
if numArgs == 0 {
|
||||||
|
|||||||
@@ -28,3 +28,31 @@ func TestDesensitize_WithoutAccount(t *testing.T) {
|
|||||||
datasource = desensitize(datasource)
|
datasource = desensitize(datasource)
|
||||||
assert.True(t, strings.Contains(datasource, "tcp(111.222.333.44:3306)"))
|
assert.True(t, strings.Contains(datasource, "tcp(111.222.333.44:3306)"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFormatForPrint(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
query string
|
||||||
|
args []interface{}
|
||||||
|
expect string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no args",
|
||||||
|
query: "select user, name from table where id=?",
|
||||||
|
expect: `select user, name from table where id=?`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "one arg",
|
||||||
|
query: "select user, name from table where id=?",
|
||||||
|
args: []interface{}{"kevin"},
|
||||||
|
expect: `select user, name from table where id=? ["kevin"]`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
actual := formatForPrint(test.query, test.args...)
|
||||||
|
assert.Equal(t, test.expect, actual)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -159,7 +159,7 @@ GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/tal-tech/
|
|||||||
|
|
||||||
* API 文档
|
* API 文档
|
||||||
|
|
||||||
[https://www.yuque.com/tal-tech/go-zero](https://www.yuque.com/tal-tech/go-zero)
|
[https://zeromicro.github.io/go-zero](https://zeromicro.github.io/go-zero)
|
||||||
|
|
||||||
* awesome 系列(更多文章见『微服务实践』公众号)
|
* awesome 系列(更多文章见『微服务实践』公众号)
|
||||||
* [快速构建高并发微服务](https://github.com/tal-tech/zero-doc/blob/main/doc/shorturl.md)
|
* [快速构建高并发微服务](https://github.com/tal-tech/zero-doc/blob/main/doc/shorturl.md)
|
||||||
|
|||||||
@@ -210,6 +210,12 @@ go get -u github.com/tal-tech/go-zero
|
|||||||
* [Rapid development of microservice systems - multiple RPCs](https://github.com/tal-tech/zero-doc/blob/main/docs/zero/bookstore-en.md)
|
* [Rapid development of microservice systems - multiple RPCs](https://github.com/tal-tech/zero-doc/blob/main/docs/zero/bookstore-en.md)
|
||||||
* [Examples](https://github.com/zeromicro/zero-examples)
|
* [Examples](https://github.com/zeromicro/zero-examples)
|
||||||
|
|
||||||
## 9. Chat group
|
## 9. Important notes
|
||||||
|
|
||||||
|
* Use grpc 1.29.1, because etcd lib doesn’t support latter versions.
|
||||||
|
|
||||||
|
`google.golang.org/grpc v1.29.1`
|
||||||
|
|
||||||
|
## 10. Chat group
|
||||||
|
|
||||||
Join the chat via https://join.slack.com/t/go-zeroworkspace/shared_invite/zt-m39xssxc-kgIqERa7aVsujKNj~XuPKg
|
Join the chat via https://join.slack.com/t/go-zeroworkspace/shared_invite/zt-m39xssxc-kgIqERa7aVsujKNj~XuPKg
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
|
|
||||||
@@ -138,6 +140,16 @@ func (grw *guardedResponseWriter) Header() http.Header {
|
|||||||
return grw.writer.Header()
|
return grw.writer.Header()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Hijack implements the http.Hijacker interface.
|
||||||
|
// This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it.
|
||||||
|
func (grw *guardedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
|
if hijacked, ok := grw.writer.(http.Hijacker); ok {
|
||||||
|
return hijacked.Hijack()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil, errors.New("server doesn't support hijacking")
|
||||||
|
}
|
||||||
|
|
||||||
func (grw *guardedResponseWriter) Write(body []byte) (int, error) {
|
func (grw *guardedResponseWriter) Write(body []byte) (int, error) {
|
||||||
return grw.writer.Write(body)
|
return grw.writer.Write(body)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -87,6 +89,26 @@ func TestAuthHandler_NilError(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAuthHandler_Flush(t *testing.T) {
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
handler := newGuardedResponseWriter(resp)
|
||||||
|
handler.Flush()
|
||||||
|
assert.True(t, resp.Flushed)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthHandler_Hijack(t *testing.T) {
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
writer := newGuardedResponseWriter(resp)
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
writer.Hijack()
|
||||||
|
})
|
||||||
|
|
||||||
|
writer = newGuardedResponseWriter(mockedHijackable{resp})
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
writer.Hijack()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func buildToken(secretKey string, payloads map[string]interface{}, seconds int64) (string, error) {
|
func buildToken(secretKey string, payloads map[string]interface{}, seconds int64) (string, error) {
|
||||||
now := time.Now().Unix()
|
now := time.Now().Unix()
|
||||||
claims := make(jwt.MapClaims)
|
claims := make(jwt.MapClaims)
|
||||||
@@ -101,3 +123,11 @@ func buildToken(secretKey string, payloads map[string]interface{}, seconds int64
|
|||||||
|
|
||||||
return token.SignedString([]byte(secretKey))
|
return token.SignedString([]byte(secretKey))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type mockedHijackable struct {
|
||||||
|
*httptest.ResponseRecorder
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m mockedHijackable) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/tal-tech/go-zero/core/codec"
|
"github.com/tal-tech/go-zero/core/codec"
|
||||||
@@ -94,6 +96,16 @@ func (w *cryptionResponseWriter) Header() http.Header {
|
|||||||
return w.ResponseWriter.Header()
|
return w.ResponseWriter.Header()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Hijack implements the http.Hijacker interface.
|
||||||
|
// This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it.
|
||||||
|
func (w *cryptionResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
|
if hijacked, ok := w.ResponseWriter.(http.Hijacker); ok {
|
||||||
|
return hijacked.Hijack()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil, errors.New("server doesn't support hijacking")
|
||||||
|
}
|
||||||
|
|
||||||
func (w *cryptionResponseWriter) Write(p []byte) (int, error) {
|
func (w *cryptionResponseWriter) Write(p []byte) (int, error) {
|
||||||
return w.buf.Write(p)
|
return w.buf.Write(p)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -103,3 +103,16 @@ func TestCryptionHandlerFlush(t *testing.T) {
|
|||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String())
|
assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCryptionHandler_Hijack(t *testing.T) {
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
writer := newCryptionResponseWriter(resp)
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
writer.Hijack()
|
||||||
|
})
|
||||||
|
|
||||||
|
writer = newCryptionResponseWriter(mockedHijackable{resp})
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
writer.Hijack()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,10 +1,13 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
"time"
|
"time"
|
||||||
@@ -25,10 +28,26 @@ type loggedResponseWriter struct {
|
|||||||
code int
|
code int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *loggedResponseWriter) Flush() {
|
||||||
|
if flusher, ok := w.w.(http.Flusher); ok {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (w *loggedResponseWriter) Header() http.Header {
|
func (w *loggedResponseWriter) Header() http.Header {
|
||||||
return w.w.Header()
|
return w.w.Header()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Hijack implements the http.Hijacker interface.
|
||||||
|
// This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it.
|
||||||
|
func (w *loggedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
|
if hijacked, ok := w.w.(http.Hijacker); ok {
|
||||||
|
return hijacked.Hijack()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil, errors.New("server doesn't support hijacking")
|
||||||
|
}
|
||||||
|
|
||||||
func (w *loggedResponseWriter) Write(bytes []byte) (int, error) {
|
func (w *loggedResponseWriter) Write(bytes []byte) (int, error) {
|
||||||
return w.w.Write(bytes)
|
return w.w.Write(bytes)
|
||||||
}
|
}
|
||||||
@@ -38,12 +57,6 @@ func (w *loggedResponseWriter) WriteHeader(code int) {
|
|||||||
w.code = code
|
w.code = code
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *loggedResponseWriter) Flush() {
|
|
||||||
if flusher, ok := w.w.(http.Flusher); ok {
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// LogHandler returns a middleware that logs http request and response.
|
// LogHandler returns a middleware that logs http request and response.
|
||||||
func LogHandler(next http.Handler) http.Handler {
|
func LogHandler(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -83,6 +96,16 @@ func (w *detailLoggedResponseWriter) Header() http.Header {
|
|||||||
return w.writer.Header()
|
return w.writer.Header()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Hijack implements the http.Hijacker interface.
|
||||||
|
// This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it.
|
||||||
|
func (w *detailLoggedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
|
if hijacked, ok := w.writer.w.(http.Hijacker); ok {
|
||||||
|
return hijacked.Hijack()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil, errors.New("server doesn't support hijacking")
|
||||||
|
}
|
||||||
|
|
||||||
func (w *detailLoggedResponseWriter) Write(bs []byte) (int, error) {
|
func (w *detailLoggedResponseWriter) Write(bs []byte) (int, error) {
|
||||||
w.buf.Write(bs)
|
w.buf.Write(bs)
|
||||||
return w.writer.Write(bs)
|
return w.writer.Write(bs)
|
||||||
|
|||||||
@@ -62,6 +62,44 @@ func TestLogHandlerSlow(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestLogHandler_Hijack(t *testing.T) {
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
writer := &loggedResponseWriter{
|
||||||
|
w: resp,
|
||||||
|
}
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
writer.Hijack()
|
||||||
|
})
|
||||||
|
|
||||||
|
writer = &loggedResponseWriter{
|
||||||
|
w: mockedHijackable{resp},
|
||||||
|
}
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
writer.Hijack()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDetailedLogHandler_Hijack(t *testing.T) {
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
writer := &detailLoggedResponseWriter{
|
||||||
|
writer: &loggedResponseWriter{
|
||||||
|
w: resp,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
writer.Hijack()
|
||||||
|
})
|
||||||
|
|
||||||
|
writer = &detailLoggedResponseWriter{
|
||||||
|
writer: &loggedResponseWriter{
|
||||||
|
w: mockedHijackable{resp},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
writer.Hijack()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkLogHandler(b *testing.B) {
|
func BenchmarkLogHandler(b *testing.B) {
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,10 @@
|
|||||||
package security
|
package security
|
||||||
|
|
||||||
import "net/http"
|
import (
|
||||||
|
"bufio"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
// A WithCodeResponseWriter is a helper to delay sealing a http.ResponseWriter on writing code.
|
// A WithCodeResponseWriter is a helper to delay sealing a http.ResponseWriter on writing code.
|
||||||
type WithCodeResponseWriter struct {
|
type WithCodeResponseWriter struct {
|
||||||
@@ -20,6 +24,12 @@ func (w *WithCodeResponseWriter) Header() http.Header {
|
|||||||
return w.Writer.Header()
|
return w.Writer.Header()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Hijack implements the http.Hijacker interface.
|
||||||
|
// This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it.
|
||||||
|
func (w *WithCodeResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
|
return w.Writer.(http.Hijacker).Hijack()
|
||||||
|
}
|
||||||
|
|
||||||
// Write writes bytes into w.
|
// Write writes bytes into w.
|
||||||
func (w *WithCodeResponseWriter) Write(bytes []byte) (int, error) {
|
func (w *WithCodeResponseWriter) Write(bytes []byte) (int, error) {
|
||||||
return w.Writer.Write(bytes)
|
return w.Writer.Write(bytes)
|
||||||
|
|||||||
@@ -77,6 +77,8 @@ func (e *Server) AddRoute(r Route, opts ...RouteOption) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Start starts the Server.
|
// Start starts the Server.
|
||||||
|
// Graceful shutdown is enabled by default.
|
||||||
|
// Use proc.SetTimeToForceQuit to customize the graceful shutdown period.
|
||||||
func (e *Server) Start() {
|
func (e *Server) Start() {
|
||||||
handleError(e.opts.start(e.ngin))
|
handleError(e.opts.start(e.ngin))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -103,17 +103,9 @@ func genComponents(dir, packetName string, api *spec.ApiSpec) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *componentsContext) createComponent(dir, packetName string, ty spec.Type) error {
|
func (c *componentsContext) createComponent(dir, packetName string, ty spec.Type) error {
|
||||||
defineStruct, ok := ty.(spec.DefineStruct)
|
defineStruct, done, err := c.checkStruct(ty)
|
||||||
if !ok {
|
if done {
|
||||||
return errors.New("unsupported type %s" + ty.Name())
|
return err
|
||||||
}
|
|
||||||
|
|
||||||
for _, item := range c.requestTypes {
|
|
||||||
if item.Name() == defineStruct.Name() {
|
|
||||||
if len(defineStruct.GetFormMembers())+len(defineStruct.GetBodyMembers()) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
modelFile := util.Title(ty.Name()) + ".java"
|
modelFile := util.Title(ty.Name()) + ".java"
|
||||||
@@ -181,6 +173,22 @@ func (c *componentsContext) createComponent(dir, packetName string, ty spec.Type
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *componentsContext) checkStruct(ty spec.Type) (spec.DefineStruct, bool, error) {
|
||||||
|
defineStruct, ok := ty.(spec.DefineStruct)
|
||||||
|
if !ok {
|
||||||
|
return spec.DefineStruct{}, true, errors.New("unsupported type %s" + ty.Name())
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, item := range c.requestTypes {
|
||||||
|
if item.Name() == defineStruct.Name() {
|
||||||
|
if len(defineStruct.GetFormMembers())+len(defineStruct.GetBodyMembers()) == 0 {
|
||||||
|
return spec.DefineStruct{}, true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return defineStruct, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *componentsContext) buildProperties(defineStruct spec.DefineStruct) (string, error) {
|
func (c *componentsContext) buildProperties(defineStruct spec.DefineStruct) (string, error) {
|
||||||
var builder strings.Builder
|
var builder strings.Builder
|
||||||
if err := c.writeType(&builder, defineStruct); err != nil {
|
if err := c.writeType(&builder, defineStruct); err != nil {
|
||||||
|
|||||||
@@ -95,17 +95,9 @@ func specTypeToJava(tp spec.Type) (string, error) {
|
|||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
switch valueType {
|
s := getBaseType(valueType)
|
||||||
case "int":
|
if len(s) == 0 {
|
||||||
return "Integer[]", nil
|
return s, errors.New("unsupported primitive type " + tp.Name())
|
||||||
case "long":
|
|
||||||
return "Long[]", nil
|
|
||||||
case "float":
|
|
||||||
return "Float[]", nil
|
|
||||||
case "double":
|
|
||||||
return "Double[]", nil
|
|
||||||
case "boolean":
|
|
||||||
return "Boolean[]", nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Sprintf("java.util.ArrayList<%s>", util.Title(valueType)), nil
|
return fmt.Sprintf("java.util.ArrayList<%s>", util.Title(valueType)), nil
|
||||||
@@ -118,6 +110,23 @@ func specTypeToJava(tp spec.Type) (string, error) {
|
|||||||
return "", errors.New("unsupported primitive type " + tp.Name())
|
return "", errors.New("unsupported primitive type " + tp.Name())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getBaseType(valueType string) string {
|
||||||
|
switch valueType {
|
||||||
|
case "int":
|
||||||
|
return "Integer[]"
|
||||||
|
case "long":
|
||||||
|
return "Long[]"
|
||||||
|
case "float":
|
||||||
|
return "Float[]"
|
||||||
|
case "double":
|
||||||
|
return "Double[]"
|
||||||
|
case "boolean":
|
||||||
|
return "Boolean[]"
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func primitiveType(tp string) (string, bool) {
|
func primitiveType(tp string) (string, bool) {
|
||||||
switch tp {
|
switch tp {
|
||||||
case "string":
|
case "string":
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ func (v *ApiVisitor) VisitApi(ctx *api.ApiContext) interface{} {
|
|||||||
for _, each := range ctx.AllSpec() {
|
for _, each := range ctx.AllSpec() {
|
||||||
root := each.Accept(v).(*Api)
|
root := each.Accept(v).(*Api)
|
||||||
v.acceptSyntax(root, &final)
|
v.acceptSyntax(root, &final)
|
||||||
v.accetpImport(root, &final)
|
v.acceptImport(root, &final)
|
||||||
v.acceptInfo(root, &final)
|
v.acceptInfo(root, &final)
|
||||||
v.acceptType(root, &final)
|
v.acceptType(root, &final)
|
||||||
v.acceptService(root, &final)
|
v.acceptService(root, &final)
|
||||||
@@ -133,7 +133,7 @@ func (v *ApiVisitor) acceptInfo(root *Api, final *Api) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *ApiVisitor) accetpImport(root *Api, final *Api) {
|
func (v *ApiVisitor) acceptImport(root *Api, final *Api) {
|
||||||
for _, imp := range root.Import {
|
for _, imp := range root.Import {
|
||||||
if _, ok := final.importM[imp.Value.Text()]; ok {
|
if _, ok := final.importM[imp.Value.Text()]; ok {
|
||||||
v.panic(imp.Import, fmt.Sprintf("duplicate import '%s'", imp.Value.Text()))
|
v.panic(imp.Import, fmt.Sprintf("duplicate import '%s'", imp.Value.Text()))
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ type AtDoc struct {
|
|||||||
Kv []*KvExpr
|
Kv []*KvExpr
|
||||||
}
|
}
|
||||||
|
|
||||||
// AtHandler describes service hander ast for api syntax
|
// AtHandler describes service handler ast for api syntax
|
||||||
type AtHandler struct {
|
type AtHandler struct {
|
||||||
AtHandlerToken Expr
|
AtHandlerToken Expr
|
||||||
Name Expr
|
Name Expr
|
||||||
@@ -630,7 +630,7 @@ func (s *Service) Equal(v interface{}) bool {
|
|||||||
return s.ServiceApi.Equal(service.ServiceApi)
|
return s.ServiceApi.Equal(service.ServiceApi)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get returns the tergate KV by specified key
|
// Get returns the target KV by specified key
|
||||||
func (kv KV) Get(key string) Expr {
|
func (kv KV) Get(key string) Expr {
|
||||||
for _, each := range kv {
|
for _, each := range kv {
|
||||||
if each.Key.Text() == key {
|
if each.Key.Text() == key {
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ type (
|
|||||||
NameExpr() Expr
|
NameExpr() Expr
|
||||||
}
|
}
|
||||||
|
|
||||||
// TypeAlias describes alias ast for api syatax
|
// TypeAlias describes alias ast for api syntax
|
||||||
TypeAlias struct {
|
TypeAlias struct {
|
||||||
Name Expr
|
Name Expr
|
||||||
Assign Expr
|
Assign Expr
|
||||||
@@ -26,7 +26,7 @@ type (
|
|||||||
CommentExpr Expr
|
CommentExpr Expr
|
||||||
}
|
}
|
||||||
|
|
||||||
// TypeStruct describes structure ast for api syatax
|
// TypeStruct describes structure ast for api syntax
|
||||||
TypeStruct struct {
|
TypeStruct struct {
|
||||||
Name Expr
|
Name Expr
|
||||||
Struct Expr
|
Struct Expr
|
||||||
@@ -225,7 +225,7 @@ func (v *ApiVisitor) VisitTypeBlockAlias(ctx *api.TypeBlockAliasContext) interfa
|
|||||||
alias.DocExpr = v.getDoc(ctx)
|
alias.DocExpr = v.getDoc(ctx)
|
||||||
alias.CommentExpr = v.getComment(ctx)
|
alias.CommentExpr = v.getComment(ctx)
|
||||||
// todo: reopen if necessary
|
// todo: reopen if necessary
|
||||||
v.panic(alias.Name, "unsupport alias")
|
v.panic(alias.Name, "unsupported alias")
|
||||||
return &alias
|
return &alias
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -238,7 +238,7 @@ func (v *ApiVisitor) VisitTypeAlias(ctx *api.TypeAliasContext) interface{} {
|
|||||||
alias.DocExpr = v.getDoc(ctx)
|
alias.DocExpr = v.getDoc(ctx)
|
||||||
alias.CommentExpr = v.getComment(ctx)
|
alias.CommentExpr = v.getComment(ctx)
|
||||||
// todo: reopen if necessary
|
// todo: reopen if necessary
|
||||||
v.panic(alias.Name, "unsupport alias")
|
v.panic(alias.Name, "unsupported alias")
|
||||||
return &alias
|
return &alias
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -319,7 +319,7 @@ func (v *ApiVisitor) VisitDataType(ctx *api.DataTypeContext) interface{} {
|
|||||||
if ctx.GetTime() != nil {
|
if ctx.GetTime() != nil {
|
||||||
// todo: reopen if it is necessary
|
// todo: reopen if it is necessary
|
||||||
timeExpr := v.newExprWithToken(ctx.GetTime())
|
timeExpr := v.newExprWithToken(ctx.GetTime())
|
||||||
v.panic(timeExpr, "unsupport time.Time")
|
v.panic(timeExpr, "unsupported time.Time")
|
||||||
return &Time{Literal: timeExpr}
|
return &Time{Literal: timeExpr}
|
||||||
}
|
}
|
||||||
if ctx.PointerType() != nil {
|
if ctx.PointerType() != nil {
|
||||||
|
|||||||
@@ -219,9 +219,9 @@ func (p parser) fillService() error {
|
|||||||
|
|
||||||
for _, astRoute := range item.ServiceApi.ServiceRoute {
|
for _, astRoute := range item.ServiceApi.ServiceRoute {
|
||||||
route := spec.Route{
|
route := spec.Route{
|
||||||
Annotation: spec.Annotation{},
|
AtServerAnnotation: spec.Annotation{},
|
||||||
Method: astRoute.Route.Method.Text(),
|
Method: astRoute.Route.Method.Text(),
|
||||||
Path: astRoute.Route.Path.Text(),
|
Path: astRoute.Route.Path.Text(),
|
||||||
}
|
}
|
||||||
if astRoute.AtHandler != nil {
|
if astRoute.AtHandler != nil {
|
||||||
route.Handler = astRoute.AtHandler.Name.Text()
|
route.Handler = astRoute.AtHandler.Name.Text()
|
||||||
@@ -275,7 +275,7 @@ func (p parser) fillRouteAtServer(astRoute *ast.ServiceRoute, route *spec.Route)
|
|||||||
for _, kv := range astRoute.AtServer.Kv {
|
for _, kv := range astRoute.AtServer.Kv {
|
||||||
properties[kv.Key.Text()] = kv.Value.Text()
|
properties[kv.Key.Text()] = kv.Value.Text()
|
||||||
}
|
}
|
||||||
route.Annotation.Properties = properties
|
route.AtServerAnnotation.Properties = properties
|
||||||
if len(route.Handler) == 0 {
|
if len(route.Handler) == 0 {
|
||||||
route.Handler = properties["handler"]
|
route.Handler = properties["handler"]
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,10 +11,11 @@ import (
|
|||||||
const (
|
const (
|
||||||
bodyTagKey = "json"
|
bodyTagKey = "json"
|
||||||
formTagKey = "form"
|
formTagKey = "form"
|
||||||
|
pathTagKey = "path"
|
||||||
defaultSummaryKey = "summary"
|
defaultSummaryKey = "summary"
|
||||||
)
|
)
|
||||||
|
|
||||||
var definedKeys = []string{bodyTagKey, formTagKey, "path"}
|
var definedKeys = []string{bodyTagKey, formTagKey, pathTagKey}
|
||||||
|
|
||||||
// Routes returns all routes in api service
|
// Routes returns all routes in api service
|
||||||
func (s Service) Routes() []Route {
|
func (s Service) Routes() []Route {
|
||||||
@@ -25,7 +26,7 @@ func (s Service) Routes() []Route {
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tags retuens all tags in Member
|
// Tags returns all tags in Member
|
||||||
func (m Member) Tags() []*Tag {
|
func (m Member) Tags() []*Tag {
|
||||||
tags, err := Parse(m.Tag)
|
tags, err := Parse(m.Tag)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -141,7 +142,7 @@ func (t DefineStruct) GetFormMembers() []Member {
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetNonBodyMembers retruns all have no tag fields
|
// GetNonBodyMembers returns all have no tag fields
|
||||||
func (t DefineStruct) GetNonBodyMembers() []Member {
|
func (t DefineStruct) GetNonBodyMembers() []Member {
|
||||||
var result []Member
|
var result []Member
|
||||||
for _, member := range t.Members {
|
for _, member := range t.Members {
|
||||||
@@ -162,16 +163,16 @@ func (r Route) JoinedDoc() string {
|
|||||||
return strings.TrimSpace(doc)
|
return strings.TrimSpace(doc)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAnnotation returns the value by specified key
|
// GetAnnotation returns the value by specified key from @server
|
||||||
func (r Route) GetAnnotation(key string) string {
|
func (r Route) GetAnnotation(key string) string {
|
||||||
if r.Annotation.Properties == nil {
|
if r.AtServerAnnotation.Properties == nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
return r.Annotation.Properties[key]
|
return r.AtServerAnnotation.Properties[key]
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAnnotation returns the value by specified key
|
// GetAnnotation returns the value by specified key from @server
|
||||||
func (g Group) GetAnnotation(key string) string {
|
func (g Group) GetAnnotation(key string) string {
|
||||||
if g.Annotation.Properties == nil {
|
if g.Annotation.Properties == nil {
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
@@ -63,14 +63,14 @@ type (
|
|||||||
|
|
||||||
// Route describes api route
|
// Route describes api route
|
||||||
Route struct {
|
Route struct {
|
||||||
Annotation Annotation
|
AtServerAnnotation Annotation
|
||||||
Method string
|
Method string
|
||||||
Path string
|
Path string
|
||||||
RequestType Type
|
RequestType Type
|
||||||
ResponseType Type
|
ResponseType Type
|
||||||
Docs Doc
|
Docs Doc
|
||||||
Handler string
|
Handler string
|
||||||
AtDoc AtDoc
|
AtDoc AtDoc
|
||||||
}
|
}
|
||||||
|
|
||||||
// Service describes api service
|
// Service describes api service
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import (
|
|||||||
"github.com/urfave/cli"
|
"github.com/urfave/cli"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TsCommand provides the entry to generting typescript codes
|
// TsCommand provides the entry to generate typescript codes
|
||||||
func TsCommand(c *cli.Context) error {
|
func TsCommand(c *cli.Context) error {
|
||||||
apiFile := c.String("api")
|
apiFile := c.String("api")
|
||||||
dir := c.String("dir")
|
dir := c.String("dir")
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import (
|
|||||||
"github.com/tal-tech/go-zero/tools/goctl/configgen"
|
"github.com/tal-tech/go-zero/tools/goctl/configgen"
|
||||||
"github.com/tal-tech/go-zero/tools/goctl/docker"
|
"github.com/tal-tech/go-zero/tools/goctl/docker"
|
||||||
"github.com/tal-tech/go-zero/tools/goctl/kube"
|
"github.com/tal-tech/go-zero/tools/goctl/kube"
|
||||||
|
"github.com/tal-tech/go-zero/tools/goctl/model/mongo"
|
||||||
model "github.com/tal-tech/go-zero/tools/goctl/model/sql/command"
|
model "github.com/tal-tech/go-zero/tools/goctl/model/sql/command"
|
||||||
"github.com/tal-tech/go-zero/tools/goctl/plugin"
|
"github.com/tal-tech/go-zero/tools/goctl/plugin"
|
||||||
rpc "github.com/tal-tech/go-zero/tools/goctl/rpc/cli"
|
rpc "github.com/tal-tech/go-zero/tools/goctl/rpc/cli"
|
||||||
@@ -28,7 +29,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
buildVersion = "1.1.5"
|
buildVersion = "1.1.6"
|
||||||
commands = []cli.Command{
|
commands = []cli.Command{
|
||||||
{
|
{
|
||||||
Name: "upgrade",
|
Name: "upgrade",
|
||||||
@@ -447,6 +448,29 @@ var (
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: "mongo",
|
||||||
|
Usage: `generate mongo model`,
|
||||||
|
Flags: []cli.Flag{
|
||||||
|
cli.StringSliceFlag{
|
||||||
|
Name: "type, t",
|
||||||
|
Usage: "specified model type name",
|
||||||
|
},
|
||||||
|
cli.BoolFlag{
|
||||||
|
Name: "cache, c",
|
||||||
|
Usage: "generate code with cache [optional]",
|
||||||
|
},
|
||||||
|
cli.StringFlag{
|
||||||
|
Name: "dir, d",
|
||||||
|
Usage: "the target dir",
|
||||||
|
},
|
||||||
|
cli.StringFlag{
|
||||||
|
Name: "style",
|
||||||
|
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Action: mongo.Action,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
69
tools/goctl/model/mongo/generate/generate.go
Normal file
69
tools/goctl/model/mongo/generate/generate.go
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
package generate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/tal-tech/go-zero/tools/goctl/config"
|
||||||
|
"github.com/tal-tech/go-zero/tools/goctl/model/mongo/template"
|
||||||
|
"github.com/tal-tech/go-zero/tools/goctl/util"
|
||||||
|
"github.com/tal-tech/go-zero/tools/goctl/util/format"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Context defines the model generation data what they needs
|
||||||
|
type Context struct {
|
||||||
|
Types []string
|
||||||
|
Cache bool
|
||||||
|
Output string
|
||||||
|
Cfg *config.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do executes model template and output the result into the specified file path
|
||||||
|
func Do(ctx *Context) error {
|
||||||
|
if ctx.Cfg == nil {
|
||||||
|
return errors.New("missing config")
|
||||||
|
}
|
||||||
|
|
||||||
|
err := generateModel(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return generateError(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateModel(ctx *Context) error {
|
||||||
|
for _, t := range ctx.Types {
|
||||||
|
fn, err := format.FileNamingFormat(ctx.Cfg.NamingFormat, t+"_model")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
text, err := util.LoadTemplate(category, modelTemplateFile, template.Text)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
output := filepath.Join(ctx.Output, fn+".go")
|
||||||
|
err = util.With("model").Parse(text).GoFmt(true).SaveTo(map[string]interface{}{
|
||||||
|
"Type": t,
|
||||||
|
"Cache": ctx.Cache,
|
||||||
|
}, output, false)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateError(ctx *Context) error {
|
||||||
|
text, err := util.LoadTemplate(category, errTemplateFile, template.Error)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
output := filepath.Join(ctx.Output, "error.go")
|
||||||
|
|
||||||
|
return util.With("error").Parse(text).GoFmt(true).SaveTo(ctx, output, false)
|
||||||
|
}
|
||||||
34
tools/goctl/model/mongo/generate/generate_test.go
Normal file
34
tools/goctl/model/mongo/generate/generate_test.go
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
package generate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io/ioutil"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/tal-tech/go-zero/tools/goctl/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
var testTypes = `
|
||||||
|
type User struct{}
|
||||||
|
type Class struct{}
|
||||||
|
`
|
||||||
|
|
||||||
|
func TestDo(t *testing.T) {
|
||||||
|
cfg, err := config.NewConfig(config.DefaultFormat)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
typesfile := filepath.Join(tempDir, "types.go")
|
||||||
|
err = ioutil.WriteFile(typesfile, []byte(testTypes), 0666)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
err = Do(&Context{
|
||||||
|
Types: []string{"User", "Class"},
|
||||||
|
Cache: false,
|
||||||
|
Output: tempDir,
|
||||||
|
Cfg: cfg,
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.Nil(t, err)
|
||||||
|
}
|
||||||
55
tools/goctl/model/mongo/generate/template.go
Normal file
55
tools/goctl/model/mongo/generate/template.go
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
package generate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/tal-tech/go-zero/tools/goctl/model/mongo/template"
|
||||||
|
"github.com/tal-tech/go-zero/tools/goctl/util"
|
||||||
|
"github.com/urfave/cli"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
category = "mongo"
|
||||||
|
modelTemplateFile = "model.tpl"
|
||||||
|
errTemplateFile = "err.tpl"
|
||||||
|
)
|
||||||
|
|
||||||
|
var templates = map[string]string{
|
||||||
|
modelTemplateFile: template.Text,
|
||||||
|
errTemplateFile: template.Error,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Category returns the mongo category.
|
||||||
|
func Category() string {
|
||||||
|
return category
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean cleans the mongo templates.
|
||||||
|
func Clean() error {
|
||||||
|
return util.Clean(category)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Templates initializes the mongo templates.
|
||||||
|
func Templates(_ *cli.Context) error {
|
||||||
|
return util.InitTemplates(category, templates)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RevertTemplate reverts the given template.
|
||||||
|
func RevertTemplate(name string) error {
|
||||||
|
content, ok := templates[name]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("%s: no such file name", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
return util.CreateTemplate(category, name, content)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update cleans and updates the templates.
|
||||||
|
func Update() error {
|
||||||
|
err := Clean()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return util.InitTemplates(category, templates)
|
||||||
|
}
|
||||||
39
tools/goctl/model/mongo/mongo.go
Normal file
39
tools/goctl/model/mongo/mongo.go
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
package mongo
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/tal-tech/go-zero/tools/goctl/config"
|
||||||
|
"github.com/tal-tech/go-zero/tools/goctl/model/mongo/generate"
|
||||||
|
"github.com/urfave/cli"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Action provides the entry for goctl mongo code generation.
|
||||||
|
func Action(ctx *cli.Context) error {
|
||||||
|
tp := ctx.StringSlice("type")
|
||||||
|
c := ctx.Bool("cache")
|
||||||
|
o := strings.TrimSpace(ctx.String("dir"))
|
||||||
|
s := ctx.String("style")
|
||||||
|
if len(tp) == 0 {
|
||||||
|
return errors.New("missing type")
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := config.NewConfig(s)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
a, err := filepath.Abs(o)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return generate.Do(&generate.Context{
|
||||||
|
Types: tp,
|
||||||
|
Cache: c,
|
||||||
|
Output: a,
|
||||||
|
Cfg: cfg,
|
||||||
|
})
|
||||||
|
}
|
||||||
210
tools/goctl/model/mongo/readme.md
Normal file
210
tools/goctl/model/mongo/readme.md
Normal file
@@ -0,0 +1,210 @@
|
|||||||
|
# mongo生成model
|
||||||
|
|
||||||
|
## 背景
|
||||||
|
|
||||||
|
在业务务开发中,model(dao)数据访问层是一个服务必不可缺的一层,因此数据库访问的CURD也是必须要对外提供的访问方法, 而CURD在go-zero中就仅存在两种情况
|
||||||
|
|
||||||
|
* 带缓存model
|
||||||
|
* 不带缓存model
|
||||||
|
|
||||||
|
从代码结构上来看,C-U-R-D四个方法就是固定的结构,因此我们可以将其交给goctl工具去完成,帮助我们提升开发效率。
|
||||||
|
|
||||||
|
## 方案设计
|
||||||
|
|
||||||
|
mongo的生成不同于mysql,mysql可以从scheme_information库中读取到一张表的信息(字段名称,数据类型,索引等),
|
||||||
|
而mongo是文档型数据库,我们暂时无法从db中读取某一条记录来实现字段信息获取,就算有也不一定是完整信息(某些字段可能是omitempty修饰,可有可无), 这里采用type自己编写+代码生成方式实现
|
||||||
|
|
||||||
|
## 使用示例
|
||||||
|
|
||||||
|
假设我们需要生成一个usermodel.go的代码文件,其包含用户信息字段有
|
||||||
|
|
||||||
|
|字段名称|字段类型|
|
||||||
|
|---|---|
|
||||||
|
|_id|bson.ObejctId|
|
||||||
|
|name|string|
|
||||||
|
|
||||||
|
### 编写types.go
|
||||||
|
|
||||||
|
```shell
|
||||||
|
$ vim types.go
|
||||||
|
```
|
||||||
|
|
||||||
|
```golang
|
||||||
|
package model
|
||||||
|
|
||||||
|
//go:generate goctl model mongo -t User
|
||||||
|
import "github.com/globalsign/mgo/bson"
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
ID bson.ObjectId `bson:"_id"`
|
||||||
|
Name string `bson:"name"`
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 生成代码
|
||||||
|
|
||||||
|
生成代码的方式有两种
|
||||||
|
|
||||||
|
* 命令行生成 在types.go所在文件夹执行命令
|
||||||
|
```shell
|
||||||
|
$ goctl model mongo -t User -style gozero
|
||||||
|
```
|
||||||
|
* 在types.go中添加`//go:generate`,然后点击执行按钮即可生成,内容示例如下:
|
||||||
|
```golang
|
||||||
|
//go:generate goctl model mongo -t User
|
||||||
|
```
|
||||||
|
|
||||||
|
### 生成示例代码
|
||||||
|
|
||||||
|
* usermodel.go
|
||||||
|
|
||||||
|
```golang
|
||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/globalsign/mgo/bson"
|
||||||
|
cachec "github.com/tal-tech/go-zero/core/stores/cache"
|
||||||
|
"github.com/tal-tech/go-zero/core/stores/mongoc"
|
||||||
|
)
|
||||||
|
|
||||||
|
type UserModel interface {
|
||||||
|
Insert(data *User, ctx context.Context) error
|
||||||
|
FindOne(id string, ctx context.Context) (*User, error)
|
||||||
|
Update(data *User, ctx context.Context) error
|
||||||
|
Delete(id string, ctx context.Context) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type defaultUserModel struct {
|
||||||
|
*mongoc.Model
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewUserModel(url, collection string, c cachec.CacheConf) UserModel {
|
||||||
|
return &defaultUserModel{
|
||||||
|
Model: mongoc.MustNewModel(url, collection, c),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *defaultUserModel) Insert(data *User, ctx context.Context) error {
|
||||||
|
if !data.ID.Valid() {
|
||||||
|
data.ID = bson.NewObjectId()
|
||||||
|
}
|
||||||
|
|
||||||
|
session, err := m.TakeSession()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer m.PutSession(session)
|
||||||
|
return m.GetCollection(session).Insert(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *defaultUserModel) FindOne(id string, ctx context.Context) (*User, error) {
|
||||||
|
if !bson.IsObjectIdHex(id) {
|
||||||
|
return nil, ErrInvalidObjectId
|
||||||
|
}
|
||||||
|
|
||||||
|
session, err := m.TakeSession()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer m.PutSession(session)
|
||||||
|
var data User
|
||||||
|
|
||||||
|
err = m.GetCollection(session).FindOneIdNoCache(&data, bson.ObjectIdHex(id))
|
||||||
|
switch err {
|
||||||
|
case nil:
|
||||||
|
return &data, nil
|
||||||
|
case mongoc.ErrNotFound:
|
||||||
|
return nil, ErrNotFound
|
||||||
|
default:
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *defaultUserModel) Update(data *User, ctx context.Context) error {
|
||||||
|
session, err := m.TakeSession()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer m.PutSession(session)
|
||||||
|
|
||||||
|
return m.GetCollection(session).UpdateIdNoCache(data.ID, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *defaultUserModel) Delete(id string, ctx context.Context) error {
|
||||||
|
session, err := m.TakeSession()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer m.PutSession(session)
|
||||||
|
|
||||||
|
return m.GetCollection(session).RemoveIdNoCache(bson.ObjectIdHex(id))
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
* error.go
|
||||||
|
|
||||||
|
```golang
|
||||||
|
package model
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
var ErrNotFound = errors.New("not found")
|
||||||
|
var ErrInvalidObjectId = errors.New("invalid objectId")
|
||||||
|
```
|
||||||
|
|
||||||
|
### 文件目录预览
|
||||||
|
|
||||||
|
```text
|
||||||
|
.
|
||||||
|
├── error.go
|
||||||
|
├── types.go
|
||||||
|
└── usermodel.go
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
## 命令预览
|
||||||
|
|
||||||
|
```text
|
||||||
|
NAME:
|
||||||
|
goctl model - generate model code
|
||||||
|
|
||||||
|
USAGE:
|
||||||
|
goctl model command [command options] [arguments...]
|
||||||
|
|
||||||
|
COMMANDS:
|
||||||
|
mysql generate mysql model
|
||||||
|
mongo generate mongo model
|
||||||
|
|
||||||
|
OPTIONS:
|
||||||
|
--help, -h show help
|
||||||
|
```
|
||||||
|
|
||||||
|
```text
|
||||||
|
NAME:
|
||||||
|
goctl model mongo - generate mongo model
|
||||||
|
|
||||||
|
USAGE:
|
||||||
|
goctl model mongo [command options] [arguments...]
|
||||||
|
|
||||||
|
OPTIONS:
|
||||||
|
--type value, -t value specified model type name
|
||||||
|
--cache, -c generate code with cache [optional]
|
||||||
|
--dir value, -d value the target dir
|
||||||
|
--style value the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
> 温馨提示
|
||||||
|
>
|
||||||
|
> `--type` 支持slice传值,示例 `goctl model mongo -t=User -t=Class`
|
||||||
|
## 注意事项
|
||||||
|
|
||||||
|
types.go本质上与xxxmodel.go无关,只是将type定义部分交给开发人员自己编写了,在xxxmodel.go中,mongo文档的存储结构必须包含
|
||||||
|
`_id`字段,对应到types中的field为`ID`,model中的findOne,update均以data.ID来进行操作的,当然,如果不符合你的命名风格,你也 可以修改模板,只要保证`id`
|
||||||
|
在types中的field名称和模板中一致就行。
|
||||||
112
tools/goctl/model/mongo/template/template.go
Normal file
112
tools/goctl/model/mongo/template/template.go
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
package template
|
||||||
|
|
||||||
|
// Text provides the default template for model to generate
|
||||||
|
var Text = `package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/globalsign/mgo/bson"
|
||||||
|
cachec "github.com/tal-tech/go-zero/core/stores/cache"
|
||||||
|
"github.com/tal-tech/go-zero/core/stores/mongoc"
|
||||||
|
)
|
||||||
|
|
||||||
|
{{if .Cache}}var prefix{{.Type}}CacheKey = "cache#{{.Type}}#"{{end}}
|
||||||
|
|
||||||
|
type {{.Type}}Model interface{
|
||||||
|
Insert(ctx context.Context,data *{{.Type}}) error
|
||||||
|
FindOne(ctx context.Context,id string) (*{{.Type}}, error)
|
||||||
|
Update(ctx context.Context,data *{{.Type}}) error
|
||||||
|
Delete(ctx context.Context,id string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type default{{.Type}}Model struct {
|
||||||
|
*mongoc.Model
|
||||||
|
}
|
||||||
|
|
||||||
|
func New{{.Type}}Model(url, collection string, c cachec.CacheConf) {{.Type}}Model {
|
||||||
|
return &default{{.Type}}Model{
|
||||||
|
Model: mongoc.MustNewModel(url, collection, c),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
func (m *default{{.Type}}Model) Insert(ctx context.Context, data *{{.Type}}) error {
|
||||||
|
if !data.ID.Valid() {
|
||||||
|
data.ID = bson.NewObjectId()
|
||||||
|
}
|
||||||
|
|
||||||
|
session, err := m.TakeSession()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer m.PutSession(session)
|
||||||
|
return m.GetCollection(session).Insert(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *default{{.Type}}Model) FindOne(ctx context.Context, id string) (*{{.Type}}, error) {
|
||||||
|
if !bson.IsObjectIdHex(id) {
|
||||||
|
return nil, ErrInvalidObjectId
|
||||||
|
}
|
||||||
|
|
||||||
|
session, err := m.TakeSession()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer m.PutSession(session)
|
||||||
|
var data {{.Type}}
|
||||||
|
{{if .Cache}}key := prefix{{.Type}}CacheKey + id
|
||||||
|
err = m.GetCollection(session).FindOneId(&data, key, bson.ObjectIdHex(id))
|
||||||
|
{{- else}}
|
||||||
|
err = m.GetCollection(session).FindOneIdNoCache(&data, bson.ObjectIdHex(id))
|
||||||
|
{{- end}}
|
||||||
|
switch err {
|
||||||
|
case nil:
|
||||||
|
return &data,nil
|
||||||
|
case mongoc.ErrNotFound:
|
||||||
|
return nil,ErrNotFound
|
||||||
|
default:
|
||||||
|
return nil,err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *default{{.Type}}Model) Update(ctx context.Context, data *{{.Type}}) error {
|
||||||
|
session, err := m.TakeSession()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer m.PutSession(session)
|
||||||
|
{{if .Cache}}key := prefix{{.Type}}CacheKey + data.ID.Hex()
|
||||||
|
return m.GetCollection(session).UpdateId(data.ID, data, key)
|
||||||
|
{{- else}}
|
||||||
|
return m.GetCollection(session).UpdateIdNoCache(data.ID, data)
|
||||||
|
{{- end}}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *default{{.Type}}Model) Delete(ctx context.Context, id string) error {
|
||||||
|
session, err := m.TakeSession()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer m.PutSession(session)
|
||||||
|
{{if .Cache}}key := prefix{{.Type}}CacheKey + id
|
||||||
|
return m.GetCollection(session).RemoveId(bson.ObjectIdHex(id), key)
|
||||||
|
{{- else}}
|
||||||
|
return m.GetCollection(session).RemoveIdNoCache(bson.ObjectIdHex(id))
|
||||||
|
{{- end}}
|
||||||
|
}
|
||||||
|
`
|
||||||
|
|
||||||
|
// Error provides the default template for error definition in mongo code generation.
|
||||||
|
var Error = `
|
||||||
|
package model
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
var ErrNotFound = errors.New("not found")
|
||||||
|
var ErrInvalidObjectId = errors.New("invalid objectId")
|
||||||
|
`
|
||||||
@@ -6,6 +6,8 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/tal-tech/go-zero/tools/goctl/model/sql/gen"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/tal-tech/go-zero/tools/goctl/config"
|
"github.com/tal-tech/go-zero/tools/goctl/config"
|
||||||
"github.com/tal-tech/go-zero/tools/goctl/util"
|
"github.com/tal-tech/go-zero/tools/goctl/util"
|
||||||
@@ -19,7 +21,10 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestFromDDl(t *testing.T) {
|
func TestFromDDl(t *testing.T) {
|
||||||
err := fromDDl("./user.sql", t.TempDir(), cfg, true, false)
|
err := gen.Clean()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
err = fromDDl("./user.sql", t.TempDir(), cfg, true, false)
|
||||||
assert.Equal(t, errNotMatched, err)
|
assert.Equal(t, errNotMatched, err)
|
||||||
|
|
||||||
// case dir is not exists
|
// case dir is not exists
|
||||||
|
|||||||
@@ -25,27 +25,7 @@ func genFindOneByField(table Table, withCache bool) (*findOneCode, error) {
|
|||||||
var list []string
|
var list []string
|
||||||
camelTableName := table.Name.ToCamel()
|
camelTableName := table.Name.ToCamel()
|
||||||
for _, key := range table.UniqueCacheKey {
|
for _, key := range table.UniqueCacheKey {
|
||||||
var inJoin, paramJoin, argJoin Join
|
in, paramJoinString, originalFieldString := convertJoin(key)
|
||||||
for _, f := range key.Fields {
|
|
||||||
param := stringx.From(f.Name.ToCamel()).Untitle()
|
|
||||||
inJoin = append(inJoin, fmt.Sprintf("%s %s", param, f.DataType))
|
|
||||||
paramJoin = append(paramJoin, param)
|
|
||||||
argJoin = append(argJoin, fmt.Sprintf("%s = ?", wrapWithRawString(f.Name.Source())))
|
|
||||||
}
|
|
||||||
var in string
|
|
||||||
if len(inJoin) > 0 {
|
|
||||||
in = inJoin.With(", ").Source()
|
|
||||||
}
|
|
||||||
|
|
||||||
var paramJoinString string
|
|
||||||
if len(paramJoin) > 0 {
|
|
||||||
paramJoinString = paramJoin.With(",").Source()
|
|
||||||
}
|
|
||||||
|
|
||||||
var originalFieldString string
|
|
||||||
if len(argJoin) > 0 {
|
|
||||||
originalFieldString = argJoin.With(" and ").Source()
|
|
||||||
}
|
|
||||||
|
|
||||||
output, err := t.Execute(map[string]interface{}{
|
output, err := t.Execute(map[string]interface{}{
|
||||||
"upperStartCamelObject": camelTableName,
|
"upperStartCamelObject": camelTableName,
|
||||||
@@ -125,3 +105,25 @@ func genFindOneByField(table Table, withCache bool) (*findOneCode, error) {
|
|||||||
findOneInterfaceMethod: strings.Join(listMethod, util.NL),
|
findOneInterfaceMethod: strings.Join(listMethod, util.NL),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func convertJoin(key Key) (in, paramJoinString, originalFieldString string) {
|
||||||
|
var inJoin, paramJoin, argJoin Join
|
||||||
|
for _, f := range key.Fields {
|
||||||
|
param := stringx.From(f.Name.ToCamel()).Untitle()
|
||||||
|
inJoin = append(inJoin, fmt.Sprintf("%s %s", param, f.DataType))
|
||||||
|
paramJoin = append(paramJoin, param)
|
||||||
|
argJoin = append(argJoin, fmt.Sprintf("%s = ?", wrapWithRawString(f.Name.Source())))
|
||||||
|
}
|
||||||
|
if len(inJoin) > 0 {
|
||||||
|
in = inJoin.With(", ").Source()
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(paramJoin) > 0 {
|
||||||
|
paramJoinString = paramJoin.With(",").Source()
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(argJoin) > 0 {
|
||||||
|
originalFieldString = argJoin.With(" and ").Source()
|
||||||
|
}
|
||||||
|
return in, paramJoinString, originalFieldString
|
||||||
|
}
|
||||||
|
|||||||
@@ -11,15 +11,15 @@ import (
|
|||||||
|
|
||||||
// Key describes cache key
|
// Key describes cache key
|
||||||
type Key struct {
|
type Key struct {
|
||||||
// VarLeft describes the varible of cache key expression which likes cacheUserIdPrefix
|
// VarLeft describes the variable of cache key expression which likes cacheUserIdPrefix
|
||||||
VarLeft string
|
VarLeft string
|
||||||
// VarRight describes the value of cache key expression which likes "cache#user#id#"
|
// VarRight describes the value of cache key expression which likes "cache#user#id#"
|
||||||
VarRight string
|
VarRight string
|
||||||
// VarExpression describes the cache key expression which likes cacheUserIdPrefix = "cache#user#id#"
|
// VarExpression describes the cache key expression which likes cacheUserIdPrefix = "cache#user#id#"
|
||||||
VarExpression string
|
VarExpression string
|
||||||
// KeyLeft describes the varible of key definiation expression which likes userKey
|
// KeyLeft describes the variable of key definition expression which likes userKey
|
||||||
KeyLeft string
|
KeyLeft string
|
||||||
// KeyRight describes the value of key definiation expression which likes fmt.Sprintf("%s%v", cacheUserPrefix, user)
|
// KeyRight describes the value of key definition expression which likes fmt.Sprintf("%s%v", cacheUserPrefix, user)
|
||||||
KeyRight string
|
KeyRight string
|
||||||
// DataKeyRight describes data key likes fmt.Sprintf("%s%v", cacheUserPrefix, data.User)
|
// DataKeyRight describes data key likes fmt.Sprintf("%s%v", cacheUserPrefix, data.User)
|
||||||
DataKeyRight string
|
DataKeyRight string
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package gen
|
|||||||
import (
|
import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/tal-tech/go-zero/core/collection"
|
||||||
"github.com/tal-tech/go-zero/tools/goctl/model/sql/template"
|
"github.com/tal-tech/go-zero/tools/goctl/model/sql/template"
|
||||||
"github.com/tal-tech/go-zero/tools/goctl/util"
|
"github.com/tal-tech/go-zero/tools/goctl/util"
|
||||||
"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
|
"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
|
||||||
@@ -23,6 +24,15 @@ func genUpdate(table Table, withCache bool) (string, string, error) {
|
|||||||
expressionValues = append(expressionValues, "data."+camel)
|
expressionValues = append(expressionValues, "data."+camel)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
keySet := collection.NewSet()
|
||||||
|
keyVariableSet := collection.NewSet()
|
||||||
|
keySet.AddStr(table.PrimaryCacheKey.DataKeyExpression)
|
||||||
|
keyVariableSet.AddStr(table.PrimaryCacheKey.KeyLeft)
|
||||||
|
for _, key := range table.UniqueCacheKey {
|
||||||
|
keySet.AddStr(key.DataKeyExpression)
|
||||||
|
keyVariableSet.AddStr(key.KeyLeft)
|
||||||
|
}
|
||||||
|
|
||||||
expressionValues = append(expressionValues, "data."+table.PrimaryKey.Name.ToCamel())
|
expressionValues = append(expressionValues, "data."+table.PrimaryKey.Name.ToCamel())
|
||||||
camelTableName := table.Name.ToCamel()
|
camelTableName := table.Name.ToCamel()
|
||||||
text, err := util.LoadTemplate(category, updateTemplateFile, template.Update)
|
text, err := util.LoadTemplate(category, updateTemplateFile, template.Update)
|
||||||
@@ -35,6 +45,8 @@ func genUpdate(table Table, withCache bool) (string, string, error) {
|
|||||||
Execute(map[string]interface{}{
|
Execute(map[string]interface{}{
|
||||||
"withCache": withCache,
|
"withCache": withCache,
|
||||||
"upperStartCamelObject": camelTableName,
|
"upperStartCamelObject": camelTableName,
|
||||||
|
"keys": strings.Join(keySet.KeysStr(), "\n"),
|
||||||
|
"keyValues": strings.Join(keyVariableSet.KeysStr(), ", "),
|
||||||
"primaryCacheKey": table.PrimaryCacheKey.DataKeyExpression,
|
"primaryCacheKey": table.PrimaryCacheKey.DataKeyExpression,
|
||||||
"primaryKeyVariable": table.PrimaryCacheKey.KeyLeft,
|
"primaryKeyVariable": table.PrimaryCacheKey.KeyLeft,
|
||||||
"lowerStartCamelObject": stringx.From(camelTableName).Untitle(),
|
"lowerStartCamelObject": stringx.From(camelTableName).Untitle(),
|
||||||
|
|||||||
@@ -102,6 +102,17 @@ func Parse(ddl string) (*Table, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
checkDuplicateUniqueIndex(uniqueIndex, tableName, normalIndex)
|
||||||
|
return &Table{
|
||||||
|
Name: stringx.From(tableName),
|
||||||
|
PrimaryKey: primaryKey,
|
||||||
|
UniqueIndex: uniqueIndex,
|
||||||
|
NormalIndex: normalIndex,
|
||||||
|
Fields: fields,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkDuplicateUniqueIndex(uniqueIndex map[string][]*Field, tableName string, normalIndex map[string][]*Field) {
|
||||||
log := console.NewColorConsole()
|
log := console.NewColorConsole()
|
||||||
uniqueSet := collection.NewSet()
|
uniqueSet := collection.NewSet()
|
||||||
for k, i := range uniqueIndex {
|
for k, i := range uniqueIndex {
|
||||||
@@ -136,14 +147,6 @@ func Parse(ddl string) (*Table, error) {
|
|||||||
|
|
||||||
normalIndexSet.Add(joinRet)
|
normalIndexSet.Add(joinRet)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Table{
|
|
||||||
Name: stringx.From(tableName),
|
|
||||||
PrimaryKey: primaryKey,
|
|
||||||
UniqueIndex: uniqueIndex,
|
|
||||||
NormalIndex: normalIndex,
|
|
||||||
Fields: fields,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertColumns(columns []*sqlparser.ColumnDefinition, primaryColumn string) (Primary, map[string]*Field, error) {
|
func convertColumns(columns []*sqlparser.ColumnDefinition, primaryColumn string) (Primary, map[string]*Field, error) {
|
||||||
@@ -289,27 +292,9 @@ func ConvertDataType(table *model.Table) (*Table, error) {
|
|||||||
AutoIncrement: strings.Contains(table.PrimaryKey.Extra, "auto_increment"),
|
AutoIncrement: strings.Contains(table.PrimaryKey.Extra, "auto_increment"),
|
||||||
}
|
}
|
||||||
|
|
||||||
fieldM := make(map[string]*Field)
|
fieldM, err := getTableFields(table)
|
||||||
for _, each := range table.Columns {
|
if err != nil {
|
||||||
isDefaultNull := each.ColumnDefault == nil && each.IsNullAble == "YES"
|
return nil, err
|
||||||
dt, err := converter.ConvertDataType(each.DataType, isDefaultNull)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
columnSeqInIndex := 0
|
|
||||||
if each.Index != nil {
|
|
||||||
columnSeqInIndex = each.Index.SeqInIndex
|
|
||||||
}
|
|
||||||
|
|
||||||
field := &Field{
|
|
||||||
Name: stringx.From(each.Name),
|
|
||||||
DataBaseType: each.DataType,
|
|
||||||
DataType: dt,
|
|
||||||
Comment: each.Comment,
|
|
||||||
SeqInIndex: columnSeqInIndex,
|
|
||||||
OrdinalPosition: each.OrdinalPosition,
|
|
||||||
}
|
|
||||||
fieldM[each.Name] = field
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, each := range fieldM {
|
for _, each := range fieldM {
|
||||||
@@ -379,3 +364,29 @@ func ConvertDataType(table *model.Table) (*Table, error) {
|
|||||||
|
|
||||||
return &reply, nil
|
return &reply, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getTableFields(table *model.Table) (map[string]*Field, error) {
|
||||||
|
fieldM := make(map[string]*Field)
|
||||||
|
for _, each := range table.Columns {
|
||||||
|
isDefaultNull := each.ColumnDefault == nil && each.IsNullAble == "YES"
|
||||||
|
dt, err := converter.ConvertDataType(each.DataType, isDefaultNull)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
columnSeqInIndex := 0
|
||||||
|
if each.Index != nil {
|
||||||
|
columnSeqInIndex = each.Index.SeqInIndex
|
||||||
|
}
|
||||||
|
|
||||||
|
field := &Field{
|
||||||
|
Name: stringx.From(each.Name),
|
||||||
|
DataBaseType: each.DataType,
|
||||||
|
DataType: dt,
|
||||||
|
Comment: each.Comment,
|
||||||
|
SeqInIndex: columnSeqInIndex,
|
||||||
|
OrdinalPosition: each.OrdinalPosition,
|
||||||
|
}
|
||||||
|
fieldM[each.Name] = field
|
||||||
|
}
|
||||||
|
return fieldM, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,11 +3,11 @@ package template
|
|||||||
// Update defines a template for generating update codes
|
// Update defines a template for generating update codes
|
||||||
var Update = `
|
var Update = `
|
||||||
func (m *default{{.upperStartCamelObject}}Model) Update(data {{.upperStartCamelObject}}) error {
|
func (m *default{{.upperStartCamelObject}}Model) Update(data {{.upperStartCamelObject}}) error {
|
||||||
{{if .withCache}}{{.primaryCacheKey}}
|
{{if .withCache}}{{.keys}}
|
||||||
_, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) {
|
_, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) {
|
||||||
query := fmt.Sprintf("update %s set %s where {{.originalPrimaryKey}} = ?", m.table, {{.lowerStartCamelObject}}RowsWithPlaceHolder)
|
query := fmt.Sprintf("update %s set %s where {{.originalPrimaryKey}} = ?", m.table, {{.lowerStartCamelObject}}RowsWithPlaceHolder)
|
||||||
return conn.Exec(query, {{.expressionValues}})
|
return conn.Exec(query, {{.expressionValues}})
|
||||||
}, {{.primaryKeyVariable}}){{else}}query := fmt.Sprintf("update %s set %s where {{.originalPrimaryKey}} = ?", m.table, {{.lowerStartCamelObject}}RowsWithPlaceHolder)
|
}, {{.keyValues}}){{else}}query := fmt.Sprintf("update %s set %s where {{.originalPrimaryKey}} = ?", m.table, {{.lowerStartCamelObject}}RowsWithPlaceHolder)
|
||||||
_,err:=m.conn.Exec(query, {{.expressionValues}}){{end}}
|
_,err:=m.conn.Exec(query, {{.expressionValues}}){{end}}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"github.com/tal-tech/go-zero/tools/goctl/api/gogen"
|
"github.com/tal-tech/go-zero/tools/goctl/api/gogen"
|
||||||
"github.com/tal-tech/go-zero/tools/goctl/docker"
|
"github.com/tal-tech/go-zero/tools/goctl/docker"
|
||||||
"github.com/tal-tech/go-zero/tools/goctl/kube"
|
"github.com/tal-tech/go-zero/tools/goctl/kube"
|
||||||
|
mongogen "github.com/tal-tech/go-zero/tools/goctl/model/mongo/generate"
|
||||||
modelgen "github.com/tal-tech/go-zero/tools/goctl/model/sql/gen"
|
modelgen "github.com/tal-tech/go-zero/tools/goctl/model/sql/gen"
|
||||||
rpcgen "github.com/tal-tech/go-zero/tools/goctl/rpc/generator"
|
rpcgen "github.com/tal-tech/go-zero/tools/goctl/rpc/generator"
|
||||||
"github.com/tal-tech/go-zero/tools/goctl/util"
|
"github.com/tal-tech/go-zero/tools/goctl/util"
|
||||||
@@ -16,7 +17,7 @@ import (
|
|||||||
|
|
||||||
const templateParentPath = "/"
|
const templateParentPath = "/"
|
||||||
|
|
||||||
// GenTemplates wtites the latest template text into file which is not exists
|
// GenTemplates writes the latest template text into file which is not exists
|
||||||
func GenTemplates(ctx *cli.Context) error {
|
func GenTemplates(ctx *cli.Context) error {
|
||||||
if err := errorx.Chain(
|
if err := errorx.Chain(
|
||||||
func() error {
|
func() error {
|
||||||
@@ -34,6 +35,9 @@ func GenTemplates(ctx *cli.Context) error {
|
|||||||
func() error {
|
func() error {
|
||||||
return kube.GenTemplates(ctx)
|
return kube.GenTemplates(ctx)
|
||||||
},
|
},
|
||||||
|
func() error {
|
||||||
|
return mongogen.Templates(ctx)
|
||||||
|
},
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -61,6 +65,15 @@ func CleanTemplates(_ *cli.Context) error {
|
|||||||
func() error {
|
func() error {
|
||||||
return rpcgen.Clean()
|
return rpcgen.Clean()
|
||||||
},
|
},
|
||||||
|
func() error {
|
||||||
|
return docker.Clean()
|
||||||
|
},
|
||||||
|
func() error {
|
||||||
|
return kube.Clean()
|
||||||
|
},
|
||||||
|
func() error {
|
||||||
|
return mongogen.Clean()
|
||||||
|
},
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -70,8 +83,8 @@ func CleanTemplates(_ *cli.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateTemplates wtites the latest template text into file,
|
// UpdateTemplates writes the latest template text into file,
|
||||||
// it will delete the oldler templates if there are exists
|
// it will delete the older templates if there are exists
|
||||||
func UpdateTemplates(ctx *cli.Context) (err error) {
|
func UpdateTemplates(ctx *cli.Context) (err error) {
|
||||||
category := ctx.String("category")
|
category := ctx.String("category")
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -90,6 +103,8 @@ func UpdateTemplates(ctx *cli.Context) (err error) {
|
|||||||
return rpcgen.Update()
|
return rpcgen.Update()
|
||||||
case modelgen.Category():
|
case modelgen.Category():
|
||||||
return modelgen.Update()
|
return modelgen.Update()
|
||||||
|
case mongogen.Category():
|
||||||
|
return mongogen.Update()
|
||||||
default:
|
default:
|
||||||
err = fmt.Errorf("unexpected category: %s", category)
|
err = fmt.Errorf("unexpected category: %s", category)
|
||||||
return
|
return
|
||||||
@@ -116,6 +131,8 @@ func RevertTemplates(ctx *cli.Context) (err error) {
|
|||||||
return rpcgen.RevertTemplate(filename)
|
return rpcgen.RevertTemplate(filename)
|
||||||
case modelgen.Category():
|
case modelgen.Category():
|
||||||
return modelgen.RevertTemplate(filename)
|
return modelgen.RevertTemplate(filename)
|
||||||
|
case mongogen.Category():
|
||||||
|
return mongogen.RevertTemplate(filename)
|
||||||
default:
|
default:
|
||||||
err = fmt.Errorf("unexpected category: %s", category)
|
err = fmt.Errorf("unexpected category: %s", category)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import (
|
|||||||
|
|
||||||
type (
|
type (
|
||||||
// Console wraps from the fmt.Sprintf,
|
// Console wraps from the fmt.Sprintf,
|
||||||
// by default, it implemented the colorConsole to provide the colorful output to the consle
|
// by default, it implemented the colorConsole to provide the colorful output to the console
|
||||||
// and the ideaConsole to output with prefix for the plugin of intellij
|
// and the ideaConsole to output with prefix for the plugin of intellij
|
||||||
Console interface {
|
Console interface {
|
||||||
Success(format string, a ...interface{})
|
Success(format string, a ...interface{})
|
||||||
@@ -81,7 +81,7 @@ func (c *colorConsole) Must(err error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewIdeaConsole returns a instace of ideaConsole
|
// NewIdeaConsole returns a instance of ideaConsole
|
||||||
func NewIdeaConsole() Console {
|
func NewIdeaConsole() Console {
|
||||||
return &ideaConsole{}
|
return &ideaConsole{}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ func CreateIfNotExist(file string) (*os.File, error) {
|
|||||||
return os.Create(file)
|
return os.Create(file)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveIfExist deletes the specficed file if it is exists
|
// RemoveIfExist deletes the specified file if it is exists
|
||||||
func RemoveIfExist(filename string) error {
|
func RemoveIfExist(filename string) error {
|
||||||
if !FileExists(filename) {
|
if !FileExists(filename) {
|
||||||
return nil
|
return nil
|
||||||
@@ -36,7 +36,7 @@ func RemoveIfExist(filename string) error {
|
|||||||
return os.Remove(filename)
|
return os.Remove(filename)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveOrQuit deletes the specficed file if read a permit command from stdin
|
// RemoveOrQuit deletes the specified file if read a permit command from stdin
|
||||||
func RemoveOrQuit(filename string) error {
|
func RemoveOrQuit(filename string) error {
|
||||||
if !FileExists(filename) {
|
if !FileExists(filename) {
|
||||||
return nil
|
return nil
|
||||||
@@ -49,7 +49,7 @@ func RemoveOrQuit(filename string) error {
|
|||||||
return os.Remove(filename)
|
return os.Remove(filename)
|
||||||
}
|
}
|
||||||
|
|
||||||
// FileExists returns true if the specficed file is exists
|
// FileExists returns true if the specified file is exists
|
||||||
func FileExists(file string) bool {
|
func FileExists(file string) bool {
|
||||||
_, err := os.Stat(file)
|
_, err := os.Stat(file)
|
||||||
return err == nil
|
return err == nil
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ const (
|
|||||||
upper
|
upper
|
||||||
)
|
)
|
||||||
|
|
||||||
// ErrNamingFormat defines an error for unknown fomat
|
// ErrNamingFormat defines an error for unknown format
|
||||||
var ErrNamingFormat = errors.New("unsupported format")
|
var ErrNamingFormat = errors.New("unsupported format")
|
||||||
|
|
||||||
type (
|
type (
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ func MkdirIfNotExist(dir string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// PathFromGoSrc returns the path whihout slash where has been trim the prefix $GOPATH
|
// PathFromGoSrc returns the path without slash where has been trim the prefix $GOPATH
|
||||||
func PathFromGoSrc() (string, error) {
|
func PathFromGoSrc() (string, error) {
|
||||||
dir, err := os.Getwd()
|
dir, err := os.Getwd()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
"unicode"
|
"unicode"
|
||||||
)
|
)
|
||||||
|
|
||||||
// String provides for coverting the source text into other spell case,like lower,snake,camel
|
// String provides for converting the source text into other spell case,like lower,snake,camel
|
||||||
type String struct {
|
type String struct {
|
||||||
source string
|
source string
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ type DefaultTemplate struct {
|
|||||||
savePath string
|
savePath string
|
||||||
}
|
}
|
||||||
|
|
||||||
// With returns a instace of DefaultTemplate
|
// With returns a instance of DefaultTemplate
|
||||||
func With(name string) *DefaultTemplate {
|
func With(name string) *DefaultTemplate {
|
||||||
return &DefaultTemplate{
|
return &DefaultTemplate{
|
||||||
name: name,
|
name: name,
|
||||||
@@ -30,7 +30,7 @@ func (t *DefaultTemplate) Parse(text string) *DefaultTemplate {
|
|||||||
return t
|
return t
|
||||||
}
|
}
|
||||||
|
|
||||||
// GoFmt sets the value to goFmt and marks the generated codes will be formated or not
|
// GoFmt sets the value to goFmt and marks the generated codes will be formatted or not
|
||||||
func (t *DefaultTemplate) GoFmt(format bool) *DefaultTemplate {
|
func (t *DefaultTemplate) GoFmt(format bool) *DefaultTemplate {
|
||||||
t.goFmt = format
|
t.goFmt = format
|
||||||
return t
|
return t
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ package vars
|
|||||||
const (
|
const (
|
||||||
// ProjectName the const value of zero
|
// ProjectName the const value of zero
|
||||||
ProjectName = "zero"
|
ProjectName = "zero"
|
||||||
// ProjectOpenSourceURL the githb url of go-zero
|
// ProjectOpenSourceURL the github url of go-zero
|
||||||
ProjectOpenSourceURL = "github.com/tal-tech/go-zero"
|
ProjectOpenSourceURL = "github.com/tal-tech/go-zero"
|
||||||
// OsWindows windows os
|
// OsWindows windows os
|
||||||
OsWindows = "windows"
|
OsWindows = "windows"
|
||||||
|
|||||||
@@ -18,6 +18,27 @@ func TimeoutInterceptor(timeout time.Duration) grpc.UnaryClientInterceptor {
|
|||||||
|
|
||||||
ctx, cancel := contextx.ShrinkDeadline(ctx, timeout)
|
ctx, cancel := contextx.ShrinkDeadline(ctx, timeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
return invoker(ctx, method, req, reply, cc, opts...)
|
|
||||||
|
// create channel with buffer size 1 to avoid goroutine leak
|
||||||
|
done := make(chan error, 1)
|
||||||
|
panicChan := make(chan interface{}, 1)
|
||||||
|
go func() {
|
||||||
|
defer func() {
|
||||||
|
if p := recover(); p != nil {
|
||||||
|
panicChan <- p
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
done <- invoker(ctx, method, req, reply, cc, opts...)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case p := <-panicChan:
|
||||||
|
panic(p)
|
||||||
|
case err := <-done:
|
||||||
|
return err
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -48,3 +48,40 @@ func TestTimeoutInterceptor_timeout(t *testing.T) {
|
|||||||
wg.Wait()
|
wg.Wait()
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTimeoutInterceptor_timeoutExpire(t *testing.T) {
|
||||||
|
const timeout = time.Millisecond * 10
|
||||||
|
interceptor := TimeoutInterceptor(timeout)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
cc := new(grpc.ClientConn)
|
||||||
|
err := interceptor(ctx, "/foo", nil, nil, cc,
|
||||||
|
func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
|
||||||
|
opts ...grpc.CallOption) error {
|
||||||
|
defer wg.Done()
|
||||||
|
time.Sleep(time.Millisecond * 50)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
wg.Wait()
|
||||||
|
assert.Equal(t, context.DeadlineExceeded, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTimeoutInterceptor_panic(t *testing.T) {
|
||||||
|
timeouts := []time.Duration{0, time.Millisecond * 10}
|
||||||
|
for _, timeout := range timeouts {
|
||||||
|
t.Run(strconv.FormatInt(int64(timeout), 10), func(t *testing.T) {
|
||||||
|
interceptor := TimeoutInterceptor(timeout)
|
||||||
|
cc := new(grpc.ClientConn)
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
_ = interceptor(context.Background(), "/foo", nil, nil, cc,
|
||||||
|
func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
|
||||||
|
opts ...grpc.CallOption) error {
|
||||||
|
panic("any")
|
||||||
|
},
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package serverinterceptors
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/tal-tech/go-zero/core/contextx"
|
"github.com/tal-tech/go-zero/core/contextx"
|
||||||
@@ -11,9 +12,38 @@ import (
|
|||||||
// UnaryTimeoutInterceptor returns a func that sets timeout to incoming unary requests.
|
// UnaryTimeoutInterceptor returns a func that sets timeout to incoming unary requests.
|
||||||
func UnaryTimeoutInterceptor(timeout time.Duration) grpc.UnaryServerInterceptor {
|
func UnaryTimeoutInterceptor(timeout time.Duration) grpc.UnaryServerInterceptor {
|
||||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
|
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
|
||||||
handler grpc.UnaryHandler) (resp interface{}, err error) {
|
handler grpc.UnaryHandler) (interface{}, error) {
|
||||||
ctx, cancel := contextx.ShrinkDeadline(ctx, timeout)
|
ctx, cancel := contextx.ShrinkDeadline(ctx, timeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
return handler(ctx, req)
|
|
||||||
|
var resp interface{}
|
||||||
|
var err error
|
||||||
|
var lock sync.Mutex
|
||||||
|
done := make(chan struct{})
|
||||||
|
// create channel with buffer size 1 to avoid goroutine leak
|
||||||
|
panicChan := make(chan interface{}, 1)
|
||||||
|
go func() {
|
||||||
|
defer func() {
|
||||||
|
if p := recover(); p != nil {
|
||||||
|
panicChan <- p
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
lock.Lock()
|
||||||
|
defer lock.Unlock()
|
||||||
|
resp, err = handler(ctx, req)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case p := <-panicChan:
|
||||||
|
panic(p)
|
||||||
|
case <-done:
|
||||||
|
lock.Lock()
|
||||||
|
defer lock.Unlock()
|
||||||
|
return resp, err
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,6 +20,17 @@ func TestUnaryTimeoutInterceptor(t *testing.T) {
|
|||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUnaryTimeoutInterceptor_panic(t *testing.T) {
|
||||||
|
interceptor := UnaryTimeoutInterceptor(time.Millisecond * 10)
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
_, _ = interceptor(context.Background(), nil, &grpc.UnaryServerInfo{
|
||||||
|
FullMethod: "/",
|
||||||
|
}, func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
|
panic("any")
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestUnaryTimeoutInterceptor_timeout(t *testing.T) {
|
func TestUnaryTimeoutInterceptor_timeout(t *testing.T) {
|
||||||
const timeout = time.Millisecond * 10
|
const timeout = time.Millisecond * 10
|
||||||
interceptor := UnaryTimeoutInterceptor(timeout)
|
interceptor := UnaryTimeoutInterceptor(timeout)
|
||||||
@@ -39,3 +50,21 @@ func TestUnaryTimeoutInterceptor_timeout(t *testing.T) {
|
|||||||
wg.Wait()
|
wg.Wait()
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUnaryTimeoutInterceptor_timeoutExpire(t *testing.T) {
|
||||||
|
const timeout = time.Millisecond * 10
|
||||||
|
interceptor := UnaryTimeoutInterceptor(timeout)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
_, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{
|
||||||
|
FullMethod: "/",
|
||||||
|
}, func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
|
defer wg.Done()
|
||||||
|
time.Sleep(time.Millisecond * 50)
|
||||||
|
return nil, nil
|
||||||
|
})
|
||||||
|
wg.Wait()
|
||||||
|
assert.Equal(t, context.DeadlineExceeded, err)
|
||||||
|
}
|
||||||
|
|||||||
@@ -79,6 +79,8 @@ func (rs *RpcServer) AddUnaryInterceptors(interceptors ...grpc.UnaryServerInterc
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Start starts the RpcServer.
|
// Start starts the RpcServer.
|
||||||
|
// Graceful shutdown is enabled by default.
|
||||||
|
// Use proc.SetTimeToForceQuit to customize the graceful shutdown period.
|
||||||
func (rs *RpcServer) Start() {
|
func (rs *RpcServer) Start() {
|
||||||
if err := rs.server.Start(rs.register); err != nil {
|
if err := rs.server.Start(rs.register); err != nil {
|
||||||
logx.Error(err)
|
logx.Error(err)
|
||||||
|
|||||||
Reference in New Issue
Block a user