Compare commits

...

19 Commits

Author SHA1 Message Date
anqiansong
888551627c optimize code (#579)
* optimize code

* optimize returns & unit test
2021-03-27 17:33:17 +08:00
Kevin Wan
bd623aaac3 support postgresql (#583)
support postgresql
2021-03-27 17:14:32 +08:00
Kevin Wan
9e6c2ba2c0 avoid goroutine leak after timeout (#575) 2021-03-21 16:54:34 +08:00
Kevin Wan
c0db8d017d gofmt logs (#574) 2021-03-20 16:40:09 +08:00
TonyWang
52b4f8ca91 add timezone and timeformat (#572)
* add timezone and timeformat

* rm time zone and keep time format

Co-authored-by: Tony Wang <tonywang.data@gmail.com>
2021-03-20 16:36:19 +08:00
Kevin Wan
4884a7b3c6 zrpc timeout & unit tests (#573)
* zrpc timeout & unit tests
2021-03-19 18:41:26 +08:00
Kevin Wan
3c6951577d make hijack more stable (#565) 2021-03-15 20:11:09 +08:00
Kevin Wan
fcd15c9b17 refactor, and add comments to describe graceful shutdown (#564) 2021-03-14 08:51:10 +08:00
Kevin Wan
155e6061cb fix golint issues (#561) 2021-03-12 23:08:04 +08:00
anqiansong
dda7666097 Feature mongo gen (#546)
* add feature: mongo code generation

* upgrade version

* update doc

* format code

* update update.tpl of mysql
2021-03-12 17:49:28 +08:00
hanhotfox
c954568b61 Hdel support for multiple key deletion (#542)
* Hdel support for multiple key deletion

* Hdel field -> fields

Co-authored-by: duanyan <duanyan@xiaoheiban.cn>
2021-03-12 17:47:21 +08:00
Kevin Wan
c2acc43a52 add important notes in readme (#560) 2021-03-12 16:48:25 +08:00
Kevin Wan
1a1a6f5239 add http hijack methods (#555) 2021-03-09 21:30:45 +08:00
anqiansong
60c7edf8f8 fix spelling (#551) 2021-03-08 18:23:12 +08:00
Kevin Wan
7ad86a52f3 update doc link (#552) 2021-03-08 17:56:03 +08:00
kingxt
1e4e5a02b2 rename (#543) 2021-03-04 17:13:07 +08:00
Kevin Wan
39540e21d2 fix golint issues (#540) 2021-03-03 17:16:09 +08:00
hexiaoen
b321622c95 暴露redis EvalSha 以及ScriptLoad接口 (#538)
Co-authored-by: shanehe <shanehe@zego.im>
2021-03-03 17:09:27 +08:00
kingxt
a25cba5380 fix collection breaker (#537)
* fix collection breaker

* optimized

* optimized

* optimized
2021-03-03 10:44:29 +08:00
60 changed files with 1146 additions and 224 deletions

View File

@@ -26,7 +26,8 @@ func DoWithTimeout(fn func() error, timeout time.Duration, opts ...DoOption) err
ctx, cancel := contextx.ShrinkDeadline(parentCtx, timeout) ctx, cancel := contextx.ShrinkDeadline(parentCtx, timeout)
defer cancel() defer cancel()
done := make(chan error) // create channel with buffer size 1 to avoid goroutine leak
done := make(chan error, 1)
panicChan := make(chan interface{}, 1) panicChan := make(chan interface{}, 1)
go func() { go func() {
defer func() { defer func() {
@@ -35,7 +36,6 @@ func DoWithTimeout(fn func() error, timeout time.Duration, opts ...DoOption) err
} }
}() }()
done <- fn() done <- fn()
close(done)
}() }()
select { select {

View File

@@ -4,6 +4,7 @@ package logx
type LogConf struct { type LogConf struct {
ServiceName string `json:",optional"` ServiceName string `json:",optional"`
Mode string `json:",default=console,options=console|file|volume"` Mode string `json:",default=console,options=console|file|volume"`
TimeFormat string `json:",optional"`
Path string `json:",default=logs"` Path string `json:",default=logs"`
Level string `json:",default=info,options=info|error|severe"` Level string `json:",default=info,options=info|error|severe"`
Compress bool `json:",optional"` Compress bool `json:",optional"`

View File

@@ -32,8 +32,6 @@ const (
) )
const ( const (
timeFormat = "2006-01-02T15:04:05.000Z07"
accessFilename = "access.log" accessFilename = "access.log"
errorFilename = "error.log" errorFilename = "error.log"
severeFilename = "severe.log" severeFilename = "severe.log"
@@ -64,6 +62,7 @@ var (
// ErrLogServiceNameNotSet is an error that indicates that the service name is not set. // ErrLogServiceNameNotSet is an error that indicates that the service name is not set.
ErrLogServiceNameNotSet = errors.New("log service name must be set") ErrLogServiceNameNotSet = errors.New("log service name must be set")
timeFormat = "2006-01-02T15:04:05.000Z07"
writeConsole bool writeConsole bool
logLevel uint32 logLevel uint32
infoLog io.WriteCloser infoLog io.WriteCloser
@@ -117,6 +116,10 @@ func MustSetup(c LogConf) {
// we need to allow different service frameworks to initialize logx respectively. // we need to allow different service frameworks to initialize logx respectively.
// the same logic for SetUp // the same logic for SetUp
func SetUp(c LogConf) error { func SetUp(c LogConf) error {
if len(c.TimeFormat) > 0 {
timeFormat = c.TimeFormat
}
switch c.Mode { switch c.Mode {
case consoleMode: case consoleMode:
setupWithConsole(c) setupWithConsole(c)

View File

@@ -43,11 +43,11 @@ type (
} }
) )
func newCollection(collection *mgo.Collection) Collection { func newCollection(collection *mgo.Collection, brk breaker.Breaker) Collection {
return &decoratedCollection{ return &decoratedCollection{
name: collection.FullName, name: collection.FullName,
collection: collection, collection: collection,
brk: breaker.NewBreaker(), brk: brk,
} }
} }

View File

@@ -71,7 +71,7 @@ func TestNewCollection(t *testing.T) {
Database: nil, Database: nil,
Name: "foo", Name: "foo",
FullName: "bar", FullName: "bar",
}) }, breaker.GetBreaker("localhost"))
assert.Equal(t, "bar", col.(*decoratedCollection).name) assert.Equal(t, "bar", col.(*decoratedCollection).name)
} }

View File

@@ -5,6 +5,7 @@ import (
"time" "time"
"github.com/globalsign/mgo" "github.com/globalsign/mgo"
"github.com/tal-tech/go-zero/core/breaker"
) )
type ( type (
@@ -20,6 +21,7 @@ type (
session *concurrentSession session *concurrentSession
db *mgo.Database db *mgo.Database
collection string collection string
brk breaker.Breaker
opts []Option opts []Option
} }
) )
@@ -46,6 +48,7 @@ func NewModel(url, collection string, opts ...Option) (*Model, error) {
// If name is empty, the database name provided in the dialed URL is used instead // If name is empty, the database name provided in the dialed URL is used instead
db: session.DB(""), db: session.DB(""),
collection: collection, collection: collection,
brk: breaker.GetBreaker(url),
opts: opts, opts: opts,
}, nil }, nil
} }
@@ -66,7 +69,7 @@ func (mm *Model) FindId(id interface{}) (Query, error) {
// GetCollection returns a Collection with given session. // GetCollection returns a Collection with given session.
func (mm *Model) GetCollection(session *mgo.Session) Collection { func (mm *Model) GetCollection(session *mgo.Session) Collection {
return newCollection(mm.db.C(mm.collection).With(session)) return newCollection(mm.db.C(mm.collection).With(session), mm.brk)
} }
// Insert inserts docs into mm. // Insert inserts docs into mm.

View File

@@ -250,6 +250,21 @@ func (s *Redis) Eval(script string, keys []string, args ...interface{}) (val int
return return
} }
// EvalSha is the implementation of redis evalsha command.
func (s *Redis) EvalSha(sha string, keys []string, args ...interface{}) (val interface{}, err error) {
err = s.brk.DoWithAcceptable(func() error {
conn, err := getRedis(s)
if err != nil {
return err
}
val, err = conn.EvalSha(sha, keys, args...).Result()
return err
}, acceptable)
return
}
// Exists is the implementation of redis exists command. // Exists is the implementation of redis exists command.
func (s *Redis) Exists(key string) (val bool, err error) { func (s *Redis) Exists(key string) (val bool, err error) {
err = s.brk.DoWithAcceptable(func() error { err = s.brk.DoWithAcceptable(func() error {
@@ -449,14 +464,14 @@ func (s *Redis) GetBit(key string, offset int64) (val int, err error) {
} }
// Hdel is the implementation of redis hdel command. // Hdel is the implementation of redis hdel command.
func (s *Redis) Hdel(key, field string) (val bool, err error) { func (s *Redis) Hdel(key string, fields ...string) (val bool, err error) {
err = s.brk.DoWithAcceptable(func() error { err = s.brk.DoWithAcceptable(func() error {
conn, err := getRedis(s) conn, err := getRedis(s)
if err != nil { if err != nil {
return err return err
} }
v, err := conn.HDel(key, field).Result() v, err := conn.HDel(key, fields...).Result()
if err != nil { if err != nil {
return err return err
} }
@@ -1032,6 +1047,16 @@ func (s *Redis) Scard(key string) (val int64, err error) {
return return
} }
// ScriptLoad is the implementation of redis script load command.
func (s *Redis) ScriptLoad(script string) (string, error) {
conn, err := getRedis(s)
if err != nil {
return "", err
}
return conn.ScriptLoad(script).Result()
}
// Set is the implementation of redis set command. // Set is the implementation of redis set command.
func (s *Redis) Set(key string, value string) error { func (s *Redis) Set(key string, value string) error {
return s.brk.DoWithAcceptable(func() error { return s.brk.DoWithAcceptable(func() error {
@@ -1101,26 +1126,6 @@ func (s *Redis) Sismember(key string, value interface{}) (val bool, err error) {
return return
} }
// Srem is the implementation of redis srem command.
func (s *Redis) Srem(key string, values ...interface{}) (val int, err error) {
err = s.brk.DoWithAcceptable(func() error {
conn, err := getRedis(s)
if err != nil {
return err
}
v, err := conn.SRem(key, values...).Result()
if err != nil {
return err
}
val = int(v)
return nil
}, acceptable)
return
}
// Smembers is the implementation of redis smembers command. // Smembers is the implementation of redis smembers command.
func (s *Redis) Smembers(key string) (val []string, err error) { func (s *Redis) Smembers(key string) (val []string, err error) {
err = s.brk.DoWithAcceptable(func() error { err = s.brk.DoWithAcceptable(func() error {
@@ -1166,6 +1171,31 @@ func (s *Redis) Srandmember(key string, count int) (val []string, err error) {
return return
} }
// Srem is the implementation of redis srem command.
func (s *Redis) Srem(key string, values ...interface{}) (val int, err error) {
err = s.brk.DoWithAcceptable(func() error {
conn, err := getRedis(s)
if err != nil {
return err
}
v, err := conn.SRem(key, values...).Result()
if err != nil {
return err
}
val = int(v)
return nil
}, acceptable)
return
}
// String returns the string representation of s.
func (s *Redis) String() string {
return s.Addr
}
// Sunion is the implementation of redis sunion command. // Sunion is the implementation of redis sunion command.
func (s *Redis) Sunion(keys ...string) (val []string, err error) { func (s *Redis) Sunion(keys ...string) (val []string, err error) {
err = s.brk.DoWithAcceptable(func() error { err = s.brk.DoWithAcceptable(func() error {
@@ -1667,20 +1697,6 @@ func (s *Redis) Zunionstore(dest string, store ZStore, keys ...string) (val int6
return return
} }
// String returns the string representation of s.
func (s *Redis) String() string {
return s.Addr
}
func (s *Redis) scriptLoad(script string) (string, error) {
conn, err := getRedis(s)
if err != nil {
return "", err
}
return conn.ScriptLoad(script).Result()
}
func acceptable(err error) bool { func acceptable(err error) bool {
return err == nil || err == red.Nil return err == nil || err == red.Nil
} }

View File

@@ -947,13 +947,24 @@ func TestRedisString(t *testing.T) {
func TestRedisScriptLoad(t *testing.T) { func TestRedisScriptLoad(t *testing.T) {
runOnRedis(t, func(client *Redis) { runOnRedis(t, func(client *Redis) {
client.Ping() client.Ping()
_, err := NewRedis(client.Addr, "").scriptLoad("foo") _, err := NewRedis(client.Addr, "").ScriptLoad("foo")
assert.NotNil(t, err) assert.NotNil(t, err)
_, err = client.scriptLoad("foo") _, err = client.ScriptLoad("foo")
assert.NotNil(t, err) assert.NotNil(t, err)
}) })
} }
func TestRedisEvalSha(t *testing.T) {
runOnRedis(t, func(client *Redis) {
client.Ping()
scriptHash, err := client.ScriptLoad(`return redis.call("EXISTS", KEYS[1])`)
assert.Nil(t, err)
result, err := client.EvalSha(scriptHash, []string{"key1"})
assert.Nil(t, err)
assert.Equal(t, int64(0), result)
})
}
func TestRedisToPairs(t *testing.T) { func TestRedisToPairs(t *testing.T) {
pairs := toPairs([]red.Z{ pairs := toPairs([]red.Z{
{ {

View File

@@ -24,6 +24,7 @@ type (
ResultHandler func(sql.Result, error) ResultHandler func(sql.Result, error)
// A BulkInserter is used to batch insert records. // A BulkInserter is used to batch insert records.
// Postgresql is not supported yet, because of the sql is formated with symbol `$`.
BulkInserter struct { BulkInserter struct {
executor *executors.PeriodicalExecutor executor *executors.PeriodicalExecutor
inserter *dbInserter inserter *dbInserter

View File

@@ -12,14 +12,10 @@ import (
const slowThreshold = time.Millisecond * 500 const slowThreshold = time.Millisecond * 500
func exec(conn sessionConn, q string, args ...interface{}) (sql.Result, error) { func exec(conn sessionConn, q string, args ...interface{}) (sql.Result, error) {
stmt, err := format(q, args...)
if err != nil {
return nil, err
}
startTime := timex.Now() startTime := timex.Now()
result, err := conn.Exec(q, args...) result, err := conn.Exec(q, args...)
duration := timex.Since(startTime) duration := timex.Since(startTime)
stmt := formatForPrint(q, args)
if duration > slowThreshold { if duration > slowThreshold {
logx.WithDuration(duration).Slowf("[SQL] exec: slowcall - %s", stmt) logx.WithDuration(duration).Slowf("[SQL] exec: slowcall - %s", stmt)
} else { } else {
@@ -33,10 +29,10 @@ func exec(conn sessionConn, q string, args ...interface{}) (sql.Result, error) {
} }
func execStmt(conn stmtConn, args ...interface{}) (sql.Result, error) { func execStmt(conn stmtConn, args ...interface{}) (sql.Result, error) {
stmt := fmt.Sprint(args...)
startTime := timex.Now() startTime := timex.Now()
result, err := conn.Exec(args...) result, err := conn.Exec(args...)
duration := timex.Since(startTime) duration := timex.Since(startTime)
stmt := fmt.Sprint(args...)
if duration > slowThreshold { if duration > slowThreshold {
logx.WithDuration(duration).Slowf("[SQL] execStmt: slowcall - %s", stmt) logx.WithDuration(duration).Slowf("[SQL] execStmt: slowcall - %s", stmt)
} else { } else {
@@ -50,14 +46,10 @@ func execStmt(conn stmtConn, args ...interface{}) (sql.Result, error) {
} }
func query(conn sessionConn, scanner func(*sql.Rows) error, q string, args ...interface{}) error { func query(conn sessionConn, scanner func(*sql.Rows) error, q string, args ...interface{}) error {
stmt, err := format(q, args...)
if err != nil {
return err
}
startTime := timex.Now() startTime := timex.Now()
rows, err := conn.Query(q, args...) rows, err := conn.Query(q, args...)
duration := timex.Since(startTime) duration := timex.Since(startTime)
stmt := fmt.Sprint(args...)
if duration > slowThreshold { if duration > slowThreshold {
logx.WithDuration(duration).Slowf("[SQL] query: slowcall - %s", stmt) logx.WithDuration(duration).Slowf("[SQL] query: slowcall - %s", stmt)
} else { } else {

View File

@@ -16,7 +16,6 @@ func TestStmt_exec(t *testing.T) {
name string name string
args []interface{} args []interface{}
delay bool delay bool
formatError bool
hasError bool hasError bool
err error err error
lastInsertId int64 lastInsertId int64
@@ -28,12 +27,6 @@ func TestStmt_exec(t *testing.T) {
lastInsertId: 1, lastInsertId: 1,
rowsAffected: 2, rowsAffected: 2,
}, },
{
name: "wrong format",
args: []interface{}{1, 2},
formatError: true,
hasError: true,
},
{ {
name: "exec error", name: "exec error",
args: []interface{}{1}, args: []interface{}{1},
@@ -70,18 +63,13 @@ func TestStmt_exec(t *testing.T) {
}, },
} }
for i, fn := range fns { for _, fn := range fns {
i := i
fn := fn fn := fn
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
t.Parallel() t.Parallel()
res, err := fn(test.args...) res, err := fn(test.args...)
if i == 0 && test.formatError { if test.hasError {
assert.NotNil(t, err)
return
}
if !test.formatError && test.hasError {
assert.NotNil(t, err) assert.NotNil(t, err)
return return
} }
@@ -103,7 +91,6 @@ func TestStmt_query(t *testing.T) {
name string name string
args []interface{} args []interface{}
delay bool delay bool
formatError bool
hasError bool hasError bool
err error err error
}{ }{
@@ -111,12 +98,6 @@ func TestStmt_query(t *testing.T) {
name: "normal", name: "normal",
args: []interface{}{1}, args: []interface{}{1},
}, },
{
name: "wrong format",
args: []interface{}{1, 2},
formatError: true,
hasError: true,
},
{ {
name: "query error", name: "query error",
args: []interface{}{1}, args: []interface{}{1},
@@ -151,18 +132,13 @@ func TestStmt_query(t *testing.T) {
}, },
} }
for i, fn := range fns { for _, fn := range fns {
i := i
fn := fn fn := fn
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
t.Parallel() t.Parallel()
err := fn(test.args...) err := fn(test.args...)
if i == 0 && test.formatError { if test.hasError {
assert.NotNil(t, err)
return
}
if !test.formatError && test.hasError {
assert.NotNil(t, err) assert.NotNil(t, err)
return return
} }

View File

@@ -45,6 +45,24 @@ func escape(input string) string {
return b.String() return b.String()
} }
func formatForPrint(query string, args ...interface{}) string {
if len(args) == 0 {
return query
}
var vals []string
for _, arg := range args {
vals = append(vals, fmt.Sprintf("%q", mapping.Repr(arg)))
}
var b strings.Builder
b.WriteByte('[')
b.WriteString(strings.Join(vals, ", "))
b.WriteByte(']')
return strings.Join([]string{query, b.String()}, " ")
}
func format(query string, args ...interface{}) (string, error) { func format(query string, args ...interface{}) (string, error) {
numArgs := len(args) numArgs := len(args)
if numArgs == 0 { if numArgs == 0 {

View File

@@ -28,3 +28,31 @@ func TestDesensitize_WithoutAccount(t *testing.T) {
datasource = desensitize(datasource) datasource = desensitize(datasource)
assert.True(t, strings.Contains(datasource, "tcp(111.222.333.44:3306)")) assert.True(t, strings.Contains(datasource, "tcp(111.222.333.44:3306)"))
} }
func TestFormatForPrint(t *testing.T) {
tests := []struct {
name string
query string
args []interface{}
expect string
}{
{
name: "no args",
query: "select user, name from table where id=?",
expect: `select user, name from table where id=?`,
},
{
name: "one arg",
query: "select user, name from table where id=?",
args: []interface{}{"kevin"},
expect: `select user, name from table where id=? ["kevin"]`,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
actual := formatForPrint(test.query, test.args...)
assert.Equal(t, test.expect, actual)
})
}
}

View File

@@ -159,7 +159,7 @@ GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/tal-tech/
* API 文档 * API 文档
[https://www.yuque.com/tal-tech/go-zero](https://www.yuque.com/tal-tech/go-zero) [https://zeromicro.github.io/go-zero](https://zeromicro.github.io/go-zero)
* awesome 系列(更多文章见『微服务实践』公众号) * awesome 系列(更多文章见『微服务实践』公众号)
* [快速构建高并发微服务](https://github.com/tal-tech/zero-doc/blob/main/doc/shorturl.md) * [快速构建高并发微服务](https://github.com/tal-tech/zero-doc/blob/main/doc/shorturl.md)

View File

@@ -210,6 +210,12 @@ go get -u github.com/tal-tech/go-zero
* [Rapid development of microservice systems - multiple RPCs](https://github.com/tal-tech/zero-doc/blob/main/docs/zero/bookstore-en.md) * [Rapid development of microservice systems - multiple RPCs](https://github.com/tal-tech/zero-doc/blob/main/docs/zero/bookstore-en.md)
* [Examples](https://github.com/zeromicro/zero-examples) * [Examples](https://github.com/zeromicro/zero-examples)
## 9. Chat group ## 9. Important notes
* Use grpc 1.29.1, because etcd lib doesnt support latter versions.
`google.golang.org/grpc v1.29.1`
## 10. Chat group
Join the chat via https://join.slack.com/t/go-zeroworkspace/shared_invite/zt-m39xssxc-kgIqERa7aVsujKNj~XuPKg Join the chat via https://join.slack.com/t/go-zeroworkspace/shared_invite/zt-m39xssxc-kgIqERa7aVsujKNj~XuPKg

View File

@@ -1,8 +1,10 @@
package handler package handler
import ( import (
"bufio"
"context" "context"
"errors" "errors"
"net"
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
@@ -138,6 +140,16 @@ func (grw *guardedResponseWriter) Header() http.Header {
return grw.writer.Header() return grw.writer.Header()
} }
// Hijack implements the http.Hijacker interface.
// This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it.
func (grw *guardedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hijacked, ok := grw.writer.(http.Hijacker); ok {
return hijacked.Hijack()
}
return nil, nil, errors.New("server doesn't support hijacking")
}
func (grw *guardedResponseWriter) Write(body []byte) (int, error) { func (grw *guardedResponseWriter) Write(body []byte) (int, error) {
return grw.writer.Write(body) return grw.writer.Write(body)
} }

View File

@@ -1,6 +1,8 @@
package handler package handler
import ( import (
"bufio"
"net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
@@ -87,6 +89,26 @@ func TestAuthHandler_NilError(t *testing.T) {
}) })
} }
func TestAuthHandler_Flush(t *testing.T) {
resp := httptest.NewRecorder()
handler := newGuardedResponseWriter(resp)
handler.Flush()
assert.True(t, resp.Flushed)
}
func TestAuthHandler_Hijack(t *testing.T) {
resp := httptest.NewRecorder()
writer := newGuardedResponseWriter(resp)
assert.NotPanics(t, func() {
writer.Hijack()
})
writer = newGuardedResponseWriter(mockedHijackable{resp})
assert.NotPanics(t, func() {
writer.Hijack()
})
}
func buildToken(secretKey string, payloads map[string]interface{}, seconds int64) (string, error) { func buildToken(secretKey string, payloads map[string]interface{}, seconds int64) (string, error) {
now := time.Now().Unix() now := time.Now().Unix()
claims := make(jwt.MapClaims) claims := make(jwt.MapClaims)
@@ -101,3 +123,11 @@ func buildToken(secretKey string, payloads map[string]interface{}, seconds int64
return token.SignedString([]byte(secretKey)) return token.SignedString([]byte(secretKey))
} }
type mockedHijackable struct {
*httptest.ResponseRecorder
}
func (m mockedHijackable) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return nil, nil, nil
}

View File

@@ -1,11 +1,13 @@
package handler package handler
import ( import (
"bufio"
"bytes" "bytes"
"encoding/base64" "encoding/base64"
"errors" "errors"
"io" "io"
"io/ioutil" "io/ioutil"
"net"
"net/http" "net/http"
"github.com/tal-tech/go-zero/core/codec" "github.com/tal-tech/go-zero/core/codec"
@@ -94,6 +96,16 @@ func (w *cryptionResponseWriter) Header() http.Header {
return w.ResponseWriter.Header() return w.ResponseWriter.Header()
} }
// Hijack implements the http.Hijacker interface.
// This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it.
func (w *cryptionResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hijacked, ok := w.ResponseWriter.(http.Hijacker); ok {
return hijacked.Hijack()
}
return nil, nil, errors.New("server doesn't support hijacking")
}
func (w *cryptionResponseWriter) Write(p []byte) (int, error) { func (w *cryptionResponseWriter) Write(p []byte) (int, error) {
return w.buf.Write(p) return w.buf.Write(p)
} }

View File

@@ -103,3 +103,16 @@ func TestCryptionHandlerFlush(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String()) assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String())
} }
func TestCryptionHandler_Hijack(t *testing.T) {
resp := httptest.NewRecorder()
writer := newCryptionResponseWriter(resp)
assert.NotPanics(t, func() {
writer.Hijack()
})
writer = newCryptionResponseWriter(mockedHijackable{resp})
assert.NotPanics(t, func() {
writer.Hijack()
})
}

View File

@@ -1,10 +1,13 @@
package handler package handler
import ( import (
"bufio"
"bytes" "bytes"
"context" "context"
"errors"
"fmt" "fmt"
"io" "io"
"net"
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
"time" "time"
@@ -25,10 +28,26 @@ type loggedResponseWriter struct {
code int code int
} }
func (w *loggedResponseWriter) Flush() {
if flusher, ok := w.w.(http.Flusher); ok {
flusher.Flush()
}
}
func (w *loggedResponseWriter) Header() http.Header { func (w *loggedResponseWriter) Header() http.Header {
return w.w.Header() return w.w.Header()
} }
// Hijack implements the http.Hijacker interface.
// This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it.
func (w *loggedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hijacked, ok := w.w.(http.Hijacker); ok {
return hijacked.Hijack()
}
return nil, nil, errors.New("server doesn't support hijacking")
}
func (w *loggedResponseWriter) Write(bytes []byte) (int, error) { func (w *loggedResponseWriter) Write(bytes []byte) (int, error) {
return w.w.Write(bytes) return w.w.Write(bytes)
} }
@@ -38,12 +57,6 @@ func (w *loggedResponseWriter) WriteHeader(code int) {
w.code = code w.code = code
} }
func (w *loggedResponseWriter) Flush() {
if flusher, ok := w.w.(http.Flusher); ok {
flusher.Flush()
}
}
// LogHandler returns a middleware that logs http request and response. // LogHandler returns a middleware that logs http request and response.
func LogHandler(next http.Handler) http.Handler { func LogHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -83,6 +96,16 @@ func (w *detailLoggedResponseWriter) Header() http.Header {
return w.writer.Header() return w.writer.Header()
} }
// Hijack implements the http.Hijacker interface.
// This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it.
func (w *detailLoggedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hijacked, ok := w.writer.w.(http.Hijacker); ok {
return hijacked.Hijack()
}
return nil, nil, errors.New("server doesn't support hijacking")
}
func (w *detailLoggedResponseWriter) Write(bs []byte) (int, error) { func (w *detailLoggedResponseWriter) Write(bs []byte) (int, error) {
w.buf.Write(bs) w.buf.Write(bs)
return w.writer.Write(bs) return w.writer.Write(bs)

View File

@@ -62,6 +62,44 @@ func TestLogHandlerSlow(t *testing.T) {
} }
} }
func TestLogHandler_Hijack(t *testing.T) {
resp := httptest.NewRecorder()
writer := &loggedResponseWriter{
w: resp,
}
assert.NotPanics(t, func() {
writer.Hijack()
})
writer = &loggedResponseWriter{
w: mockedHijackable{resp},
}
assert.NotPanics(t, func() {
writer.Hijack()
})
}
func TestDetailedLogHandler_Hijack(t *testing.T) {
resp := httptest.NewRecorder()
writer := &detailLoggedResponseWriter{
writer: &loggedResponseWriter{
w: resp,
},
}
assert.NotPanics(t, func() {
writer.Hijack()
})
writer = &detailLoggedResponseWriter{
writer: &loggedResponseWriter{
w: mockedHijackable{resp},
},
}
assert.NotPanics(t, func() {
writer.Hijack()
})
}
func BenchmarkLogHandler(b *testing.B) { func BenchmarkLogHandler(b *testing.B) {
b.ReportAllocs() b.ReportAllocs()

View File

@@ -1,6 +1,10 @@
package security package security
import "net/http" import (
"bufio"
"net"
"net/http"
)
// A WithCodeResponseWriter is a helper to delay sealing a http.ResponseWriter on writing code. // A WithCodeResponseWriter is a helper to delay sealing a http.ResponseWriter on writing code.
type WithCodeResponseWriter struct { type WithCodeResponseWriter struct {
@@ -20,6 +24,12 @@ func (w *WithCodeResponseWriter) Header() http.Header {
return w.Writer.Header() return w.Writer.Header()
} }
// Hijack implements the http.Hijacker interface.
// This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it.
func (w *WithCodeResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return w.Writer.(http.Hijacker).Hijack()
}
// Write writes bytes into w. // Write writes bytes into w.
func (w *WithCodeResponseWriter) Write(bytes []byte) (int, error) { func (w *WithCodeResponseWriter) Write(bytes []byte) (int, error) {
return w.Writer.Write(bytes) return w.Writer.Write(bytes)

View File

@@ -77,6 +77,8 @@ func (e *Server) AddRoute(r Route, opts ...RouteOption) {
} }
// Start starts the Server. // Start starts the Server.
// Graceful shutdown is enabled by default.
// Use proc.SetTimeToForceQuit to customize the graceful shutdown period.
func (e *Server) Start() { func (e *Server) Start() {
handleError(e.opts.start(e.ngin)) handleError(e.opts.start(e.ngin))
} }

View File

@@ -103,17 +103,9 @@ func genComponents(dir, packetName string, api *spec.ApiSpec) error {
} }
func (c *componentsContext) createComponent(dir, packetName string, ty spec.Type) error { func (c *componentsContext) createComponent(dir, packetName string, ty spec.Type) error {
defineStruct, ok := ty.(spec.DefineStruct) defineStruct, done, err := c.checkStruct(ty)
if !ok { if done {
return errors.New("unsupported type %s" + ty.Name()) return err
}
for _, item := range c.requestTypes {
if item.Name() == defineStruct.Name() {
if len(defineStruct.GetFormMembers())+len(defineStruct.GetBodyMembers()) == 0 {
return nil
}
}
} }
modelFile := util.Title(ty.Name()) + ".java" modelFile := util.Title(ty.Name()) + ".java"
@@ -181,6 +173,22 @@ func (c *componentsContext) createComponent(dir, packetName string, ty spec.Type
return err return err
} }
func (c *componentsContext) checkStruct(ty spec.Type) (spec.DefineStruct, bool, error) {
defineStruct, ok := ty.(spec.DefineStruct)
if !ok {
return spec.DefineStruct{}, true, errors.New("unsupported type %s" + ty.Name())
}
for _, item := range c.requestTypes {
if item.Name() == defineStruct.Name() {
if len(defineStruct.GetFormMembers())+len(defineStruct.GetBodyMembers()) == 0 {
return spec.DefineStruct{}, true, nil
}
}
}
return defineStruct, false, nil
}
func (c *componentsContext) buildProperties(defineStruct spec.DefineStruct) (string, error) { func (c *componentsContext) buildProperties(defineStruct spec.DefineStruct) (string, error) {
var builder strings.Builder var builder strings.Builder
if err := c.writeType(&builder, defineStruct); err != nil { if err := c.writeType(&builder, defineStruct); err != nil {

View File

@@ -95,17 +95,9 @@ func specTypeToJava(tp spec.Type) (string, error) {
return "", err return "", err
} }
switch valueType { s := getBaseType(valueType)
case "int": if len(s) == 0 {
return "Integer[]", nil return s, errors.New("unsupported primitive type " + tp.Name())
case "long":
return "Long[]", nil
case "float":
return "Float[]", nil
case "double":
return "Double[]", nil
case "boolean":
return "Boolean[]", nil
} }
return fmt.Sprintf("java.util.ArrayList<%s>", util.Title(valueType)), nil return fmt.Sprintf("java.util.ArrayList<%s>", util.Title(valueType)), nil
@@ -118,6 +110,23 @@ func specTypeToJava(tp spec.Type) (string, error) {
return "", errors.New("unsupported primitive type " + tp.Name()) return "", errors.New("unsupported primitive type " + tp.Name())
} }
func getBaseType(valueType string) string {
switch valueType {
case "int":
return "Integer[]"
case "long":
return "Long[]"
case "float":
return "Float[]"
case "double":
return "Double[]"
case "boolean":
return "Boolean[]"
default:
return ""
}
}
func primitiveType(tp string) (string, bool) { func primitiveType(tp string) (string, bool) {
switch tp { switch tp {
case "string": case "string":

View File

@@ -33,7 +33,7 @@ func (v *ApiVisitor) VisitApi(ctx *api.ApiContext) interface{} {
for _, each := range ctx.AllSpec() { for _, each := range ctx.AllSpec() {
root := each.Accept(v).(*Api) root := each.Accept(v).(*Api)
v.acceptSyntax(root, &final) v.acceptSyntax(root, &final)
v.accetpImport(root, &final) v.acceptImport(root, &final)
v.acceptInfo(root, &final) v.acceptInfo(root, &final)
v.acceptType(root, &final) v.acceptType(root, &final)
v.acceptService(root, &final) v.acceptService(root, &final)
@@ -133,7 +133,7 @@ func (v *ApiVisitor) acceptInfo(root *Api, final *Api) {
} }
} }
func (v *ApiVisitor) accetpImport(root *Api, final *Api) { func (v *ApiVisitor) acceptImport(root *Api, final *Api) {
for _, imp := range root.Import { for _, imp := range root.Import {
if _, ok := final.importM[imp.Value.Text()]; ok { if _, ok := final.importM[imp.Value.Text()]; ok {
v.panic(imp.Import, fmt.Sprintf("duplicate import '%s'", imp.Value.Text())) v.panic(imp.Import, fmt.Sprintf("duplicate import '%s'", imp.Value.Text()))

View File

@@ -50,7 +50,7 @@ type AtDoc struct {
Kv []*KvExpr Kv []*KvExpr
} }
// AtHandler describes service hander ast for api syntax // AtHandler describes service handler ast for api syntax
type AtHandler struct { type AtHandler struct {
AtHandlerToken Expr AtHandlerToken Expr
Name Expr Name Expr
@@ -630,7 +630,7 @@ func (s *Service) Equal(v interface{}) bool {
return s.ServiceApi.Equal(service.ServiceApi) return s.ServiceApi.Equal(service.ServiceApi)
} }
// Get returns the tergate KV by specified key // Get returns the target KV by specified key
func (kv KV) Get(key string) Expr { func (kv KV) Get(key string) Expr {
for _, each := range kv { for _, each := range kv {
if each.Key.Text() == key { if each.Key.Text() == key {

View File

@@ -17,7 +17,7 @@ type (
NameExpr() Expr NameExpr() Expr
} }
// TypeAlias describes alias ast for api syatax // TypeAlias describes alias ast for api syntax
TypeAlias struct { TypeAlias struct {
Name Expr Name Expr
Assign Expr Assign Expr
@@ -26,7 +26,7 @@ type (
CommentExpr Expr CommentExpr Expr
} }
// TypeStruct describes structure ast for api syatax // TypeStruct describes structure ast for api syntax
TypeStruct struct { TypeStruct struct {
Name Expr Name Expr
Struct Expr Struct Expr
@@ -225,7 +225,7 @@ func (v *ApiVisitor) VisitTypeBlockAlias(ctx *api.TypeBlockAliasContext) interfa
alias.DocExpr = v.getDoc(ctx) alias.DocExpr = v.getDoc(ctx)
alias.CommentExpr = v.getComment(ctx) alias.CommentExpr = v.getComment(ctx)
// todo: reopen if necessary // todo: reopen if necessary
v.panic(alias.Name, "unsupport alias") v.panic(alias.Name, "unsupported alias")
return &alias return &alias
} }
@@ -238,7 +238,7 @@ func (v *ApiVisitor) VisitTypeAlias(ctx *api.TypeAliasContext) interface{} {
alias.DocExpr = v.getDoc(ctx) alias.DocExpr = v.getDoc(ctx)
alias.CommentExpr = v.getComment(ctx) alias.CommentExpr = v.getComment(ctx)
// todo: reopen if necessary // todo: reopen if necessary
v.panic(alias.Name, "unsupport alias") v.panic(alias.Name, "unsupported alias")
return &alias return &alias
} }
@@ -319,7 +319,7 @@ func (v *ApiVisitor) VisitDataType(ctx *api.DataTypeContext) interface{} {
if ctx.GetTime() != nil { if ctx.GetTime() != nil {
// todo: reopen if it is necessary // todo: reopen if it is necessary
timeExpr := v.newExprWithToken(ctx.GetTime()) timeExpr := v.newExprWithToken(ctx.GetTime())
v.panic(timeExpr, "unsupport time.Time") v.panic(timeExpr, "unsupported time.Time")
return &Time{Literal: timeExpr} return &Time{Literal: timeExpr}
} }
if ctx.PointerType() != nil { if ctx.PointerType() != nil {

View File

@@ -219,7 +219,7 @@ func (p parser) fillService() error {
for _, astRoute := range item.ServiceApi.ServiceRoute { for _, astRoute := range item.ServiceApi.ServiceRoute {
route := spec.Route{ route := spec.Route{
Annotation: spec.Annotation{}, AtServerAnnotation: spec.Annotation{},
Method: astRoute.Route.Method.Text(), Method: astRoute.Route.Method.Text(),
Path: astRoute.Route.Path.Text(), Path: astRoute.Route.Path.Text(),
} }
@@ -275,7 +275,7 @@ func (p parser) fillRouteAtServer(astRoute *ast.ServiceRoute, route *spec.Route)
for _, kv := range astRoute.AtServer.Kv { for _, kv := range astRoute.AtServer.Kv {
properties[kv.Key.Text()] = kv.Value.Text() properties[kv.Key.Text()] = kv.Value.Text()
} }
route.Annotation.Properties = properties route.AtServerAnnotation.Properties = properties
if len(route.Handler) == 0 { if len(route.Handler) == 0 {
route.Handler = properties["handler"] route.Handler = properties["handler"]
} }

View File

@@ -11,10 +11,11 @@ import (
const ( const (
bodyTagKey = "json" bodyTagKey = "json"
formTagKey = "form" formTagKey = "form"
pathTagKey = "path"
defaultSummaryKey = "summary" defaultSummaryKey = "summary"
) )
var definedKeys = []string{bodyTagKey, formTagKey, "path"} var definedKeys = []string{bodyTagKey, formTagKey, pathTagKey}
// Routes returns all routes in api service // Routes returns all routes in api service
func (s Service) Routes() []Route { func (s Service) Routes() []Route {
@@ -25,7 +26,7 @@ func (s Service) Routes() []Route {
return result return result
} }
// Tags retuens all tags in Member // Tags returns all tags in Member
func (m Member) Tags() []*Tag { func (m Member) Tags() []*Tag {
tags, err := Parse(m.Tag) tags, err := Parse(m.Tag)
if err != nil { if err != nil {
@@ -141,7 +142,7 @@ func (t DefineStruct) GetFormMembers() []Member {
return result return result
} }
// GetNonBodyMembers retruns all have no tag fields // GetNonBodyMembers returns all have no tag fields
func (t DefineStruct) GetNonBodyMembers() []Member { func (t DefineStruct) GetNonBodyMembers() []Member {
var result []Member var result []Member
for _, member := range t.Members { for _, member := range t.Members {
@@ -162,16 +163,16 @@ func (r Route) JoinedDoc() string {
return strings.TrimSpace(doc) return strings.TrimSpace(doc)
} }
// GetAnnotation returns the value by specified key // GetAnnotation returns the value by specified key from @server
func (r Route) GetAnnotation(key string) string { func (r Route) GetAnnotation(key string) string {
if r.Annotation.Properties == nil { if r.AtServerAnnotation.Properties == nil {
return "" return ""
} }
return r.Annotation.Properties[key] return r.AtServerAnnotation.Properties[key]
} }
// GetAnnotation returns the value by specified key // GetAnnotation returns the value by specified key from @server
func (g Group) GetAnnotation(key string) string { func (g Group) GetAnnotation(key string) string {
if g.Annotation.Properties == nil { if g.Annotation.Properties == nil {
return "" return ""

View File

@@ -63,7 +63,7 @@ type (
// Route describes api route // Route describes api route
Route struct { Route struct {
Annotation Annotation AtServerAnnotation Annotation
Method string Method string
Path string Path string
RequestType Type RequestType Type

View File

@@ -11,7 +11,7 @@ import (
"github.com/urfave/cli" "github.com/urfave/cli"
) )
// TsCommand provides the entry to generting typescript codes // TsCommand provides the entry to generate typescript codes
func TsCommand(c *cli.Context) error { func TsCommand(c *cli.Context) error {
apiFile := c.String("api") apiFile := c.String("api")
dir := c.String("dir") dir := c.String("dir")

View File

@@ -19,6 +19,7 @@ import (
"github.com/tal-tech/go-zero/tools/goctl/configgen" "github.com/tal-tech/go-zero/tools/goctl/configgen"
"github.com/tal-tech/go-zero/tools/goctl/docker" "github.com/tal-tech/go-zero/tools/goctl/docker"
"github.com/tal-tech/go-zero/tools/goctl/kube" "github.com/tal-tech/go-zero/tools/goctl/kube"
"github.com/tal-tech/go-zero/tools/goctl/model/mongo"
model "github.com/tal-tech/go-zero/tools/goctl/model/sql/command" model "github.com/tal-tech/go-zero/tools/goctl/model/sql/command"
"github.com/tal-tech/go-zero/tools/goctl/plugin" "github.com/tal-tech/go-zero/tools/goctl/plugin"
rpc "github.com/tal-tech/go-zero/tools/goctl/rpc/cli" rpc "github.com/tal-tech/go-zero/tools/goctl/rpc/cli"
@@ -28,7 +29,7 @@ import (
) )
var ( var (
buildVersion = "1.1.5" buildVersion = "1.1.6"
commands = []cli.Command{ commands = []cli.Command{
{ {
Name: "upgrade", Name: "upgrade",
@@ -447,6 +448,29 @@ var (
}, },
}, },
}, },
{
Name: "mongo",
Usage: `generate mongo model`,
Flags: []cli.Flag{
cli.StringSliceFlag{
Name: "type, t",
Usage: "specified model type name",
},
cli.BoolFlag{
Name: "cache, c",
Usage: "generate code with cache [optional]",
},
cli.StringFlag{
Name: "dir, d",
Usage: "the target dir",
},
cli.StringFlag{
Name: "style",
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
},
},
Action: mongo.Action,
},
}, },
}, },
{ {

View File

@@ -0,0 +1,69 @@
package generate
import (
"errors"
"path/filepath"
"github.com/tal-tech/go-zero/tools/goctl/config"
"github.com/tal-tech/go-zero/tools/goctl/model/mongo/template"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/format"
)
// Context defines the model generation data what they needs
type Context struct {
Types []string
Cache bool
Output string
Cfg *config.Config
}
// Do executes model template and output the result into the specified file path
func Do(ctx *Context) error {
if ctx.Cfg == nil {
return errors.New("missing config")
}
err := generateModel(ctx)
if err != nil {
return err
}
return generateError(ctx)
}
func generateModel(ctx *Context) error {
for _, t := range ctx.Types {
fn, err := format.FileNamingFormat(ctx.Cfg.NamingFormat, t+"_model")
if err != nil {
return err
}
text, err := util.LoadTemplate(category, modelTemplateFile, template.Text)
if err != nil {
return err
}
output := filepath.Join(ctx.Output, fn+".go")
err = util.With("model").Parse(text).GoFmt(true).SaveTo(map[string]interface{}{
"Type": t,
"Cache": ctx.Cache,
}, output, false)
if err != nil {
return err
}
}
return nil
}
func generateError(ctx *Context) error {
text, err := util.LoadTemplate(category, errTemplateFile, template.Error)
if err != nil {
return err
}
output := filepath.Join(ctx.Output, "error.go")
return util.With("error").Parse(text).GoFmt(true).SaveTo(ctx, output, false)
}

View File

@@ -0,0 +1,34 @@
package generate
import (
"io/ioutil"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/tools/goctl/config"
)
var testTypes = `
type User struct{}
type Class struct{}
`
func TestDo(t *testing.T) {
cfg, err := config.NewConfig(config.DefaultFormat)
assert.Nil(t, err)
tempDir := t.TempDir()
typesfile := filepath.Join(tempDir, "types.go")
err = ioutil.WriteFile(typesfile, []byte(testTypes), 0666)
assert.Nil(t, err)
err = Do(&Context{
Types: []string{"User", "Class"},
Cache: false,
Output: tempDir,
Cfg: cfg,
})
assert.Nil(t, err)
}

View File

@@ -0,0 +1,55 @@
package generate
import (
"fmt"
"github.com/tal-tech/go-zero/tools/goctl/model/mongo/template"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/urfave/cli"
)
const (
category = "mongo"
modelTemplateFile = "model.tpl"
errTemplateFile = "err.tpl"
)
var templates = map[string]string{
modelTemplateFile: template.Text,
errTemplateFile: template.Error,
}
// Category returns the mongo category.
func Category() string {
return category
}
// Clean cleans the mongo templates.
func Clean() error {
return util.Clean(category)
}
// Templates initializes the mongo templates.
func Templates(_ *cli.Context) error {
return util.InitTemplates(category, templates)
}
// RevertTemplate reverts the given template.
func RevertTemplate(name string) error {
content, ok := templates[name]
if !ok {
return fmt.Errorf("%s: no such file name", name)
}
return util.CreateTemplate(category, name, content)
}
// Update cleans and updates the templates.
func Update() error {
err := Clean()
if err != nil {
return err
}
return util.InitTemplates(category, templates)
}

View File

@@ -0,0 +1,39 @@
package mongo
import (
"errors"
"path/filepath"
"strings"
"github.com/tal-tech/go-zero/tools/goctl/config"
"github.com/tal-tech/go-zero/tools/goctl/model/mongo/generate"
"github.com/urfave/cli"
)
// Action provides the entry for goctl mongo code generation.
func Action(ctx *cli.Context) error {
tp := ctx.StringSlice("type")
c := ctx.Bool("cache")
o := strings.TrimSpace(ctx.String("dir"))
s := ctx.String("style")
if len(tp) == 0 {
return errors.New("missing type")
}
cfg, err := config.NewConfig(s)
if err != nil {
return err
}
a, err := filepath.Abs(o)
if err != nil {
return err
}
return generate.Do(&generate.Context{
Types: tp,
Cache: c,
Output: a,
Cfg: cfg,
})
}

View File

@@ -0,0 +1,210 @@
# mongo生成model
## 背景
在业务务开发中model(dao)数据访问层是一个服务必不可缺的一层因此数据库访问的CURD也是必须要对外提供的访问方法 而CURD在go-zero中就仅存在两种情况
* 带缓存model
* 不带缓存model
从代码结构上来看C-U-R-D四个方法就是固定的结构因此我们可以将其交给goctl工具去完成帮助我们提升开发效率。
## 方案设计
mongo的生成不同于mysqlmysql可以从scheme_information库中读取到一张表的信息字段名称数据类型索引等
而mongo是文档型数据库我们暂时无法从db中读取某一条记录来实现字段信息获取就算有也不一定是完整信息某些字段可能是omitempty修饰可有可无 这里采用type自己编写+代码生成方式实现
## 使用示例
假设我们需要生成一个usermodel.go的代码文件其包含用户信息字段有
|字段名称|字段类型|
|---|---|
|_id|bson.ObejctId|
|name|string|
### 编写types.go
```shell
$ vim types.go
```
```golang
package model
//go:generate goctl model mongo -t User
import "github.com/globalsign/mgo/bson"
type User struct {
ID bson.ObjectId `bson:"_id"`
Name string `bson:"name"`
}
```
### 生成代码
生成代码的方式有两种
* 命令行生成 在types.go所在文件夹执行命令
```shell
$ goctl model mongo -t User -style gozero
```
* 在types.go中添加`//go:generate`,然后点击执行按钮即可生成,内容示例如下:
```golang
//go:generate goctl model mongo -t User
```
### 生成示例代码
* usermodel.go
```golang
package model
import (
"context"
"github.com/globalsign/mgo/bson"
cachec "github.com/tal-tech/go-zero/core/stores/cache"
"github.com/tal-tech/go-zero/core/stores/mongoc"
)
type UserModel interface {
Insert(data *User, ctx context.Context) error
FindOne(id string, ctx context.Context) (*User, error)
Update(data *User, ctx context.Context) error
Delete(id string, ctx context.Context) error
}
type defaultUserModel struct {
*mongoc.Model
}
func NewUserModel(url, collection string, c cachec.CacheConf) UserModel {
return &defaultUserModel{
Model: mongoc.MustNewModel(url, collection, c),
}
}
func (m *defaultUserModel) Insert(data *User, ctx context.Context) error {
if !data.ID.Valid() {
data.ID = bson.NewObjectId()
}
session, err := m.TakeSession()
if err != nil {
return err
}
defer m.PutSession(session)
return m.GetCollection(session).Insert(data)
}
func (m *defaultUserModel) FindOne(id string, ctx context.Context) (*User, error) {
if !bson.IsObjectIdHex(id) {
return nil, ErrInvalidObjectId
}
session, err := m.TakeSession()
if err != nil {
return nil, err
}
defer m.PutSession(session)
var data User
err = m.GetCollection(session).FindOneIdNoCache(&data, bson.ObjectIdHex(id))
switch err {
case nil:
return &data, nil
case mongoc.ErrNotFound:
return nil, ErrNotFound
default:
return nil, err
}
}
func (m *defaultUserModel) Update(data *User, ctx context.Context) error {
session, err := m.TakeSession()
if err != nil {
return err
}
defer m.PutSession(session)
return m.GetCollection(session).UpdateIdNoCache(data.ID, data)
}
func (m *defaultUserModel) Delete(id string, ctx context.Context) error {
session, err := m.TakeSession()
if err != nil {
return err
}
defer m.PutSession(session)
return m.GetCollection(session).RemoveIdNoCache(bson.ObjectIdHex(id))
}
```
* error.go
```golang
package model
import "errors"
var ErrNotFound = errors.New("not found")
var ErrInvalidObjectId = errors.New("invalid objectId")
```
### 文件目录预览
```text
.
├── error.go
├── types.go
└── usermodel.go
```
## 命令预览
```text
NAME:
goctl model - generate model code
USAGE:
goctl model command [command options] [arguments...]
COMMANDS:
mysql generate mysql model
mongo generate mongo model
OPTIONS:
--help, -h show help
```
```text
NAME:
goctl model mongo - generate mongo model
USAGE:
goctl model mongo [command options] [arguments...]
OPTIONS:
--type value, -t value specified model type name
--cache, -c generate code with cache [optional]
--dir value, -d value the target dir
--style value the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]
```
> 温馨提示
>
> `--type` 支持slice传值示例 `goctl model mongo -t=User -t=Class`
## 注意事项
types.go本质上与xxxmodel.go无关只是将type定义部分交给开发人员自己编写了在xxxmodel.go中mongo文档的存储结构必须包含
`_id`字段对应到types中的field为`ID`model中的findOne,update均以data.ID来进行操作的当然如果不符合你的命名风格你也 可以修改模板,只要保证`id`
在types中的field名称和模板中一致就行。

View File

@@ -0,0 +1,112 @@
package template
// Text provides the default template for model to generate
var Text = `package model
import (
"context"
"github.com/globalsign/mgo/bson"
cachec "github.com/tal-tech/go-zero/core/stores/cache"
"github.com/tal-tech/go-zero/core/stores/mongoc"
)
{{if .Cache}}var prefix{{.Type}}CacheKey = "cache#{{.Type}}#"{{end}}
type {{.Type}}Model interface{
Insert(ctx context.Context,data *{{.Type}}) error
FindOne(ctx context.Context,id string) (*{{.Type}}, error)
Update(ctx context.Context,data *{{.Type}}) error
Delete(ctx context.Context,id string) error
}
type default{{.Type}}Model struct {
*mongoc.Model
}
func New{{.Type}}Model(url, collection string, c cachec.CacheConf) {{.Type}}Model {
return &default{{.Type}}Model{
Model: mongoc.MustNewModel(url, collection, c),
}
}
func (m *default{{.Type}}Model) Insert(ctx context.Context, data *{{.Type}}) error {
if !data.ID.Valid() {
data.ID = bson.NewObjectId()
}
session, err := m.TakeSession()
if err != nil {
return err
}
defer m.PutSession(session)
return m.GetCollection(session).Insert(data)
}
func (m *default{{.Type}}Model) FindOne(ctx context.Context, id string) (*{{.Type}}, error) {
if !bson.IsObjectIdHex(id) {
return nil, ErrInvalidObjectId
}
session, err := m.TakeSession()
if err != nil {
return nil, err
}
defer m.PutSession(session)
var data {{.Type}}
{{if .Cache}}key := prefix{{.Type}}CacheKey + id
err = m.GetCollection(session).FindOneId(&data, key, bson.ObjectIdHex(id))
{{- else}}
err = m.GetCollection(session).FindOneIdNoCache(&data, bson.ObjectIdHex(id))
{{- end}}
switch err {
case nil:
return &data,nil
case mongoc.ErrNotFound:
return nil,ErrNotFound
default:
return nil,err
}
}
func (m *default{{.Type}}Model) Update(ctx context.Context, data *{{.Type}}) error {
session, err := m.TakeSession()
if err != nil {
return err
}
defer m.PutSession(session)
{{if .Cache}}key := prefix{{.Type}}CacheKey + data.ID.Hex()
return m.GetCollection(session).UpdateId(data.ID, data, key)
{{- else}}
return m.GetCollection(session).UpdateIdNoCache(data.ID, data)
{{- end}}
}
func (m *default{{.Type}}Model) Delete(ctx context.Context, id string) error {
session, err := m.TakeSession()
if err != nil {
return err
}
defer m.PutSession(session)
{{if .Cache}}key := prefix{{.Type}}CacheKey + id
return m.GetCollection(session).RemoveId(bson.ObjectIdHex(id), key)
{{- else}}
return m.GetCollection(session).RemoveIdNoCache(bson.ObjectIdHex(id))
{{- end}}
}
`
// Error provides the default template for error definition in mongo code generation.
var Error = `
package model
import "errors"
var ErrNotFound = errors.New("not found")
var ErrInvalidObjectId = errors.New("invalid objectId")
`

View File

@@ -6,6 +6,8 @@ import (
"path/filepath" "path/filepath"
"testing" "testing"
"github.com/tal-tech/go-zero/tools/goctl/model/sql/gen"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/tools/goctl/config" "github.com/tal-tech/go-zero/tools/goctl/config"
"github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util"
@@ -19,7 +21,10 @@ var (
) )
func TestFromDDl(t *testing.T) { func TestFromDDl(t *testing.T) {
err := fromDDl("./user.sql", t.TempDir(), cfg, true, false) err := gen.Clean()
assert.Nil(t, err)
err = fromDDl("./user.sql", t.TempDir(), cfg, true, false)
assert.Equal(t, errNotMatched, err) assert.Equal(t, errNotMatched, err)
// case dir is not exists // case dir is not exists

View File

@@ -25,27 +25,7 @@ func genFindOneByField(table Table, withCache bool) (*findOneCode, error) {
var list []string var list []string
camelTableName := table.Name.ToCamel() camelTableName := table.Name.ToCamel()
for _, key := range table.UniqueCacheKey { for _, key := range table.UniqueCacheKey {
var inJoin, paramJoin, argJoin Join in, paramJoinString, originalFieldString := convertJoin(key)
for _, f := range key.Fields {
param := stringx.From(f.Name.ToCamel()).Untitle()
inJoin = append(inJoin, fmt.Sprintf("%s %s", param, f.DataType))
paramJoin = append(paramJoin, param)
argJoin = append(argJoin, fmt.Sprintf("%s = ?", wrapWithRawString(f.Name.Source())))
}
var in string
if len(inJoin) > 0 {
in = inJoin.With(", ").Source()
}
var paramJoinString string
if len(paramJoin) > 0 {
paramJoinString = paramJoin.With(",").Source()
}
var originalFieldString string
if len(argJoin) > 0 {
originalFieldString = argJoin.With(" and ").Source()
}
output, err := t.Execute(map[string]interface{}{ output, err := t.Execute(map[string]interface{}{
"upperStartCamelObject": camelTableName, "upperStartCamelObject": camelTableName,
@@ -125,3 +105,25 @@ func genFindOneByField(table Table, withCache bool) (*findOneCode, error) {
findOneInterfaceMethod: strings.Join(listMethod, util.NL), findOneInterfaceMethod: strings.Join(listMethod, util.NL),
}, nil }, nil
} }
func convertJoin(key Key) (in, paramJoinString, originalFieldString string) {
var inJoin, paramJoin, argJoin Join
for _, f := range key.Fields {
param := stringx.From(f.Name.ToCamel()).Untitle()
inJoin = append(inJoin, fmt.Sprintf("%s %s", param, f.DataType))
paramJoin = append(paramJoin, param)
argJoin = append(argJoin, fmt.Sprintf("%s = ?", wrapWithRawString(f.Name.Source())))
}
if len(inJoin) > 0 {
in = inJoin.With(", ").Source()
}
if len(paramJoin) > 0 {
paramJoinString = paramJoin.With(",").Source()
}
if len(argJoin) > 0 {
originalFieldString = argJoin.With(" and ").Source()
}
return in, paramJoinString, originalFieldString
}

View File

@@ -11,15 +11,15 @@ import (
// Key describes cache key // Key describes cache key
type Key struct { type Key struct {
// VarLeft describes the varible of cache key expression which likes cacheUserIdPrefix // VarLeft describes the variable of cache key expression which likes cacheUserIdPrefix
VarLeft string VarLeft string
// VarRight describes the value of cache key expression which likes "cache#user#id#" // VarRight describes the value of cache key expression which likes "cache#user#id#"
VarRight string VarRight string
// VarExpression describes the cache key expression which likes cacheUserIdPrefix = "cache#user#id#" // VarExpression describes the cache key expression which likes cacheUserIdPrefix = "cache#user#id#"
VarExpression string VarExpression string
// KeyLeft describes the varible of key definiation expression which likes userKey // KeyLeft describes the variable of key definition expression which likes userKey
KeyLeft string KeyLeft string
// KeyRight describes the value of key definiation expression which likes fmt.Sprintf("%s%v", cacheUserPrefix, user) // KeyRight describes the value of key definition expression which likes fmt.Sprintf("%s%v", cacheUserPrefix, user)
KeyRight string KeyRight string
// DataKeyRight describes data key likes fmt.Sprintf("%s%v", cacheUserPrefix, data.User) // DataKeyRight describes data key likes fmt.Sprintf("%s%v", cacheUserPrefix, data.User)
DataKeyRight string DataKeyRight string

View File

@@ -3,6 +3,7 @@ package gen
import ( import (
"strings" "strings"
"github.com/tal-tech/go-zero/core/collection"
"github.com/tal-tech/go-zero/tools/goctl/model/sql/template" "github.com/tal-tech/go-zero/tools/goctl/model/sql/template"
"github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx" "github.com/tal-tech/go-zero/tools/goctl/util/stringx"
@@ -23,6 +24,15 @@ func genUpdate(table Table, withCache bool) (string, string, error) {
expressionValues = append(expressionValues, "data."+camel) expressionValues = append(expressionValues, "data."+camel)
} }
keySet := collection.NewSet()
keyVariableSet := collection.NewSet()
keySet.AddStr(table.PrimaryCacheKey.DataKeyExpression)
keyVariableSet.AddStr(table.PrimaryCacheKey.KeyLeft)
for _, key := range table.UniqueCacheKey {
keySet.AddStr(key.DataKeyExpression)
keyVariableSet.AddStr(key.KeyLeft)
}
expressionValues = append(expressionValues, "data."+table.PrimaryKey.Name.ToCamel()) expressionValues = append(expressionValues, "data."+table.PrimaryKey.Name.ToCamel())
camelTableName := table.Name.ToCamel() camelTableName := table.Name.ToCamel()
text, err := util.LoadTemplate(category, updateTemplateFile, template.Update) text, err := util.LoadTemplate(category, updateTemplateFile, template.Update)
@@ -35,6 +45,8 @@ func genUpdate(table Table, withCache bool) (string, string, error) {
Execute(map[string]interface{}{ Execute(map[string]interface{}{
"withCache": withCache, "withCache": withCache,
"upperStartCamelObject": camelTableName, "upperStartCamelObject": camelTableName,
"keys": strings.Join(keySet.KeysStr(), "\n"),
"keyValues": strings.Join(keyVariableSet.KeysStr(), ", "),
"primaryCacheKey": table.PrimaryCacheKey.DataKeyExpression, "primaryCacheKey": table.PrimaryCacheKey.DataKeyExpression,
"primaryKeyVariable": table.PrimaryCacheKey.KeyLeft, "primaryKeyVariable": table.PrimaryCacheKey.KeyLeft,
"lowerStartCamelObject": stringx.From(camelTableName).Untitle(), "lowerStartCamelObject": stringx.From(camelTableName).Untitle(),

View File

@@ -102,6 +102,17 @@ func Parse(ddl string) (*Table, error) {
} }
} }
checkDuplicateUniqueIndex(uniqueIndex, tableName, normalIndex)
return &Table{
Name: stringx.From(tableName),
PrimaryKey: primaryKey,
UniqueIndex: uniqueIndex,
NormalIndex: normalIndex,
Fields: fields,
}, nil
}
func checkDuplicateUniqueIndex(uniqueIndex map[string][]*Field, tableName string, normalIndex map[string][]*Field) {
log := console.NewColorConsole() log := console.NewColorConsole()
uniqueSet := collection.NewSet() uniqueSet := collection.NewSet()
for k, i := range uniqueIndex { for k, i := range uniqueIndex {
@@ -136,14 +147,6 @@ func Parse(ddl string) (*Table, error) {
normalIndexSet.Add(joinRet) normalIndexSet.Add(joinRet)
} }
return &Table{
Name: stringx.From(tableName),
PrimaryKey: primaryKey,
UniqueIndex: uniqueIndex,
NormalIndex: normalIndex,
Fields: fields,
}, nil
} }
func convertColumns(columns []*sqlparser.ColumnDefinition, primaryColumn string) (Primary, map[string]*Field, error) { func convertColumns(columns []*sqlparser.ColumnDefinition, primaryColumn string) (Primary, map[string]*Field, error) {
@@ -289,28 +292,10 @@ func ConvertDataType(table *model.Table) (*Table, error) {
AutoIncrement: strings.Contains(table.PrimaryKey.Extra, "auto_increment"), AutoIncrement: strings.Contains(table.PrimaryKey.Extra, "auto_increment"),
} }
fieldM := make(map[string]*Field) fieldM, err := getTableFields(table)
for _, each := range table.Columns {
isDefaultNull := each.ColumnDefault == nil && each.IsNullAble == "YES"
dt, err := converter.ConvertDataType(each.DataType, isDefaultNull)
if err != nil { if err != nil {
return nil, err return nil, err
} }
columnSeqInIndex := 0
if each.Index != nil {
columnSeqInIndex = each.Index.SeqInIndex
}
field := &Field{
Name: stringx.From(each.Name),
DataBaseType: each.DataType,
DataType: dt,
Comment: each.Comment,
SeqInIndex: columnSeqInIndex,
OrdinalPosition: each.OrdinalPosition,
}
fieldM[each.Name] = field
}
for _, each := range fieldM { for _, each := range fieldM {
reply.Fields = append(reply.Fields, each) reply.Fields = append(reply.Fields, each)
@@ -379,3 +364,29 @@ func ConvertDataType(table *model.Table) (*Table, error) {
return &reply, nil return &reply, nil
} }
func getTableFields(table *model.Table) (map[string]*Field, error) {
fieldM := make(map[string]*Field)
for _, each := range table.Columns {
isDefaultNull := each.ColumnDefault == nil && each.IsNullAble == "YES"
dt, err := converter.ConvertDataType(each.DataType, isDefaultNull)
if err != nil {
return nil, err
}
columnSeqInIndex := 0
if each.Index != nil {
columnSeqInIndex = each.Index.SeqInIndex
}
field := &Field{
Name: stringx.From(each.Name),
DataBaseType: each.DataType,
DataType: dt,
Comment: each.Comment,
SeqInIndex: columnSeqInIndex,
OrdinalPosition: each.OrdinalPosition,
}
fieldM[each.Name] = field
}
return fieldM, nil
}

View File

@@ -3,11 +3,11 @@ package template
// Update defines a template for generating update codes // Update defines a template for generating update codes
var Update = ` var Update = `
func (m *default{{.upperStartCamelObject}}Model) Update(data {{.upperStartCamelObject}}) error { func (m *default{{.upperStartCamelObject}}Model) Update(data {{.upperStartCamelObject}}) error {
{{if .withCache}}{{.primaryCacheKey}} {{if .withCache}}{{.keys}}
_, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) { _, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) {
query := fmt.Sprintf("update %s set %s where {{.originalPrimaryKey}} = ?", m.table, {{.lowerStartCamelObject}}RowsWithPlaceHolder) query := fmt.Sprintf("update %s set %s where {{.originalPrimaryKey}} = ?", m.table, {{.lowerStartCamelObject}}RowsWithPlaceHolder)
return conn.Exec(query, {{.expressionValues}}) return conn.Exec(query, {{.expressionValues}})
}, {{.primaryKeyVariable}}){{else}}query := fmt.Sprintf("update %s set %s where {{.originalPrimaryKey}} = ?", m.table, {{.lowerStartCamelObject}}RowsWithPlaceHolder) }, {{.keyValues}}){{else}}query := fmt.Sprintf("update %s set %s where {{.originalPrimaryKey}} = ?", m.table, {{.lowerStartCamelObject}}RowsWithPlaceHolder)
_,err:=m.conn.Exec(query, {{.expressionValues}}){{end}} _,err:=m.conn.Exec(query, {{.expressionValues}}){{end}}
return err return err
} }

View File

@@ -8,6 +8,7 @@ import (
"github.com/tal-tech/go-zero/tools/goctl/api/gogen" "github.com/tal-tech/go-zero/tools/goctl/api/gogen"
"github.com/tal-tech/go-zero/tools/goctl/docker" "github.com/tal-tech/go-zero/tools/goctl/docker"
"github.com/tal-tech/go-zero/tools/goctl/kube" "github.com/tal-tech/go-zero/tools/goctl/kube"
mongogen "github.com/tal-tech/go-zero/tools/goctl/model/mongo/generate"
modelgen "github.com/tal-tech/go-zero/tools/goctl/model/sql/gen" modelgen "github.com/tal-tech/go-zero/tools/goctl/model/sql/gen"
rpcgen "github.com/tal-tech/go-zero/tools/goctl/rpc/generator" rpcgen "github.com/tal-tech/go-zero/tools/goctl/rpc/generator"
"github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util"
@@ -16,7 +17,7 @@ import (
const templateParentPath = "/" const templateParentPath = "/"
// GenTemplates wtites the latest template text into file which is not exists // GenTemplates writes the latest template text into file which is not exists
func GenTemplates(ctx *cli.Context) error { func GenTemplates(ctx *cli.Context) error {
if err := errorx.Chain( if err := errorx.Chain(
func() error { func() error {
@@ -34,6 +35,9 @@ func GenTemplates(ctx *cli.Context) error {
func() error { func() error {
return kube.GenTemplates(ctx) return kube.GenTemplates(ctx)
}, },
func() error {
return mongogen.Templates(ctx)
},
); err != nil { ); err != nil {
return err return err
} }
@@ -61,6 +65,15 @@ func CleanTemplates(_ *cli.Context) error {
func() error { func() error {
return rpcgen.Clean() return rpcgen.Clean()
}, },
func() error {
return docker.Clean()
},
func() error {
return kube.Clean()
},
func() error {
return mongogen.Clean()
},
) )
if err != nil { if err != nil {
return err return err
@@ -70,8 +83,8 @@ func CleanTemplates(_ *cli.Context) error {
return nil return nil
} }
// UpdateTemplates wtites the latest template text into file, // UpdateTemplates writes the latest template text into file,
// it will delete the oldler templates if there are exists // it will delete the older templates if there are exists
func UpdateTemplates(ctx *cli.Context) (err error) { func UpdateTemplates(ctx *cli.Context) (err error) {
category := ctx.String("category") category := ctx.String("category")
defer func() { defer func() {
@@ -90,6 +103,8 @@ func UpdateTemplates(ctx *cli.Context) (err error) {
return rpcgen.Update() return rpcgen.Update()
case modelgen.Category(): case modelgen.Category():
return modelgen.Update() return modelgen.Update()
case mongogen.Category():
return mongogen.Update()
default: default:
err = fmt.Errorf("unexpected category: %s", category) err = fmt.Errorf("unexpected category: %s", category)
return return
@@ -116,6 +131,8 @@ func RevertTemplates(ctx *cli.Context) (err error) {
return rpcgen.RevertTemplate(filename) return rpcgen.RevertTemplate(filename)
case modelgen.Category(): case modelgen.Category():
return modelgen.RevertTemplate(filename) return modelgen.RevertTemplate(filename)
case mongogen.Category():
return mongogen.RevertTemplate(filename)
default: default:
err = fmt.Errorf("unexpected category: %s", category) err = fmt.Errorf("unexpected category: %s", category)
return return

View File

@@ -9,7 +9,7 @@ import (
type ( type (
// Console wraps from the fmt.Sprintf, // Console wraps from the fmt.Sprintf,
// by default, it implemented the colorConsole to provide the colorful output to the consle // by default, it implemented the colorConsole to provide the colorful output to the console
// and the ideaConsole to output with prefix for the plugin of intellij // and the ideaConsole to output with prefix for the plugin of intellij
Console interface { Console interface {
Success(format string, a ...interface{}) Success(format string, a ...interface{})
@@ -81,7 +81,7 @@ func (c *colorConsole) Must(err error) {
} }
} }
// NewIdeaConsole returns a instace of ideaConsole // NewIdeaConsole returns a instance of ideaConsole
func NewIdeaConsole() Console { func NewIdeaConsole() Console {
return &ideaConsole{} return &ideaConsole{}
} }

View File

@@ -27,7 +27,7 @@ func CreateIfNotExist(file string) (*os.File, error) {
return os.Create(file) return os.Create(file)
} }
// RemoveIfExist deletes the specficed file if it is exists // RemoveIfExist deletes the specified file if it is exists
func RemoveIfExist(filename string) error { func RemoveIfExist(filename string) error {
if !FileExists(filename) { if !FileExists(filename) {
return nil return nil
@@ -36,7 +36,7 @@ func RemoveIfExist(filename string) error {
return os.Remove(filename) return os.Remove(filename)
} }
// RemoveOrQuit deletes the specficed file if read a permit command from stdin // RemoveOrQuit deletes the specified file if read a permit command from stdin
func RemoveOrQuit(filename string) error { func RemoveOrQuit(filename string) error {
if !FileExists(filename) { if !FileExists(filename) {
return nil return nil
@@ -49,7 +49,7 @@ func RemoveOrQuit(filename string) error {
return os.Remove(filename) return os.Remove(filename)
} }
// FileExists returns true if the specficed file is exists // FileExists returns true if the specified file is exists
func FileExists(file string) bool { func FileExists(file string) bool {
_, err := os.Stat(file) _, err := os.Stat(file)
return err == nil return err == nil

View File

@@ -18,7 +18,7 @@ const (
upper upper
) )
// ErrNamingFormat defines an error for unknown fomat // ErrNamingFormat defines an error for unknown format
var ErrNamingFormat = errors.New("unsupported format") var ErrNamingFormat = errors.New("unsupported format")
type ( type (

View File

@@ -33,7 +33,7 @@ func MkdirIfNotExist(dir string) error {
return nil return nil
} }
// PathFromGoSrc returns the path whihout slash where has been trim the prefix $GOPATH // PathFromGoSrc returns the path without slash where has been trim the prefix $GOPATH
func PathFromGoSrc() (string, error) { func PathFromGoSrc() (string, error) {
dir, err := os.Getwd() dir, err := os.Getwd()
if err != nil { if err != nil {

View File

@@ -6,7 +6,7 @@ import (
"unicode" "unicode"
) )
// String provides for coverting the source text into other spell case,like lower,snake,camel // String provides for converting the source text into other spell case,like lower,snake,camel
type String struct { type String struct {
source string source string
} }

View File

@@ -17,7 +17,7 @@ type DefaultTemplate struct {
savePath string savePath string
} }
// With returns a instace of DefaultTemplate // With returns a instance of DefaultTemplate
func With(name string) *DefaultTemplate { func With(name string) *DefaultTemplate {
return &DefaultTemplate{ return &DefaultTemplate{
name: name, name: name,
@@ -30,7 +30,7 @@ func (t *DefaultTemplate) Parse(text string) *DefaultTemplate {
return t return t
} }
// GoFmt sets the value to goFmt and marks the generated codes will be formated or not // GoFmt sets the value to goFmt and marks the generated codes will be formatted or not
func (t *DefaultTemplate) GoFmt(format bool) *DefaultTemplate { func (t *DefaultTemplate) GoFmt(format bool) *DefaultTemplate {
t.goFmt = format t.goFmt = format
return t return t

View File

@@ -3,7 +3,7 @@ package vars
const ( const (
// ProjectName the const value of zero // ProjectName the const value of zero
ProjectName = "zero" ProjectName = "zero"
// ProjectOpenSourceURL the githb url of go-zero // ProjectOpenSourceURL the github url of go-zero
ProjectOpenSourceURL = "github.com/tal-tech/go-zero" ProjectOpenSourceURL = "github.com/tal-tech/go-zero"
// OsWindows windows os // OsWindows windows os
OsWindows = "windows" OsWindows = "windows"

View File

@@ -18,6 +18,27 @@ func TimeoutInterceptor(timeout time.Duration) grpc.UnaryClientInterceptor {
ctx, cancel := contextx.ShrinkDeadline(ctx, timeout) ctx, cancel := contextx.ShrinkDeadline(ctx, timeout)
defer cancel() defer cancel()
return invoker(ctx, method, req, reply, cc, opts...)
// create channel with buffer size 1 to avoid goroutine leak
done := make(chan error, 1)
panicChan := make(chan interface{}, 1)
go func() {
defer func() {
if p := recover(); p != nil {
panicChan <- p
}
}()
done <- invoker(ctx, method, req, reply, cc, opts...)
}()
select {
case p := <-panicChan:
panic(p)
case err := <-done:
return err
case <-ctx.Done():
return ctx.Err()
}
} }
} }

View File

@@ -48,3 +48,40 @@ func TestTimeoutInterceptor_timeout(t *testing.T) {
wg.Wait() wg.Wait()
assert.Nil(t, err) assert.Nil(t, err)
} }
func TestTimeoutInterceptor_timeoutExpire(t *testing.T) {
const timeout = time.Millisecond * 10
interceptor := TimeoutInterceptor(timeout)
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
defer cancel()
var wg sync.WaitGroup
wg.Add(1)
cc := new(grpc.ClientConn)
err := interceptor(ctx, "/foo", nil, nil, cc,
func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
opts ...grpc.CallOption) error {
defer wg.Done()
time.Sleep(time.Millisecond * 50)
return nil
})
wg.Wait()
assert.Equal(t, context.DeadlineExceeded, err)
}
func TestTimeoutInterceptor_panic(t *testing.T) {
timeouts := []time.Duration{0, time.Millisecond * 10}
for _, timeout := range timeouts {
t.Run(strconv.FormatInt(int64(timeout), 10), func(t *testing.T) {
interceptor := TimeoutInterceptor(timeout)
cc := new(grpc.ClientConn)
assert.Panics(t, func() {
_ = interceptor(context.Background(), "/foo", nil, nil, cc,
func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
opts ...grpc.CallOption) error {
panic("any")
},
)
})
})
}
}

View File

@@ -2,6 +2,7 @@ package serverinterceptors
import ( import (
"context" "context"
"sync"
"time" "time"
"github.com/tal-tech/go-zero/core/contextx" "github.com/tal-tech/go-zero/core/contextx"
@@ -11,9 +12,38 @@ import (
// UnaryTimeoutInterceptor returns a func that sets timeout to incoming unary requests. // UnaryTimeoutInterceptor returns a func that sets timeout to incoming unary requests.
func UnaryTimeoutInterceptor(timeout time.Duration) grpc.UnaryServerInterceptor { func UnaryTimeoutInterceptor(timeout time.Duration) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler) (resp interface{}, err error) { handler grpc.UnaryHandler) (interface{}, error) {
ctx, cancel := contextx.ShrinkDeadline(ctx, timeout) ctx, cancel := contextx.ShrinkDeadline(ctx, timeout)
defer cancel() defer cancel()
return handler(ctx, req)
var resp interface{}
var err error
var lock sync.Mutex
done := make(chan struct{})
// create channel with buffer size 1 to avoid goroutine leak
panicChan := make(chan interface{}, 1)
go func() {
defer func() {
if p := recover(); p != nil {
panicChan <- p
}
}()
lock.Lock()
defer lock.Unlock()
resp, err = handler(ctx, req)
close(done)
}()
select {
case p := <-panicChan:
panic(p)
case <-done:
lock.Lock()
defer lock.Unlock()
return resp, err
case <-ctx.Done():
return nil, ctx.Err()
}
} }
} }

View File

@@ -20,6 +20,17 @@ func TestUnaryTimeoutInterceptor(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
} }
func TestUnaryTimeoutInterceptor_panic(t *testing.T) {
interceptor := UnaryTimeoutInterceptor(time.Millisecond * 10)
assert.Panics(t, func() {
_, _ = interceptor(context.Background(), nil, &grpc.UnaryServerInfo{
FullMethod: "/",
}, func(ctx context.Context, req interface{}) (interface{}, error) {
panic("any")
})
})
}
func TestUnaryTimeoutInterceptor_timeout(t *testing.T) { func TestUnaryTimeoutInterceptor_timeout(t *testing.T) {
const timeout = time.Millisecond * 10 const timeout = time.Millisecond * 10
interceptor := UnaryTimeoutInterceptor(timeout) interceptor := UnaryTimeoutInterceptor(timeout)
@@ -39,3 +50,21 @@ func TestUnaryTimeoutInterceptor_timeout(t *testing.T) {
wg.Wait() wg.Wait()
assert.Nil(t, err) assert.Nil(t, err)
} }
func TestUnaryTimeoutInterceptor_timeoutExpire(t *testing.T) {
const timeout = time.Millisecond * 10
interceptor := UnaryTimeoutInterceptor(timeout)
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
defer cancel()
var wg sync.WaitGroup
wg.Add(1)
_, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{
FullMethod: "/",
}, func(ctx context.Context, req interface{}) (interface{}, error) {
defer wg.Done()
time.Sleep(time.Millisecond * 50)
return nil, nil
})
wg.Wait()
assert.Equal(t, context.DeadlineExceeded, err)
}

View File

@@ -79,6 +79,8 @@ func (rs *RpcServer) AddUnaryInterceptors(interceptors ...grpc.UnaryServerInterc
} }
// Start starts the RpcServer. // Start starts the RpcServer.
// Graceful shutdown is enabled by default.
// Use proc.SetTimeToForceQuit to customize the graceful shutdown period.
func (rs *RpcServer) Start() { func (rs *RpcServer) Start() {
if err := rs.server.Start(rs.register); err != nil { if err := rs.server.Start(rs.register); err != nil {
logx.Error(err) logx.Error(err)