Compare commits

...

16 Commits

Author SHA1 Message Date
kevin
1fd2ef9347 make tests faster 2020-10-21 21:43:41 +08:00
kevin
efffb40fa3 update wechat info 2020-10-21 20:26:35 +08:00
kevin
9c8f31cf83 can only specify one origin in cors 2020-10-21 16:47:49 +08:00
kevin
96cb7af728 make tests faster 2020-10-21 15:18:22 +08:00
Keson
41964f9d52 gozero template (#147)
* model/rpc generate code from template cache

* delete unused(deprecated) code

* support template init|update|clean|revert

* model: return the execute result for insert and update operation

* // deprecated: containsAny

* add template test

* add default buildVersion

* update build version
2020-10-21 14:59:35 +08:00
kevin
fe0d0687f5 support cors in rest server 2020-10-21 14:10:36 +08:00
kingxt
1c1e4bca86 optimized generator formatted code (#148)
* rebase upstream

* rebase

* trim no need line

* trim no need line

* trim no need line

* update doc

* remove update

* remove no need

* remove no need

* goctl add jwt support

* goctl add jwt support

* goctl add jwt support

* goctl support import

* goctl support import

* support return ()

* revert

* refactor and rename folder to group

* remove no need

* add anonymous annotation

* optimized

* rename

* rename

* update test

* api add middleware support: usage:

@server(
    middleware: M1, M2
)

* api add middleware support: usage:

@server(
    middleware: M1, M2
)

* simple logic

* optimized

* optimized generator formatted code

* optimized generator formatted code

* add more test

Co-authored-by: kingxt <dream4kingxt@163.com>
2020-10-20 19:43:20 +08:00
kevin
1abe21aa2a export WithUnaryClientInterceptor 2020-10-20 18:03:05 +08:00
kevin
cee170f3e9 fix zrpc client interceptor calling problem 2020-10-20 17:57:41 +08:00
kevin
907efd92c9 let balancer to be customizable 2020-10-20 17:01:53 +08:00
kevin
737cd4751a rename NewPatRouter to NewRouter 2020-10-20 14:23:21 +08:00
kevin
dfe6e88529 use goctl template to generate all kinds of templates 2020-10-19 23:13:18 +08:00
kingxt
85a815bea0 fix name typo and format with newline (#143)
* rebase upstream

* rebase

* trim no need line

* trim no need line

* trim no need line

* update doc

* remove update

* remove no need

* remove no need

* goctl add jwt support

* goctl add jwt support

* goctl add jwt support

* goctl support import

* goctl support import

* support return ()

* revert

* refactor and rename folder to group

* remove no need

* add anonymous annotation

* optimized

* rename

* rename

* update test

* api add middleware support: usage:

@server(
    middleware: M1, M2
)

* api add middleware support: usage:

@server(
    middleware: M1, M2
)

* simple logic

* optimized

* bugs fix for name typo and format with newline

Co-authored-by: kingxt <dream4kingxt@163.com>
2020-10-19 21:05:00 +08:00
kingxt
aa3c391919 api add middleware support (#140)
* rebase upstream

* rebase

* trim no need line

* trim no need line

* trim no need line

* update doc

* remove update

* remove no need

* remove no need

* goctl add jwt support

* goctl add jwt support

* goctl add jwt support

* goctl support import

* goctl support import

* support return ()

* revert

* refactor and rename folder to group

* remove no need

* add anonymous annotation

* optimized

* rename

* rename

* update test

* api add middleware support: usage:

@server(
    middleware: M1, M2
)

* api add middleware support: usage:

@server(
    middleware: M1, M2
)

* simple logic

* should reverse middlewares

* optimized

* optimized

* rename

Co-authored-by: kingxt <dream4kingxt@163.com>
2020-10-19 18:34:10 +08:00
kevin
c9b0ac1ee4 add more tests 2020-10-19 15:49:11 +08:00
mywaystay
33faab61a3 add redis Zrevrank (#137)
* update goctl rpc template log print url

* add redis Zrevrank

Co-authored-by: zhangkai <zhangkai@laoyuegou.com>
2020-10-19 15:30:19 +08:00
81 changed files with 1687 additions and 488 deletions

View File

@@ -2,18 +2,18 @@ package bloom
import ( import (
"testing" "testing"
"time"
"github.com/alicebob/miniredis" "github.com/alicebob/miniredis"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/lang"
"github.com/tal-tech/go-zero/core/stores/redis" "github.com/tal-tech/go-zero/core/stores/redis"
) )
func TestRedisBitSet_New_Set_Test(t *testing.T) { func TestRedisBitSet_New_Set_Test(t *testing.T) {
s, err := miniredis.Run() s, clean, err := createMiniRedis()
if err != nil { assert.Nil(t, err)
t.Error("Miniredis could not start") defer clean()
}
defer s.Close()
store := redis.NewRedis(s.Addr(), redis.NodeType) store := redis.NewRedis(s.Addr(), redis.NodeType)
bitSet := newRedisBitSet(store, "test_key", 1024) bitSet := newRedisBitSet(store, "test_key", 1024)
@@ -46,11 +46,9 @@ func TestRedisBitSet_New_Set_Test(t *testing.T) {
} }
func TestRedisBitSet_Add(t *testing.T) { func TestRedisBitSet_Add(t *testing.T) {
s, err := miniredis.Run() s, clean, err := createMiniRedis()
if err != nil { assert.Nil(t, err)
t.Error("Miniredis could not start") defer clean()
}
defer s.Close()
store := redis.NewRedis(s.Addr(), redis.NodeType) store := redis.NewRedis(s.Addr(), redis.NodeType)
filter := New(store, "test_key", 64) filter := New(store, "test_key", 64)
@@ -60,3 +58,22 @@ func TestRedisBitSet_Add(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, ok) assert.True(t, ok)
} }
func createMiniRedis() (r *miniredis.Miniredis, clean func(), err error) {
r, err = miniredis.Run()
if err != nil {
return nil, nil, err
}
return r, func() {
ch := make(chan lang.PlaceholderType)
go func() {
r.Close()
close(ch)
}()
select {
case <-ch:
case <-time.After(time.Second):
}
}, nil
}

11
core/errorx/callchain.go Normal file
View File

@@ -0,0 +1,11 @@
package errorx
func Chain(fns ...func() error) error {
for _, fn := range fns {
if err := fn(); err != nil {
return err
}
}
return nil
}

View File

@@ -0,0 +1,27 @@
package errorx
import (
"errors"
"testing"
"github.com/stretchr/testify/assert"
)
func TestChain(t *testing.T) {
var errDummy = errors.New("dummy")
assert.Nil(t, Chain(func() error {
return nil
}, func() error {
return nil
}))
assert.Equal(t, errDummy, Chain(func() error {
return errDummy
}, func() error {
return nil
}))
assert.Equal(t, errDummy, Chain(func() error {
return nil
}, func() error {
return errDummy
}))
}

View File

@@ -73,6 +73,7 @@ type (
ZrevrangebyscoreWithScores(key string, start, stop int64) ([]redis.Pair, error) ZrevrangebyscoreWithScores(key string, start, stop int64) ([]redis.Pair, error)
ZrevrangebyscoreWithScoresAndLimit(key string, start, stop int64, page, size int) ([]redis.Pair, error) ZrevrangebyscoreWithScoresAndLimit(key string, start, stop int64, page, size int) ([]redis.Pair, error)
Zscore(key string, value string) (int64, error) Zscore(key string, value string) (int64, error)
Zrevrank(key, field string) (int64, error)
} }
clusterStore struct { clusterStore struct {
@@ -635,6 +636,15 @@ func (cs clusterStore) ZrevrangebyscoreWithScoresAndLimit(key string, start, sto
return node.ZrevrangebyscoreWithScoresAndLimit(key, start, stop, page, size) return node.ZrevrangebyscoreWithScoresAndLimit(key, start, stop, page, size)
} }
func (cs clusterStore) Zrevrank(key, field string) (int64, error) {
node, err := cs.getRedis(key)
if err != nil {
return 0, err
}
return node.Zrevrank(key, field)
}
func (cs clusterStore) Zscore(key string, value string) (int64, error) { func (cs clusterStore) Zscore(key string, value string) (int64, error) {
node, err := cs.getRedis(key) node, err := cs.getRedis(key)
if err != nil { if err != nil {

View File

@@ -516,6 +516,8 @@ func TestRedis_SortedSet(t *testing.T) {
assert.NotNil(t, err) assert.NotNil(t, err)
_, err = store.ZrevrangebyscoreWithScoresAndLimit("key", 5, 8, 1, 1) _, err = store.ZrevrangebyscoreWithScoresAndLimit("key", 5, 8, 1, 1)
assert.NotNil(t, err) assert.NotNil(t, err)
_, err = store.Zrevrank("key", "value")
assert.NotNil(t, err)
_, err = store.Zadds("key", redis.Pair{ _, err = store.Zadds("key", redis.Pair{
Key: "value2", Key: "value2",
Score: 6, Score: 6,
@@ -640,6 +642,9 @@ func TestRedis_SortedSet(t *testing.T) {
Score: 5, Score: 5,
}, },
}, pairs) }, pairs)
rank, err = client.Zrevrank("key", "value1")
assert.Nil(t, err)
assert.Equal(t, int64(1), rank)
val, err = client.Zadds("key", redis.Pair{ val, err = client.Zadds("key", redis.Pair{
Key: "value2", Key: "value2",
Score: 6, Score: 6,

View File

@@ -1273,6 +1273,20 @@ func (s *Redis) ZrevrangebyscoreWithScoresAndLimit(key string, start, stop int64
return return
} }
func (s *Redis) Zrevrank(key string, field string) (val int64, err error) {
err = s.brk.DoWithAcceptable(func() error {
conn, err := getRedis(s)
if err != nil {
return err
}
val, err = conn.ZRevRank(key, field).Result()
return err
}, acceptable)
return
}
func (s *Redis) String() string { func (s *Redis) String() string {
return s.Addr return s.Addr
} }

View File

@@ -584,6 +584,9 @@ func TestRedis_SortedSet(t *testing.T) {
rank, err := client.Zrank("key", "value2") rank, err := client.Zrank("key", "value2")
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, int64(1), rank) assert.Equal(t, int64(1), rank)
rank, err = client.Zrevrank("key", "value1")
assert.Nil(t, err)
assert.Equal(t, int64(2), rank)
_, err = NewRedis(client.Addr, "").Zrank("key", "value4") _, err = NewRedis(client.Addr, "").Zrank("key", "value4")
assert.NotNil(t, err) assert.NotNil(t, err)
_, err = client.Zrank("key", "value4") _, err = client.Zrank("key", "value4")
@@ -710,6 +713,8 @@ func TestRedis_SortedSet(t *testing.T) {
pairs, err = client.ZrevrangebyscoreWithScoresAndLimit("key", 5, 8, 1, 0) pairs, err = client.ZrevrangebyscoreWithScoresAndLimit("key", 5, 8, 1, 0)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, 0, len(pairs)) assert.Equal(t, 0, len(pairs))
_, err = NewRedis(client.Addr, "").Zrevrank("key", "value")
assert.NotNil(t, err)
}) })
} }

View File

@@ -16,6 +16,7 @@ import (
"github.com/alicebob/miniredis" "github.com/alicebob/miniredis"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/lang"
"github.com/tal-tech/go-zero/core/logx" "github.com/tal-tech/go-zero/core/logx"
"github.com/tal-tech/go-zero/core/stat" "github.com/tal-tech/go-zero/core/stat"
"github.com/tal-tech/go-zero/core/stores/cache" "github.com/tal-tech/go-zero/core/stores/cache"
@@ -30,10 +31,9 @@ func init() {
func TestCachedConn_GetCache(t *testing.T) { func TestCachedConn_GetCache(t *testing.T) {
resetStats() resetStats()
s, err := miniredis.Run() s, clean, err := createMiniRedis()
if err != nil { assert.Nil(t, err)
t.Error(err) defer clean()
}
r := redis.NewRedis(s.Addr(), redis.NodeType) r := redis.NewRedis(s.Addr(), redis.NodeType)
c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10)) c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10))
@@ -48,10 +48,9 @@ func TestCachedConn_GetCache(t *testing.T) {
func TestStat(t *testing.T) { func TestStat(t *testing.T) {
resetStats() resetStats()
s, err := miniredis.Run() s, clean, err := createMiniRedis()
if err != nil { assert.Nil(t, err)
t.Error(err) defer clean()
}
r := redis.NewRedis(s.Addr(), redis.NodeType) r := redis.NewRedis(s.Addr(), redis.NodeType)
c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10)) c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10))
@@ -73,10 +72,9 @@ func TestStat(t *testing.T) {
func TestCachedConn_QueryRowIndex_NoCache(t *testing.T) { func TestCachedConn_QueryRowIndex_NoCache(t *testing.T) {
resetStats() resetStats()
s, err := miniredis.Run() s, clean, err := createMiniRedis()
if err != nil { assert.Nil(t, err)
t.Error(err) defer clean()
}
r := redis.NewRedis(s.Addr(), redis.NodeType) r := redis.NewRedis(s.Addr(), redis.NodeType)
c := NewConn(dummySqlConn{}, cache.CacheConf{ c := NewConn(dummySqlConn{}, cache.CacheConf{
@@ -124,10 +122,9 @@ func TestCachedConn_QueryRowIndex_NoCache(t *testing.T) {
func TestCachedConn_QueryRowIndex_HasCache(t *testing.T) { func TestCachedConn_QueryRowIndex_HasCache(t *testing.T) {
resetStats() resetStats()
s, err := miniredis.Run() s, clean, err := createMiniRedis()
if err != nil { assert.Nil(t, err)
t.Error(err) defer clean()
}
r := redis.NewRedis(s.Addr(), redis.NodeType) r := redis.NewRedis(s.Addr(), redis.NodeType)
c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10), c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10),
@@ -213,11 +210,9 @@ func TestCachedConn_QueryRowIndex_HasCache_IntPrimary(t *testing.T) {
}, },
} }
s, err := miniredis.Run() s, clean, err := createMiniRedis()
if err != nil { assert.Nil(t, err)
t.Error(err) defer clean()
}
defer s.Close()
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
@@ -261,12 +256,9 @@ func TestCachedConn_QueryRowIndex_HasWrongCache(t *testing.T) {
for k, v := range caches { for k, v := range caches {
t.Run(k+"/"+v, func(t *testing.T) { t.Run(k+"/"+v, func(t *testing.T) {
resetStats() resetStats()
s, err := miniredis.Run() s, clean, err := createMiniRedis()
if err != nil { assert.Nil(t, err)
t.Error(err) defer clean()
}
s.FlushAll()
defer s.Close()
r := redis.NewRedis(s.Addr(), redis.NodeType) r := redis.NewRedis(s.Addr(), redis.NodeType)
c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10), c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10),
@@ -320,10 +312,9 @@ func TestStatCacheFails(t *testing.T) {
func TestStatDbFails(t *testing.T) { func TestStatDbFails(t *testing.T) {
resetStats() resetStats()
s, err := miniredis.Run() s, clean, err := createMiniRedis()
if err != nil { assert.Nil(t, err)
t.Error(err) defer clean()
}
r := redis.NewRedis(s.Addr(), redis.NodeType) r := redis.NewRedis(s.Addr(), redis.NodeType)
c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10)) c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10))
@@ -343,10 +334,9 @@ func TestStatDbFails(t *testing.T) {
func TestStatFromMemory(t *testing.T) { func TestStatFromMemory(t *testing.T) {
resetStats() resetStats()
s, err := miniredis.Run() s, clean, err := createMiniRedis()
if err != nil { assert.Nil(t, err)
t.Error(err) defer clean()
}
r := redis.NewRedis(s.Addr(), redis.NodeType) r := redis.NewRedis(s.Addr(), redis.NodeType)
c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10)) c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10))
@@ -403,10 +393,9 @@ func TestStatFromMemory(t *testing.T) {
} }
func TestCachedConnQueryRow(t *testing.T) { func TestCachedConnQueryRow(t *testing.T) {
s, err := miniredis.Run() s, clean, err := createMiniRedis()
if err != nil { assert.Nil(t, err)
t.Error(err) defer clean()
}
const ( const (
key = "user" key = "user"
@@ -433,10 +422,9 @@ func TestCachedConnQueryRow(t *testing.T) {
} }
func TestCachedConnQueryRowFromCache(t *testing.T) { func TestCachedConnQueryRowFromCache(t *testing.T) {
s, err := miniredis.Run() s, clean, err := createMiniRedis()
if err != nil { assert.Nil(t, err)
t.Error(err) defer clean()
}
const ( const (
key = "user" key = "user"
@@ -464,10 +452,9 @@ func TestCachedConnQueryRowFromCache(t *testing.T) {
} }
func TestQueryRowNotFound(t *testing.T) { func TestQueryRowNotFound(t *testing.T) {
s, err := miniredis.Run() s, clean, err := createMiniRedis()
if err != nil { assert.Nil(t, err)
t.Error(err) defer clean()
}
const key = "user" const key = "user"
var conn trackedConn var conn trackedConn
@@ -486,10 +473,9 @@ func TestQueryRowNotFound(t *testing.T) {
} }
func TestCachedConnExec(t *testing.T) { func TestCachedConnExec(t *testing.T) {
s, err := miniredis.Run() s, clean, err := createMiniRedis()
if err != nil { assert.Nil(t, err)
t.Error(err) defer clean()
}
var conn trackedConn var conn trackedConn
r := redis.NewRedis(s.Addr(), redis.NodeType) r := redis.NewRedis(s.Addr(), redis.NodeType)
@@ -500,10 +486,9 @@ func TestCachedConnExec(t *testing.T) {
} }
func TestCachedConnExecDropCache(t *testing.T) { func TestCachedConnExecDropCache(t *testing.T) {
s, err := miniredis.Run() s, clean, err := createMiniRedis()
if err != nil { assert.Nil(t, err)
t.Error(err) defer clean()
}
const ( const (
key = "user" key = "user"
@@ -539,10 +524,9 @@ func TestCachedConnExecDropCacheFailed(t *testing.T) {
} }
func TestCachedConnQueryRows(t *testing.T) { func TestCachedConnQueryRows(t *testing.T) {
s, err := miniredis.Run() s, clean, err := createMiniRedis()
if err != nil { assert.Nil(t, err)
t.Error(err) defer clean()
}
var conn trackedConn var conn trackedConn
r := redis.NewRedis(s.Addr(), redis.NodeType) r := redis.NewRedis(s.Addr(), redis.NodeType)
@@ -554,10 +538,9 @@ func TestCachedConnQueryRows(t *testing.T) {
} }
func TestCachedConnTransact(t *testing.T) { func TestCachedConnTransact(t *testing.T) {
s, err := miniredis.Run() s, clean, err := createMiniRedis()
if err != nil { assert.Nil(t, err)
t.Error(err) defer clean()
}
var conn trackedConn var conn trackedConn
r := redis.NewRedis(s.Addr(), redis.NodeType) r := redis.NewRedis(s.Addr(), redis.NodeType)
@@ -570,10 +553,9 @@ func TestCachedConnTransact(t *testing.T) {
} }
func TestQueryRowNoCache(t *testing.T) { func TestQueryRowNoCache(t *testing.T) {
s, err := miniredis.Run() s, clean, err := createMiniRedis()
if err != nil { assert.Nil(t, err)
t.Error(err) defer clean()
}
const ( const (
key = "user" key = "user"
@@ -657,3 +639,22 @@ func (c *trackedConn) Transact(fn func(session sqlx.Session) error) error {
c.transactValue = true c.transactValue = true
return c.dummySqlConn.Transact(fn) return c.dummySqlConn.Transact(fn)
} }
func createMiniRedis() (r *miniredis.Miniredis, clean func(), err error) {
r, err = miniredis.Run()
if err != nil {
return nil, nil, err
}
return r, func() {
ch := make(chan lang.PlaceholderType)
go func() {
r.Close()
close(ch)
}()
select {
case <-ch:
case <-time.After(time.Second):
}
}, nil
}

View File

@@ -56,7 +56,7 @@ func main() {
Port: *port, Port: *port,
Timeout: *timeout, Timeout: *timeout,
MaxConns: 500, MaxConns: 500,
}) }, rest.WithNotAllowedHandler(rest.CorsHandler()))
defer engine.Stop() defer engine.Stop()
engine.Use(first) engine.Use(first)

View File

@@ -173,6 +173,6 @@ GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/tal-tech/
如果您发现bug请及时提issue我们会尽快确认并修改。 如果您发现bug请及时提issue我们会尽快确认并修改。
扫码后请加群主,便于我邀请您进讨论群,并请退出扫码网关群,谢谢! <!-- 扫码后请加群主,便于我邀请您进讨论群,并请退出扫码网关群,谢谢!-->
<img src="https://raw.githubusercontent.com/tal-tech/zero-doc/main/doc/images/wechat.jpg" alt="wechat" width="300" /> <img src="https://raw.githubusercontent.com/tal-tech/zero-doc/main/doc/images/wechat.jpg" alt="wechat" width="300" />

View File

@@ -57,7 +57,7 @@ func (s *engine) SetUnsignedCallback(callback handler.UnsignedCallback) {
} }
func (s *engine) Start() error { func (s *engine) Start() error {
return s.StartWithRouter(router.NewPatRouter()) return s.StartWithRouter(router.NewRouter())
} }
func (s *engine) StartWithRouter(router httpx.Router) error { func (s *engine) StartWithRouter(router httpx.Router) error {

27
rest/handlers.go Normal file
View File

@@ -0,0 +1,27 @@
package rest
import "net/http"
const (
allowOrigin = "Access-Control-Allow-Origin"
allOrigins = "*"
allowMethods = "Access-Control-Allow-Methods"
allowHeaders = "Access-Control-Allow-Headers"
headers = "Content-Type, Content-Length, Origin"
methods = "GET, HEAD, POST, PATCH, PUT, DELETE"
)
// CorsHandler handles cross domain OPTIONS requests.
// At most one origin can be specified, other origins are ignored if given.
func CorsHandler(origins ...string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if len(origins) > 0 {
w.Header().Set(allowOrigin, origins[0])
} else {
w.Header().Set(allowOrigin, allOrigins)
}
w.Header().Set(allowMethods, methods)
w.Header().Set(allowHeaders, headers)
w.WriteHeader(http.StatusNoContent)
})
}

42
rest/handlers_test.go Normal file
View File

@@ -0,0 +1,42 @@
package rest
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
func TestCorsHandlerWithOrigins(t *testing.T) {
tests := []struct {
name string
origins []string
expect string
}{
{
name: "allow all origins",
expect: allOrigins,
},
{
name: "allow one origin",
origins: []string{"local"},
expect: "local",
},
{
name: "allow many origins",
origins: []string{"local", "remote"},
expect: "local",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
w := httptest.NewRecorder()
handler := CorsHandler(test.origins...)
handler.ServeHTTP(w, nil)
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
})
}
}

View File

@@ -6,4 +6,5 @@ type Router interface {
http.Handler http.Handler
Handle(method string, path string, handler http.Handler) error Handle(method string, path string, handler http.Handler) error
SetNotFoundHandler(handler http.Handler) SetNotFoundHandler(handler http.Handler)
SetNotAllowedHandler(handler http.Handler)
} }

View File

@@ -21,18 +21,19 @@ var (
ErrInvalidPath = errors.New("path must begin with '/'") ErrInvalidPath = errors.New("path must begin with '/'")
) )
type PatRouter struct { type patRouter struct {
trees map[string]*search.Tree trees map[string]*search.Tree
notFound http.Handler notFound http.Handler
notAllowed http.Handler
} }
func NewPatRouter() httpx.Router { func NewRouter() httpx.Router {
return &PatRouter{ return &patRouter{
trees: make(map[string]*search.Tree), trees: make(map[string]*search.Tree),
} }
} }
func (pr *PatRouter) Handle(method, reqPath string, handler http.Handler) error { func (pr *patRouter) Handle(method, reqPath string, handler http.Handler) error {
if !validMethod(method) { if !validMethod(method) {
return ErrInvalidMethod return ErrInvalidMethod
} }
@@ -51,7 +52,7 @@ func (pr *PatRouter) Handle(method, reqPath string, handler http.Handler) error
} }
} }
func (pr *PatRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (pr *patRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
reqPath := path.Clean(r.URL.Path) reqPath := path.Clean(r.URL.Path)
if tree, ok := pr.trees[r.Method]; ok { if tree, ok := pr.trees[r.Method]; ok {
if result, ok := tree.Search(reqPath); ok { if result, ok := tree.Search(reqPath); ok {
@@ -63,19 +64,29 @@ func (pr *PatRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
} }
if allow, ok := pr.methodNotAllowed(r.Method, reqPath); ok { allow, ok := pr.methodNotAllowed(r.Method, reqPath)
if !ok {
pr.handleNotFound(w, r)
return
}
if pr.notAllowed != nil {
pr.notAllowed.ServeHTTP(w, r)
} else {
w.Header().Set(allowHeader, allow) w.Header().Set(allowHeader, allow)
w.WriteHeader(http.StatusMethodNotAllowed) w.WriteHeader(http.StatusMethodNotAllowed)
} else {
pr.handleNotFound(w, r)
} }
} }
func (pr *PatRouter) SetNotFoundHandler(handler http.Handler) { func (pr *patRouter) SetNotFoundHandler(handler http.Handler) {
pr.notFound = handler pr.notFound = handler
} }
func (pr *PatRouter) handleNotFound(w http.ResponseWriter, r *http.Request) { func (pr *patRouter) SetNotAllowedHandler(handler http.Handler) {
pr.notAllowed = handler
}
func (pr *patRouter) handleNotFound(w http.ResponseWriter, r *http.Request) {
if pr.notFound != nil { if pr.notFound != nil {
pr.notFound.ServeHTTP(w, r) pr.notFound.ServeHTTP(w, r)
} else { } else {
@@ -83,7 +94,7 @@ func (pr *PatRouter) handleNotFound(w http.ResponseWriter, r *http.Request) {
} }
} }
func (pr *PatRouter) methodNotAllowed(method, path string) (string, bool) { func (pr *patRouter) methodNotAllowed(method, path string) (string, bool) {
var allows []string var allows []string
for treeMethod, tree := range pr.trees { for treeMethod, tree := range pr.trees {

View File

@@ -47,7 +47,7 @@ func TestPatRouterHandleErrors(t *testing.T) {
for _, test := range tests { for _, test := range tests {
t.Run(test.method, func(t *testing.T) { t.Run(test.method, func(t *testing.T) {
router := NewPatRouter() router := NewRouter()
err := router.Handle(test.method, test.path, nil) err := router.Handle(test.method, test.path, nil)
assert.Error(t, ErrInvalidMethod, err) assert.Error(t, ErrInvalidMethod, err)
}) })
@@ -56,17 +56,34 @@ func TestPatRouterHandleErrors(t *testing.T) {
func TestPatRouterNotFound(t *testing.T) { func TestPatRouterNotFound(t *testing.T) {
var notFound bool var notFound bool
router := NewPatRouter() router := NewRouter()
router.SetNotFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { router.SetNotFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
notFound = true notFound = true
})) }))
router.Handle(http.MethodGet, "/a/b", nil) err := router.Handle(http.MethodGet, "/a/b",
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
assert.Nil(t, err)
r, _ := http.NewRequest(http.MethodGet, "/b/c", nil) r, _ := http.NewRequest(http.MethodGet, "/b/c", nil)
w := new(mockedResponseWriter) w := new(mockedResponseWriter)
router.ServeHTTP(w, r) router.ServeHTTP(w, r)
assert.True(t, notFound) assert.True(t, notFound)
} }
func TestPatRouterNotAllowed(t *testing.T) {
var notAllowed bool
router := NewRouter()
router.SetNotAllowedHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
notAllowed = true
}))
err := router.Handle(http.MethodGet, "/a/b",
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
assert.Nil(t, err)
r, _ := http.NewRequest(http.MethodPost, "/a/b", nil)
w := new(mockedResponseWriter)
router.ServeHTTP(w, r)
assert.True(t, notAllowed)
}
func TestPatRouter(t *testing.T) { func TestPatRouter(t *testing.T) {
tests := []struct { tests := []struct {
method string method string
@@ -87,7 +104,7 @@ func TestPatRouter(t *testing.T) {
for _, test := range tests { for _, test := range tests {
t.Run(test.method+":"+test.path, func(t *testing.T) { t.Run(test.method+":"+test.path, func(t *testing.T) {
routed := false routed := false
router := NewPatRouter() router := NewRouter()
err := router.Handle(test.method, "/a/:b", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { err := router.Handle(test.method, "/a/:b", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
routed = true routed = true
assert.Equal(t, 1, len(context.Vars(r))) assert.Equal(t, 1, len(context.Vars(r)))
@@ -125,7 +142,7 @@ func TestParseSlice(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
r.Header.Set("Content-Type", "application/x-www-form-urlencoded") r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
rt := NewPatRouter() rt := NewRouter()
err = rt.Handle(http.MethodPost, "/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { err = rt.Handle(http.MethodPost, "/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {
Names []string `form:"names"` Names []string `form:"names"`
@@ -149,7 +166,7 @@ func TestParseJsonPost(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
r.Header.Set(httpx.ContentType, httpx.ApplicationJson) r.Header.Set(httpx.ContentType, httpx.ApplicationJson)
router := NewPatRouter() router := NewRouter()
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(func( err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(func(
w http.ResponseWriter, r *http.Request) { w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {
@@ -181,7 +198,7 @@ func TestParseJsonPostWithIntSlice(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
r.Header.Set(httpx.ContentType, httpx.ApplicationJson) r.Header.Set(httpx.ContentType, httpx.ApplicationJson)
router := NewPatRouter() router := NewRouter()
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(func( err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(func(
w http.ResponseWriter, r *http.Request) { w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {
@@ -209,7 +226,7 @@ func TestParseJsonPostError(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
r.Header.Set(httpx.ContentType, httpx.ApplicationJson) r.Header.Set(httpx.ContentType, httpx.ApplicationJson)
router := NewPatRouter() router := NewRouter()
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {
@@ -237,7 +254,7 @@ func TestParseJsonPostInvalidRequest(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
r.Header.Set(httpx.ContentType, httpx.ApplicationJson) r.Header.Set(httpx.ContentType, httpx.ApplicationJson)
router := NewPatRouter() router := NewRouter()
err = router.Handle(http.MethodPost, "/", http.HandlerFunc( err = router.Handle(http.MethodPost, "/", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {
@@ -259,7 +276,7 @@ func TestParseJsonPostRequired(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
r.Header.Set(httpx.ContentType, httpx.ApplicationJson) r.Header.Set(httpx.ContentType, httpx.ApplicationJson)
router := NewPatRouter() router := NewRouter()
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {
@@ -282,7 +299,7 @@ func TestParsePath(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017", nil) r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017", nil)
assert.Nil(t, err) assert.Nil(t, err)
router := NewPatRouter() router := NewRouter()
err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc( err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {
@@ -307,7 +324,7 @@ func TestParsePathRequired(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin", nil) r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin", nil)
assert.Nil(t, err) assert.Nil(t, err)
router := NewPatRouter() router := NewRouter()
err = router.Handle(http.MethodGet, "/:name/", http.HandlerFunc( err = router.Handle(http.MethodGet, "/:name/", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {
@@ -328,7 +345,7 @@ func TestParseQuery(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000", nil) r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000", nil)
assert.Nil(t, err) assert.Nil(t, err)
router := NewPatRouter() router := NewRouter()
err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc( err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {
@@ -353,7 +370,7 @@ func TestParseQueryRequired(t *testing.T) {
r, err := http.NewRequest(http.MethodPost, "http://hello.com/kevin/2017?nickname=whatever", nil) r, err := http.NewRequest(http.MethodPost, "http://hello.com/kevin/2017?nickname=whatever", nil)
assert.Nil(t, err) assert.Nil(t, err)
router := NewPatRouter() router := NewRouter()
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {
Nickname string `form:"nickname"` Nickname string `form:"nickname"`
@@ -373,7 +390,7 @@ func TestParseOptional(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017?nickname=whatever&zipcode=", nil) r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017?nickname=whatever&zipcode=", nil)
assert.Nil(t, err) assert.Nil(t, err)
router := NewPatRouter() router := NewRouter()
err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc( err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {
@@ -414,7 +431,7 @@ func TestParseNestedInRequestEmpty(t *testing.T) {
} }
) )
router := NewPatRouter() router := NewRouter()
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
var v WrappedRequest var v WrappedRequest
@@ -453,7 +470,7 @@ func TestParsePtrInRequest(t *testing.T) {
} }
) )
router := NewPatRouter() router := NewRouter()
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
var v WrappedRequest var v WrappedRequest
@@ -484,7 +501,7 @@ func TestParsePtrInRequestEmpty(t *testing.T) {
} }
) )
router := NewPatRouter() router := NewRouter()
err = router.Handle(http.MethodPost, "/kevin", http.HandlerFunc( err = router.Handle(http.MethodPost, "/kevin", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
var v WrappedRequest var v WrappedRequest
@@ -501,7 +518,7 @@ func TestParseQueryOptional(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017?nickname=whatever&zipcode=", nil) r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017?nickname=whatever&zipcode=", nil)
assert.Nil(t, err) assert.Nil(t, err)
router := NewPatRouter() router := NewRouter()
err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc( err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {
@@ -526,7 +543,7 @@ func TestParse(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000", nil) r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000", nil)
assert.Nil(t, err) assert.Nil(t, err)
router := NewPatRouter() router := NewRouter()
err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc( err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {
@@ -564,7 +581,7 @@ func TestParseWrappedRequest(t *testing.T) {
} }
) )
router := NewPatRouter() router := NewRouter()
err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc( err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
var v WrappedRequest var v WrappedRequest
@@ -596,7 +613,7 @@ func TestParseWrappedGetRequestWithJsonHeader(t *testing.T) {
} }
) )
router := NewPatRouter() router := NewRouter()
err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc( err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
var v WrappedRequest var v WrappedRequest
@@ -629,7 +646,7 @@ func TestParseWrappedHeadRequestWithJsonHeader(t *testing.T) {
} }
) )
router := NewPatRouter() router := NewRouter()
err = router.Handle(http.MethodHead, "/:name/:year", http.HandlerFunc( err = router.Handle(http.MethodHead, "/:name/:year", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
var v WrappedRequest var v WrappedRequest
@@ -661,7 +678,7 @@ func TestParseWrappedRequestPtr(t *testing.T) {
} }
) )
router := NewPatRouter() router := NewRouter()
err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc( err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
var v WrappedRequest var v WrappedRequest
@@ -684,7 +701,7 @@ func TestParseWithAll(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
r.Header.Set(httpx.ContentType, httpx.ApplicationJson) r.Header.Set(httpx.ContentType, httpx.ApplicationJson)
router := NewPatRouter() router := NewRouter()
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {
Name string `path:"name"` Name string `path:"name"`
@@ -715,7 +732,7 @@ func TestParseWithAllUtf8(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
r.Header.Set(httpx.ContentType, applicationJsonWithUtf8) r.Header.Set(httpx.ContentType, applicationJsonWithUtf8)
router := NewPatRouter() router := NewRouter()
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {
@@ -746,7 +763,7 @@ func TestParseWithMissingForm(t *testing.T) {
bytes.NewBufferString(`{"location": "shanghai", "time": 20170912}`)) bytes.NewBufferString(`{"location": "shanghai", "time": 20170912}`))
assert.Nil(t, err) assert.Nil(t, err)
router := NewPatRouter() router := NewRouter()
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {
@@ -773,7 +790,7 @@ func TestParseWithMissingAllForms(t *testing.T) {
bytes.NewBufferString(`{"location": "shanghai", "time": 20170912}`)) bytes.NewBufferString(`{"location": "shanghai", "time": 20170912}`))
assert.Nil(t, err) assert.Nil(t, err)
router := NewPatRouter() router := NewRouter()
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {
@@ -799,7 +816,7 @@ func TestParseWithMissingJson(t *testing.T) {
bytes.NewBufferString(`{"location": "shanghai"}`)) bytes.NewBufferString(`{"location": "shanghai"}`))
assert.Nil(t, err) assert.Nil(t, err)
router := NewPatRouter() router := NewRouter()
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {
@@ -825,7 +842,7 @@ func TestParseWithMissingAllJsons(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000", nil) r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000", nil)
assert.Nil(t, err) assert.Nil(t, err)
router := NewPatRouter() router := NewRouter()
err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc( err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {
@@ -852,7 +869,7 @@ func TestParseWithMissingPath(t *testing.T) {
bytes.NewBufferString(`{"location": "shanghai", "time": 20170912}`)) bytes.NewBufferString(`{"location": "shanghai", "time": 20170912}`))
assert.Nil(t, err) assert.Nil(t, err)
router := NewPatRouter() router := NewRouter()
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {
@@ -879,7 +896,7 @@ func TestParseWithMissingAllPaths(t *testing.T) {
bytes.NewBufferString(`{"location": "shanghai", "time": 20170912}`)) bytes.NewBufferString(`{"location": "shanghai", "time": 20170912}`))
assert.Nil(t, err) assert.Nil(t, err)
router := NewPatRouter() router := NewRouter()
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {
@@ -906,7 +923,7 @@ func TestParseGetWithContentLengthHeader(t *testing.T) {
r.Header.Set(httpx.ContentType, httpx.ApplicationJson) r.Header.Set(httpx.ContentType, httpx.ApplicationJson)
r.Header.Set(contentLength, "1024") r.Header.Set(contentLength, "1024")
router := NewPatRouter() router := NewRouter()
err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc( err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {
@@ -933,7 +950,7 @@ func TestParseJsonPostWithTypeMismatch(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
r.Header.Set(httpx.ContentType, applicationJsonWithUtf8) r.Header.Set(httpx.ContentType, applicationJsonWithUtf8)
router := NewPatRouter() router := NewRouter()
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {
@@ -959,7 +976,7 @@ func TestParseJsonPostWithInt2String(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
r.Header.Set(httpx.ContentType, applicationJsonWithUtf8) r.Header.Set(httpx.ContentType, applicationJsonWithUtf8)
router := NewPatRouter() router := NewRouter()
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
v := struct { v := struct {
@@ -980,7 +997,7 @@ func TestParseJsonPostWithInt2String(t *testing.T) {
func BenchmarkPatRouter(b *testing.B) { func BenchmarkPatRouter(b *testing.B) {
b.ReportAllocs() b.ReportAllocs()
router := NewPatRouter() router := NewRouter()
router.Handle(http.MethodGet, "/api/:user/:name", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { router.Handle(http.MethodGet, "/api/:user/:name", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
})) }))
w := &mockedResponseWriter{} w := &mockedResponseWriter{}

View File

@@ -1,12 +1,14 @@
package rest package rest
import ( import (
"errors"
"log" "log"
"net/http" "net/http"
"github.com/tal-tech/go-zero/core/logx" "github.com/tal-tech/go-zero/core/logx"
"github.com/tal-tech/go-zero/rest/handler" "github.com/tal-tech/go-zero/rest/handler"
"github.com/tal-tech/go-zero/rest/httpx" "github.com/tal-tech/go-zero/rest/httpx"
"github.com/tal-tech/go-zero/rest/router"
) )
type ( type (
@@ -32,6 +34,10 @@ func MustNewServer(c RestConf, opts ...RunOption) *Server {
} }
func NewServer(c RestConf, opts ...RunOption) (*Server, error) { func NewServer(c RestConf, opts ...RunOption) (*Server, error) {
if len(opts) > 1 {
return nil, errors.New("only one RunOption is allowed")
}
if err := c.SetUp(); err != nil { if err := c.SetUp(); err != nil {
return nil, err return nil, err
} }
@@ -103,6 +109,13 @@ func WithJwtTransition(secret, prevSecret string) RouteOption {
} }
} }
func WithMiddlewares(ms []Middleware, rs ...Route) []Route {
for i := len(ms) - 1; i >= 0; i-- {
rs = WithMiddleware(ms[i], rs...)
}
return rs
}
func WithMiddleware(middleware Middleware, rs ...Route) []Route { func WithMiddleware(middleware Middleware, rs ...Route) []Route {
routes := make([]Route, len(rs)) routes := make([]Route, len(rs))
@@ -118,6 +131,18 @@ func WithMiddleware(middleware Middleware, rs ...Route) []Route {
return routes return routes
} }
func WithNotFoundHandler(handler http.Handler) RunOption {
rt := router.NewRouter()
rt.SetNotFoundHandler(handler)
return WithRouter(rt)
}
func WithNotAllowedHandler(handler http.Handler) RunOption {
rt := router.NewRouter()
rt.SetNotAllowedHandler(handler)
return WithRouter(rt)
}
func WithPriority() RouteOption { func WithPriority() RouteOption {
return func(r *featuredRoutes) { return func(r *featuredRoutes) {
r.priority = true r.priority = true

View File

@@ -12,9 +12,14 @@ import (
"github.com/tal-tech/go-zero/rest/router" "github.com/tal-tech/go-zero/rest/router"
) )
func TestNewServer(t *testing.T) {
_, err := NewServer(RestConf{}, WithNotFoundHandler(nil), WithNotAllowedHandler(nil))
assert.NotNil(t, err)
}
func TestWithMiddleware(t *testing.T) { func TestWithMiddleware(t *testing.T) {
m := make(map[string]string) m := make(map[string]string)
router := router.NewPatRouter() router := router.NewRouter()
handler := func(w http.ResponseWriter, r *http.Request) { handler := func(w http.ResponseWriter, r *http.Request) {
var v struct { var v struct {
Nickname string `form:"nickname"` Nickname string `form:"nickname"`
@@ -68,3 +73,81 @@ func TestWithMiddleware(t *testing.T) {
"wan": "2020", "wan": "2020",
}, m) }, m)
} }
func TestMultiMiddlewares(t *testing.T) {
m := make(map[string]string)
router := router.NewRouter()
handler := func(w http.ResponseWriter, r *http.Request) {
var v struct {
Nickname string `form:"nickname"`
Zipcode int64 `form:"zipcode"`
}
err := httpx.Parse(r, &v)
assert.Nil(t, err)
_, err = io.WriteString(w, fmt.Sprintf("%s:%s", v.Nickname, m[v.Nickname]))
assert.Nil(t, err)
}
rs := WithMiddlewares([]Middleware{
func(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var v struct {
Name string `path:"name"`
Year string `path:"year"`
}
assert.Nil(t, httpx.ParsePath(r, &v))
m[v.Name] = v.Year
next.ServeHTTP(w, r)
}
},
func(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var v struct {
Name string `form:"nickname"`
Zipcode string `form:"zipcode"`
}
assert.Nil(t, httpx.ParseForm(r, &v))
assert.NotEmpty(t, m)
m[v.Name] = v.Zipcode + v.Zipcode
next.ServeHTTP(w, r)
}
},
}, Route{
Method: http.MethodGet,
Path: "/first/:name/:year",
Handler: handler,
}, Route{
Method: http.MethodGet,
Path: "/second/:name/:year",
Handler: handler,
})
urls := []string{
"http://hello.com/first/kevin/2017?nickname=whatever&zipcode=200000",
"http://hello.com/second/wan/2020?nickname=whatever&zipcode=200000",
}
for _, route := range rs {
assert.Nil(t, router.Handle(route.Method, route.Path, route.Handler))
}
for _, url := range urls {
r, err := http.NewRequest(http.MethodGet, url, nil)
assert.Nil(t, err)
rr := httptest.NewRecorder()
router.ServeHTTP(rr, r)
assert.Equal(t, "whatever:200000200000", rr.Body.String())
}
assert.EqualValues(t, map[string]string{
"kevin": "2017",
"wan": "2020",
"whatever": "200000200000",
}, m)
}
func TestWithPriority(t *testing.T) {
var fr featuredRoutes
WithPriority()(&fr)
assert.True(t, fr.priority)
}

View File

@@ -1,11 +1,15 @@
package parser package gogen
import ( import (
goformat "go/format"
"io/ioutil" "io/ioutil"
"os" "os"
"path/filepath"
"strings"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/tools/goctl/api/parser"
) )
const testApiTemplate = ` const testApiTemplate = `
@@ -119,13 +123,68 @@ service A-api {
} }
` `
const apiHasMiddleware = `
type Request struct {
Name string ` + "`" + `path:"name,options=you|me"` + "`" + `
}
type Response struct {
Message string ` + "`" + `json:"message"` + "`" + `
}
@server(
middleware: TokenValidate
)
service A-api {
@handler GreetHandler
get /greet/from/:name(Request) returns (Response)
}
`
const apiJwt = `
type Request struct {
Name string ` + "`" + `path:"name,options=you|me"` + "`" + `
}
type Response struct {
Message string ` + "`" + `json:"message"` + "`" + `
}
@server(
jwt: Auth
)
service A-api {
@handler GreetHandler
get /greet/from/:name(Request) returns (Response)
}
`
const apiJwtWithMiddleware = `
type Request struct {
Name string ` + "`" + `path:"name,options=you|me"` + "`" + `
}
type Response struct {
Message string ` + "`" + `json:"message"` + "`" + `
}
@server(
jwt: Auth
middleware: TokenValidate
)
service A-api {
@handler GreetHandler
get /greet/from/:name(Request) returns (Response)
}
`
func TestParser(t *testing.T) { func TestParser(t *testing.T) {
filename := "greet.api" filename := "greet.api"
err := ioutil.WriteFile(filename, []byte(testApiTemplate), os.ModePerm) err := ioutil.WriteFile(filename, []byte(testApiTemplate), os.ModePerm)
assert.Nil(t, err) assert.Nil(t, err)
defer os.Remove(filename) defer os.Remove(filename)
parser, err := NewParser(filename) parser, err := parser.NewParser(filename)
assert.Nil(t, err) assert.Nil(t, err)
api, err := parser.Parse() api, err := parser.Parse()
@@ -139,6 +198,8 @@ func TestParser(t *testing.T) {
assert.Equal(t, api.Service.Routes[1].RequestType.Name, "Request") assert.Equal(t, api.Service.Routes[1].RequestType.Name, "Request")
assert.Equal(t, api.Service.Routes[1].ResponseType.Name, "") assert.Equal(t, api.Service.Routes[1].ResponseType.Name, "")
validate(t, filename)
} }
func TestMultiService(t *testing.T) { func TestMultiService(t *testing.T) {
@@ -147,7 +208,7 @@ func TestMultiService(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
defer os.Remove(filename) defer os.Remove(filename)
parser, err := NewParser(filename) parser, err := parser.NewParser(filename)
assert.Nil(t, err) assert.Nil(t, err)
api, err := parser.Parse() api, err := parser.Parse()
@@ -155,6 +216,8 @@ func TestMultiService(t *testing.T) {
assert.Equal(t, len(api.Service.Routes), 2) assert.Equal(t, len(api.Service.Routes), 2)
assert.Equal(t, len(api.Service.Groups), 2) assert.Equal(t, len(api.Service.Groups), 2)
validate(t, filename)
} }
func TestApiNoInfo(t *testing.T) { func TestApiNoInfo(t *testing.T) {
@@ -163,11 +226,13 @@ func TestApiNoInfo(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
defer os.Remove(filename) defer os.Remove(filename)
parser, err := NewParser(filename) parser, err := parser.NewParser(filename)
assert.Nil(t, err) assert.Nil(t, err)
_, err = parser.Parse() _, err = parser.Parse()
assert.Nil(t, err) assert.Nil(t, err)
validate(t, filename)
} }
func TestInvalidApiFile(t *testing.T) { func TestInvalidApiFile(t *testing.T) {
@@ -176,7 +241,7 @@ func TestInvalidApiFile(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
defer os.Remove(filename) defer os.Remove(filename)
parser, err := NewParser(filename) parser, err := parser.NewParser(filename)
assert.Nil(t, err) assert.Nil(t, err)
_, err = parser.Parse() _, err = parser.Parse()
@@ -189,7 +254,7 @@ func TestAnonymousAnnotation(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
defer os.Remove(filename) defer os.Remove(filename)
parser, err := NewParser(filename) parser, err := parser.NewParser(filename)
assert.Nil(t, err) assert.Nil(t, err)
api, err := parser.Parse() api, err := parser.Parse()
@@ -197,4 +262,71 @@ func TestAnonymousAnnotation(t *testing.T) {
assert.Equal(t, len(api.Service.Routes), 1) assert.Equal(t, len(api.Service.Routes), 1)
assert.Equal(t, api.Service.Routes[0].Annotations[0].Value, "GreetHandler") assert.Equal(t, api.Service.Routes[0].Annotations[0].Value, "GreetHandler")
validate(t, filename)
}
func TestApiHasMiddleware(t *testing.T) {
filename := "greet.api"
err := ioutil.WriteFile(filename, []byte(apiHasMiddleware), os.ModePerm)
assert.Nil(t, err)
defer os.Remove(filename)
parser, err := parser.NewParser(filename)
assert.Nil(t, err)
_, err = parser.Parse()
assert.Nil(t, err)
validate(t, filename)
}
func TestApiHasJwt(t *testing.T) {
filename := "jwt.api"
err := ioutil.WriteFile(filename, []byte(apiJwt), os.ModePerm)
assert.Nil(t, err)
defer os.Remove(filename)
parser, err := parser.NewParser(filename)
assert.Nil(t, err)
_, err = parser.Parse()
assert.Nil(t, err)
validate(t, filename)
}
func TestApiHasJwtAndMiddleware(t *testing.T) {
filename := "jwt.api"
err := ioutil.WriteFile(filename, []byte(apiJwtWithMiddleware), os.ModePerm)
assert.Nil(t, err)
defer os.Remove(filename)
parser, err := parser.NewParser(filename)
assert.Nil(t, err)
_, err = parser.Parse()
assert.Nil(t, err)
validate(t, filename)
}
func validate(t *testing.T, api string) {
dir := "_go"
err := DoGenProject(api, dir, true)
defer os.RemoveAll(dir)
assert.Nil(t, err)
filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
if strings.HasSuffix(path, ".go") {
code, err := ioutil.ReadFile(path)
assert.Nil(t, err)
assert.Nil(t, validateCode(string(code)))
}
return nil
})
}
func validateCode(code string) error {
_, err := goformat.Source([]byte(code))
return err
} }

View File

@@ -8,7 +8,7 @@ import (
"github.com/tal-tech/go-zero/tools/goctl/api/spec" "github.com/tal-tech/go-zero/tools/goctl/api/spec"
"github.com/tal-tech/go-zero/tools/goctl/api/util" "github.com/tal-tech/go-zero/tools/goctl/api/util"
"github.com/tal-tech/go-zero/tools/goctl/templatex" ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/vars" "github.com/tal-tech/go-zero/tools/goctl/vars"
) )
@@ -48,7 +48,7 @@ func genConfig(dir string, api *spec.ApiSpec) error {
} }
var authImportStr = fmt.Sprintf("\"%s/rest\"", vars.ProjectOpenSourceUrl) var authImportStr = fmt.Sprintf("\"%s/rest\"", vars.ProjectOpenSourceUrl)
text, err := templatex.LoadTemplate(category, configTemplateFile, configTemplate) text, err := ctlutil.LoadTemplate(category, configTemplateFile, configTemplate)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -8,7 +8,7 @@ import (
"github.com/tal-tech/go-zero/tools/goctl/api/spec" "github.com/tal-tech/go-zero/tools/goctl/api/spec"
"github.com/tal-tech/go-zero/tools/goctl/api/util" "github.com/tal-tech/go-zero/tools/goctl/api/util"
"github.com/tal-tech/go-zero/tools/goctl/templatex" ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
) )
const ( const (
@@ -40,7 +40,7 @@ func genEtc(dir string, api *spec.ApiSpec) error {
port = strconv.Itoa(defaultPort) port = strconv.Itoa(defaultPort)
} }
text, err := templatex.LoadTemplate(category, etcTemplateFile, etcTemplate) text, err := ctlutil.LoadTemplate(category, etcTemplateFile, etcTemplate)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -9,7 +9,6 @@ import (
"github.com/tal-tech/go-zero/tools/goctl/api/spec" "github.com/tal-tech/go-zero/tools/goctl/api/spec"
apiutil "github.com/tal-tech/go-zero/tools/goctl/api/util" apiutil "github.com/tal-tech/go-zero/tools/goctl/api/util"
"github.com/tal-tech/go-zero/tools/goctl/templatex"
"github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/vars" "github.com/tal-tech/go-zero/tools/goctl/vars"
) )
@@ -94,7 +93,7 @@ func doGenToFile(dir, handler string, group spec.Group, route spec.Route, handle
} }
defer fp.Close() defer fp.Close()
text, err := templatex.LoadTemplate(category, handlerTemplateFile, handlerTemplate) text, err := util.LoadTemplate(category, handlerTemplateFile, handlerTemplate)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -9,7 +9,6 @@ import (
"github.com/tal-tech/go-zero/tools/goctl/api/spec" "github.com/tal-tech/go-zero/tools/goctl/api/spec"
"github.com/tal-tech/go-zero/tools/goctl/api/util" "github.com/tal-tech/go-zero/tools/goctl/api/util"
"github.com/tal-tech/go-zero/tools/goctl/templatex"
ctlutil "github.com/tal-tech/go-zero/tools/goctl/util" ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/vars" "github.com/tal-tech/go-zero/tools/goctl/vars"
) )
@@ -94,7 +93,7 @@ func genLogicByRoute(dir string, group spec.Group, route spec.Route) error {
requestString = "req " + "types." + strings.Title(route.RequestType.Name) requestString = "req " + "types." + strings.Title(route.RequestType.Name)
} }
text, err := templatex.LoadTemplate(category, logicTemplateFile, logicTemplate) text, err := ctlutil.LoadTemplate(category, logicTemplateFile, logicTemplate)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -8,7 +8,6 @@ import (
"github.com/tal-tech/go-zero/tools/goctl/api/spec" "github.com/tal-tech/go-zero/tools/goctl/api/spec"
"github.com/tal-tech/go-zero/tools/goctl/api/util" "github.com/tal-tech/go-zero/tools/goctl/api/util"
"github.com/tal-tech/go-zero/tools/goctl/templatex"
ctlutil "github.com/tal-tech/go-zero/tools/goctl/util" ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/vars" "github.com/tal-tech/go-zero/tools/goctl/vars"
) )
@@ -61,7 +60,7 @@ func genMain(dir string, api *spec.ApiSpec) error {
return err return err
} }
text, err := templatex.LoadTemplate(category, mainTemplateFile, mainTemplate) text, err := ctlutil.LoadTemplate(category, mainTemplateFile, mainTemplate)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -31,9 +31,9 @@ func RegisterHandlers(engine *rest.Server, serverCtx *svc.ServiceContext) {
} }
` `
routesAdditionTemplate = ` routesAdditionTemplate = `
engine.AddRoutes([]rest.Route{ engine.AddRoutes(
{{.routes}} {{.routes}} {{.jwt}}{{.signature}}
}{{.jwt}}{{.signature}}) )
` `
) )
@@ -52,6 +52,7 @@ type (
jwtEnabled bool jwtEnabled bool
signatureEnabled bool signatureEnabled bool
authName string authName string
middleware []string
} }
route struct { route struct {
method string method string
@@ -70,6 +71,7 @@ func genRoutes(dir string, api *spec.ApiSpec, force bool) error {
gt := template.Must(template.New("groupTemplate").Parse(routesAdditionTemplate)) gt := template.Must(template.New("groupTemplate").Parse(routesAdditionTemplate))
for _, g := range groups { for _, g := range groups {
var gbuilder strings.Builder var gbuilder strings.Builder
gbuilder.WriteString("[]rest.Route{")
for _, r := range g.routes { for _, r := range g.routes {
fmt.Fprintf(&gbuilder, ` fmt.Fprintf(&gbuilder, `
{ {
@@ -79,16 +81,33 @@ func genRoutes(dir string, api *spec.ApiSpec, force bool) error {
},`, },`,
r.method, r.path, r.handler) r.method, r.path, r.handler)
} }
var jwt string var jwt string
if g.jwtEnabled { if g.jwtEnabled {
jwt = fmt.Sprintf(", rest.WithJwt(serverCtx.Config.%s.AccessSecret)", g.authName) jwt = fmt.Sprintf("\n rest.WithJwt(serverCtx.Config.%s.AccessSecret),", g.authName)
} }
var signature string var signature string
if g.signatureEnabled { if g.signatureEnabled {
signature = fmt.Sprintf(", rest.WithSignature(serverCtx.Config.%s.Signature)", g.authName) signature = fmt.Sprintf("\n rest.WithSignature(serverCtx.Config.%s.Signature),", g.authName)
} }
var routes string
if len(g.middleware) > 0 {
gbuilder.WriteString("\n}...,")
var params = g.middleware
for i := range params {
params[i] = "serverCtx." + params[i]
}
var middlewareStr = strings.Join(params, ", ")
routes = fmt.Sprintf("rest.WithMiddlewares(\n[]rest.Middleware{ %s }, \n %s \n),",
middlewareStr, strings.TrimSpace(gbuilder.String()))
} else {
gbuilder.WriteString("\n},")
routes = strings.TrimSpace(gbuilder.String())
}
if err := gt.Execute(&builder, map[string]string{ if err := gt.Execute(&builder, map[string]string{
"routes": strings.TrimSpace(gbuilder.String()), "routes": routes,
"jwt": jwt, "jwt": jwt,
"signature": signature, "signature": signature,
}); err != nil { }); err != nil {
@@ -185,6 +204,11 @@ func getRoutes(api *spec.ApiSpec) ([]group, error) {
groupedRoutes.authName = value groupedRoutes.authName = value
groupedRoutes.jwtEnabled = true groupedRoutes.jwtEnabled = true
} }
if value, ok := apiutil.GetAnnotationValue(g.Annotations, "server", "middleware"); ok {
for _, item := range strings.Split(value, ",") {
groupedRoutes.middleware = append(groupedRoutes.middleware, item)
}
}
routes = append(routes, groupedRoutes) routes = append(routes, groupedRoutes)
} }

View File

@@ -7,18 +7,21 @@ import (
"github.com/tal-tech/go-zero/tools/goctl/api/spec" "github.com/tal-tech/go-zero/tools/goctl/api/spec"
"github.com/tal-tech/go-zero/tools/goctl/api/util" "github.com/tal-tech/go-zero/tools/goctl/api/util"
"github.com/tal-tech/go-zero/tools/goctl/templatex"
ctlutil "github.com/tal-tech/go-zero/tools/goctl/util" ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/vars"
) )
const ( const (
contextFilename = "servicecontext.go" contextFilename = "servicecontext.go"
contextTemplate = `package svc contextTemplate = `package svc
import {{.configImport}} import (
{{.configImport}}
)
type ServiceContext struct { type ServiceContext struct {
Config {{.config}} Config {{.config}}
{{.middleware}}
} }
func NewServiceContext(c {{.config}}) *ServiceContext { func NewServiceContext(c {{.config}}) *ServiceContext {
@@ -48,17 +51,27 @@ func genServiceContext(dir string, api *spec.ApiSpec) error {
return err return err
} }
text, err := templatex.LoadTemplate(category, contextTemplateFile, contextTemplate) text, err := ctlutil.LoadTemplate(category, contextTemplateFile, contextTemplate)
if err != nil { if err != nil {
return err return err
} }
var middlewareStr string
for _, item := range getMiddleware(api) {
middlewareStr += fmt.Sprintf("%s rest.Middleware\n", item)
}
var configImport = "\"" + ctlutil.JoinPackages(parentPkg, configDir) + "\"" var configImport = "\"" + ctlutil.JoinPackages(parentPkg, configDir) + "\""
if len(middlewareStr) > 0 {
configImport += fmt.Sprintf("\n\"%s/rest\"", vars.ProjectOpenSourceUrl)
}
t := template.Must(template.New("contextTemplate").Parse(text)) t := template.Must(template.New("contextTemplate").Parse(text))
buffer := new(bytes.Buffer) buffer := new(bytes.Buffer)
err = t.Execute(buffer, map[string]string{ err = t.Execute(buffer, map[string]string{
"configImport": configImport, "configImport": configImport,
"config": "config.Config", "config": "config.Config",
"middleware": middlewareStr,
}) })
if err != nil { if err != nil {
return nil return nil

View File

@@ -1,7 +1,9 @@
package gogen package gogen
import ( import (
"github.com/tal-tech/go-zero/tools/goctl/templatex" "fmt"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/urfave/cli" "github.com/urfave/cli"
) )
@@ -25,5 +27,29 @@ var templates = map[string]string{
} }
func GenTemplates(_ *cli.Context) error { func GenTemplates(_ *cli.Context) error {
return templatex.InitTemplates(category, templates) return util.InitTemplates(category, templates)
}
func RevertTemplate(name string) error {
content, ok := templates[name]
if !ok {
return fmt.Errorf("%s: no such file name", name)
}
return util.CreateTemplate(category, name, content)
}
func Update(category string) error {
err := Clean()
if err != nil {
return err
}
return util.InitTemplates(category, templates)
}
func Clean() error {
return util.Clean(category)
}
func GetCategory() string {
return category
} }

View File

@@ -0,0 +1,92 @@
package gogen
import (
"io/ioutil"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/tools/goctl/util"
)
func TestGenTemplates(t *testing.T) {
err := util.InitTemplates(category, templates)
assert.Nil(t, err)
dir, err := util.GetTemplateDir(category)
assert.Nil(t, err)
file := filepath.Join(dir, "main.tpl")
data, err := ioutil.ReadFile(file)
assert.Nil(t, err)
assert.Equal(t, string(data), mainTemplate)
}
func TestRevertTemplate(t *testing.T) {
name := "main.tpl"
err := util.InitTemplates(category, templates)
assert.Nil(t, err)
dir, err := util.GetTemplateDir(category)
assert.Nil(t, err)
file := filepath.Join(dir, name)
data, err := ioutil.ReadFile(file)
assert.Nil(t, err)
modifyData := string(data) + "modify"
err = util.CreateTemplate(category, name, modifyData)
assert.Nil(t, err)
data, err = ioutil.ReadFile(file)
assert.Nil(t, err)
assert.Equal(t, string(data), modifyData)
assert.Nil(t, RevertTemplate(name))
data, err = ioutil.ReadFile(file)
assert.Nil(t, err)
assert.Equal(t, mainTemplate, string(data))
}
func TestClean(t *testing.T) {
name := "main.tpl"
err := util.InitTemplates(category, templates)
assert.Nil(t, err)
assert.Nil(t, Clean())
dir, err := util.GetTemplateDir(category)
assert.Nil(t, err)
file := filepath.Join(dir, name)
_, err = ioutil.ReadFile(file)
assert.NotNil(t, err)
}
func TestUpdate(t *testing.T) {
name := "main.tpl"
err := util.InitTemplates(category, templates)
assert.Nil(t, err)
dir, err := util.GetTemplateDir(category)
assert.Nil(t, err)
file := filepath.Join(dir, name)
data, err := ioutil.ReadFile(file)
assert.Nil(t, err)
modifyData := string(data) + "modify"
err = util.CreateTemplate(category, name, modifyData)
assert.Nil(t, err)
data, err = ioutil.ReadFile(file)
assert.Nil(t, err)
assert.Equal(t, string(data), modifyData)
assert.Nil(t, Update(category))
data, err = ioutil.ReadFile(file)
assert.Nil(t, err)
assert.Equal(t, mainTemplate, string(data))
}

View File

@@ -66,6 +66,18 @@ func getAuths(api *spec.ApiSpec) []string {
return authNames.KeysStr() return authNames.KeysStr()
} }
func getMiddleware(api *spec.ApiSpec) []string {
result := collection.NewSet()
for _, g := range api.Service.Groups {
if value, ok := util.GetAnnotationValue(g.Annotations, "server", "middleware"); ok {
for _, item := range strings.Split(value, ",") {
result.Add(strings.TrimSpace(item))
}
}
}
return result.KeysStr()
}
func formatCode(code string) string { func formatCode(code string) string {
ret, err := goformat.Source([]byte(code)) ret, err := goformat.Source([]byte(code))
if err != nil { if err != nil {

View File

@@ -1,18 +0,0 @@
package feature
import (
"fmt"
"github.com/logrusorgru/aurora"
"github.com/urfave/cli"
)
var feature = `
1、增加goctl model支持
`
func Feature(_ *cli.Context) error {
fmt.Println(aurora.Blue("\nFEATURE:"))
fmt.Println(aurora.Blue(feature))
return nil
}

View File

@@ -3,6 +3,7 @@ package main
import ( import (
"fmt" "fmt"
"os" "os"
"runtime"
"github.com/tal-tech/go-zero/core/logx" "github.com/tal-tech/go-zero/core/logx"
"github.com/tal-tech/go-zero/tools/goctl/api/apigen" "github.com/tal-tech/go-zero/tools/goctl/api/apigen"
@@ -17,15 +18,15 @@ import (
"github.com/tal-tech/go-zero/tools/goctl/api/validate" "github.com/tal-tech/go-zero/tools/goctl/api/validate"
"github.com/tal-tech/go-zero/tools/goctl/configgen" "github.com/tal-tech/go-zero/tools/goctl/configgen"
"github.com/tal-tech/go-zero/tools/goctl/docker" "github.com/tal-tech/go-zero/tools/goctl/docker"
"github.com/tal-tech/go-zero/tools/goctl/feature"
model "github.com/tal-tech/go-zero/tools/goctl/model/sql/command" model "github.com/tal-tech/go-zero/tools/goctl/model/sql/command"
rpc "github.com/tal-tech/go-zero/tools/goctl/rpc/command" rpc "github.com/tal-tech/go-zero/tools/goctl/rpc/command"
"github.com/tal-tech/go-zero/tools/goctl/tpl"
"github.com/urfave/cli" "github.com/urfave/cli"
) )
var ( var (
BuildTime = "not set" BuildVersion = "20201021"
commands = []cli.Command{ commands = []cli.Command{
{ {
Name: "api", Name: "api",
Usage: "generate api related files", Usage: "generate api related files",
@@ -102,13 +103,6 @@ var (
}, },
}, },
Action: gogen.GoCommand, Action: gogen.GoCommand,
Subcommands: []cli.Command{
{
Name: "template",
Usage: "initialize the api templates",
Action: gogen.GenTemplates,
},
},
}, },
{ {
Name: "java", Name: "java",
@@ -335,9 +329,46 @@ var (
Action: configgen.GenConfigCommand, Action: configgen.GenConfigCommand,
}, },
{ {
Name: "feature", Name: "template",
Usage: "the features of the latest version", Usage: "template operation",
Action: feature.Feature, Subcommands: []cli.Command{
{
Name: "init",
Usage: "initialize the all templates(force update)",
Action: tpl.GenTemplates,
},
{
Name: "clean",
Usage: "clean the all cache templates",
Action: tpl.CleanTemplates,
},
{
Name: "update",
Usage: "update template of the target category to the latest",
Flags: []cli.Flag{
cli.StringFlag{
Name: "category,c",
Usage: "the category of template, enum [api,rpc,model]",
},
},
Action: tpl.UpdateTemplates,
},
{
Name: "revert",
Usage: "revert the target template to the latest",
Flags: []cli.Flag{
cli.StringFlag{
Name: "category,c",
Usage: "the category of template, enum [api,rpc,model]",
},
cli.StringFlag{
Name: "name,n",
Usage: "the target file name of template",
},
},
Action: tpl.RevertTemplates,
},
},
}, },
} }
) )
@@ -347,7 +378,7 @@ func main() {
app := cli.NewApp() app := cli.NewApp()
app.Usage = "a cli tool to generate code" app.Usage = "a cli tool to generate code"
app.Version = BuildTime app.Version = fmt.Sprintf("%s %s/%s", BuildVersion, runtime.GOOS, runtime.GOARCH)
app.Commands = commands app.Commands = commands
// cli already print error messages // cli already print error messages
if err := app.Run(os.Args); err != nil { if err := app.Run(os.Args); err != nil {

View File

@@ -67,6 +67,7 @@ spec:
path: /usr/share/zoneinfo/Asia/Shanghai path: /usr/share/zoneinfo/Asia/Shanghai
--- ---
apiVersion: v1 apiVersion: v1
kind: Service kind: Service
metadata: metadata:
@@ -84,6 +85,7 @@ spec:
type: NodePort{{if .envIsPreOrPro}} type: NodePort{{if .envIsPreOrPro}}
--- ---
apiVersion: autoscaling/v2beta1 apiVersion: autoscaling/v2beta1
kind: HorizontalPodAutoscaler kind: HorizontalPodAutoscaler
metadata: metadata:
@@ -105,6 +107,7 @@ spec:
targetAverageUtilization: 80 targetAverageUtilization: 80
--- ---
apiVersion: autoscaling/v2beta1 apiVersion: autoscaling/v2beta1
kind: HorizontalPodAutoscaler kind: HorizontalPodAutoscaler
metadata: metadata:
@@ -123,4 +126,5 @@ spec:
- type: Resource - type: Resource
resource: resource:
name: memory name: memory
targetAverageUtilization: 80{{end}}` targetAverageUtilization: 80{{end}}
`

View File

@@ -1,6 +1,5 @@
package k8s package k8s
// 无环境区分
var jobTmeplate = `apiVersion: batch/v1beta1 var jobTmeplate = `apiVersion: batch/v1beta1
kind: CronJob kind: CronJob
metadata: metadata:
@@ -43,4 +42,5 @@ spec:
volumes: volumes:
- name: timezone - name: timezone
hostPath: hostPath:
path: /usr/share/zoneinfo/Asia/Shanghai` path: /usr/share/zoneinfo/Asia/Shanghai
`

View File

@@ -7,22 +7,19 @@ import (
"text/template" "text/template"
) )
var ( const (
errUnknownServiceType = errors.New("unknown service type") ServiceTypeApi ServiceType = "api"
ServiceTypeRpc ServiceType = "rpc"
ServiceTypeJob ServiceType = "job"
envDev = "dev"
) )
const ( var errUnknownServiceType = errors.New("unknown service type")
ServiceTypeApi ServiceType = "api"
ServiceTypeRpc ServiceType = "rpc"
ServiceTypeJob ServiceType = "job"
ServiceTypeRmq ServiceType = "rmq"
ServiceTypeSync ServiceType = "sync"
envDev = "dev"
)
type ( type (
ServiceType string ServiceType string
K8sRequest struct {
KubeRequest struct {
Env string Env string
ServiceName string ServiceName string
ServiceType ServiceType ServiceType ServiceType
@@ -41,20 +38,18 @@ type (
} }
) )
func Gen(req K8sRequest) (string, error) { func Gen(req KubeRequest) (string, error) {
switch req.ServiceType { switch req.ServiceType {
case ServiceTypeApi, ServiceTypeRpc: case ServiceTypeApi, ServiceTypeRpc:
return genApiRpc(req) return genApiRpc(req)
case ServiceTypeJob: case ServiceTypeJob:
return genJob(req) return genJob(req)
case ServiceTypeRmq, ServiceTypeSync:
return genRmqSync(req)
default: default:
return "", errUnknownServiceType return "", errUnknownServiceType
} }
} }
func genApiRpc(req K8sRequest) (string, error) { func genApiRpc(req KubeRequest) (string, error) {
t, err := template.New("api_rpc").Parse(apiRpcTmeplate) t, err := template.New("api_rpc").Parse(apiRpcTmeplate)
if err != nil { if err != nil {
return "", err return "", err
@@ -83,33 +78,7 @@ func genApiRpc(req K8sRequest) (string, error) {
return buffer.String(), nil return buffer.String(), nil
} }
func genRmqSync(req K8sRequest) (string, error) { func genJob(req KubeRequest) (string, error) {
t, err := template.New("rmq_sync").Parse(rmqSyncTmeplate)
if err != nil {
return "", err
}
buffer := new(bytes.Buffer)
err = t.Execute(buffer, map[string]interface{}{
"name": fmt.Sprintf("%s-%s", req.ServiceName, req.ServiceType),
"namespace": req.Namespace,
"replicas": req.Replicas,
"revisionHistoryLimit": req.RevisionHistoryLimit,
"limitCpu": req.LimitCpu,
"limitMem": req.LimitMem,
"requestCpu": req.RequestCpu,
"requestMem": req.RequestMem,
"serviceName": req.ServiceName,
"env": req.Env,
"envIsPreOrPro": req.Env != envDev,
"envIsDev": req.Env == envDev,
})
if err != nil {
return "", nil
}
return buffer.String(), nil
}
func genJob(req K8sRequest) (string, error) {
t, err := template.New("job").Parse(jobTmeplate) t, err := template.New("job").Parse(jobTmeplate)
if err != nil { if err != nil {
return "", err return "", err

View File

@@ -1,68 +0,0 @@
package k8s
var rmqSyncTmeplate = `apiVersion: apps/v1beta2
kind: Deployment
metadata:
name: {{.name}}
namespace: {{.namespace}}
labels:
app: {{.name}}
spec:
replicas: {{.replicas}}
revisionHistoryLimit: {{.revisionHistoryLimit}}
selector:
matchLabels:
app: {{.name}}
template:
metadata:
labels:
app: {{.name}}
spec:{{if .envIsDev}}
terminationGracePeriodSeconds: 60{{end}}
containers:
- name: {{.name}}
image: registry-vpc.cn-hangzhou.aliyuncs.com/{{.namespace}}/
lifecycle:
preStop:
exec:
command: ["sh","-c","sleep 5"]
env:
- name: aliyun_logs_k8slog
value: "stdout"
- name: aliyun_logs_k8slog_tags
value: "stage={{.env}}"
- name: aliyun_logs_k8slog_format
value: "json"
resources:
limits:
cpu: {{.limitCpu}}m
memory: {{.limitMem}}Mi
requests:
cpu: {{.requestCpu}}m
memory: {{.requestMem}}Mi
command:
- ./{{.serviceName}}
- -f
- ./{{.name}}.json
volumeMounts:
- name: timezone
mountPath: /etc/localtime
imagePullSecrets:
- name: {{.namespace}}
volumes:
- name: timezone
hostPath:
path: /usr/share/zoneinfo/Asia/Shanghai{{if .envIsPreOrPro}}
---
apiVersion: v1
kind: Service
metadata:
name: {{.name}}-svc
namespace: {{.namespace}}
spec:
selector:
app: {{.name}}
sessionAffinity: None
type: ClusterIP
clusterIP: None{{end}}`

View File

@@ -1,5 +1,9 @@
# Change log # Change log
## 2020-10-19
* 增加template
## 2020-08-20 ## 2020-08-20
* 新增支持通过连接数据库生成model * 新增支持通过连接数据库生成model

View File

@@ -0,0 +1,20 @@
package converter
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestConvertDataType(t *testing.T) {
v, err := ConvertDataType("tinyint")
assert.Nil(t, err)
assert.Equal(t, "int64", v)
v, err = ConvertDataType("timestamp")
assert.Nil(t, err)
assert.Equal(t, "time.Time", v)
_, err = ConvertDataType("float32")
assert.NotNil(t, err)
}

View File

@@ -4,4 +4,4 @@
goctl model mysql ddl -src="./sql/user.sql" -dir="./sql/model" -c goctl model mysql ddl -src="./sql/user.sql" -dir="./sql/model" -c
# generate model with cache from data source # generate model with cache from data source
goctl model mysql datasource -url="user:password@tcp(127.0.0.1:3306)/database" -table="table1,table2" -dir="./model" #goctl model mysql datasource -url="user:password@tcp(127.0.0.1:3306)/database" -table="table1,table2" -dir="./model"

View File

@@ -5,7 +5,7 @@ import (
"github.com/tal-tech/go-zero/core/collection" "github.com/tal-tech/go-zero/core/collection"
"github.com/tal-tech/go-zero/tools/goctl/model/sql/template" "github.com/tal-tech/go-zero/tools/goctl/model/sql/template"
"github.com/tal-tech/go-zero/tools/goctl/templatex" "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx" "github.com/tal-tech/go-zero/tools/goctl/util/stringx"
) )
@@ -22,8 +22,13 @@ func genDelete(table Table, withCache bool) (string, error) {
} }
camel := table.Name.ToCamel() camel := table.Name.ToCamel()
output, err := templatex.With("delete"). text, err := util.LoadTemplate(category, deleteTemplateFile, template.Delete)
Parse(template.Delete). if err != nil {
return "", err
}
output, err := util.With("delete").
Parse(text).
Execute(map[string]interface{}{ Execute(map[string]interface{}{
"upperStartCamelObject": camel, "upperStartCamelObject": camel,
"withCache": withCache, "withCache": withCache,

View File

@@ -5,7 +5,7 @@ import (
"github.com/tal-tech/go-zero/tools/goctl/model/sql/parser" "github.com/tal-tech/go-zero/tools/goctl/model/sql/parser"
"github.com/tal-tech/go-zero/tools/goctl/model/sql/template" "github.com/tal-tech/go-zero/tools/goctl/model/sql/template"
"github.com/tal-tech/go-zero/tools/goctl/templatex" "github.com/tal-tech/go-zero/tools/goctl/util"
) )
func genFields(fields []parser.Field) (string, error) { func genFields(fields []parser.Field) (string, error) {
@@ -25,8 +25,14 @@ func genField(field parser.Field) (string, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
output, err := templatex.With("types").
Parse(template.Field). text, err := util.LoadTemplate(category, fieldTemplateFile, template.Field)
if err != nil {
return "", err
}
output, err := util.With("types").
Parse(text).
Execute(map[string]interface{}{ Execute(map[string]interface{}{
"name": field.Name.ToCamel(), "name": field.Name.ToCamel(),
"type": field.DataType, "type": field.DataType,

View File

@@ -2,14 +2,19 @@ package gen
import ( import (
"github.com/tal-tech/go-zero/tools/goctl/model/sql/template" "github.com/tal-tech/go-zero/tools/goctl/model/sql/template"
"github.com/tal-tech/go-zero/tools/goctl/templatex" "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx" "github.com/tal-tech/go-zero/tools/goctl/util/stringx"
) )
func genFindOne(table Table, withCache bool) (string, error) { func genFindOne(table Table, withCache bool) (string, error) {
camel := table.Name.ToCamel() camel := table.Name.ToCamel()
output, err := templatex.With("findOne"). text, err := util.LoadTemplate(category, findOneTemplateFile, template.FindOne)
Parse(template.FindOne). if err != nil {
return "", err
}
output, err := util.With("findOne").
Parse(text).
Execute(map[string]interface{}{ Execute(map[string]interface{}{
"withCache": withCache, "withCache": withCache,
"upperStartCamelObject": camel, "upperStartCamelObject": camel,

View File

@@ -5,12 +5,17 @@ import (
"strings" "strings"
"github.com/tal-tech/go-zero/tools/goctl/model/sql/template" "github.com/tal-tech/go-zero/tools/goctl/model/sql/template"
"github.com/tal-tech/go-zero/tools/goctl/templatex" "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx" "github.com/tal-tech/go-zero/tools/goctl/util/stringx"
) )
func genFindOneByField(table Table, withCache bool) (string, string, error) { func genFindOneByField(table Table, withCache bool) (string, string, error) {
t := templatex.With("findOneByField").Parse(template.FindOneByField) text, err := util.LoadTemplate(category, findOneByFieldTemplateFile, template.FindOneByField)
if err != nil {
return "", "", err
}
t := util.With("findOneByField").Parse(text)
var list []string var list []string
camelTableName := table.Name.ToCamel() camelTableName := table.Name.ToCamel()
for _, field := range table.Fields { for _, field := range table.Fields {
@@ -33,10 +38,16 @@ func genFindOneByField(table Table, withCache bool) (string, string, error) {
if err != nil { if err != nil {
return "", "", err return "", "", err
} }
list = append(list, output.String()) list = append(list, output.String())
} }
if withCache { if withCache {
out, err := templatex.With("findOneByFieldExtraMethod").Parse(template.FindOneByFieldExtraMethod).Execute(map[string]interface{}{ text, err := util.LoadTemplate(category, findOneByFieldExtraMethodTemplateFile, template.FindOneByFieldExtraMethod)
if err != nil {
return "", "", err
}
out, err := util.With("findOneByFieldExtraMethod").Parse(text).Execute(map[string]interface{}{
"upperStartCamelObject": camelTableName, "upperStartCamelObject": camelTableName,
"primaryKeyLeft": table.CacheKey[table.PrimaryKey.Name.Source()].Left, "primaryKeyLeft": table.CacheKey[table.PrimaryKey.Name.Source()].Left,
"lowerStartCamelObject": stringx.From(camelTableName).UnTitle(), "lowerStartCamelObject": stringx.From(camelTableName).UnTitle(),
@@ -45,6 +56,7 @@ func genFindOneByField(table Table, withCache bool) (string, string, error) {
if err != nil { if err != nil {
return "", "", err return "", "", err
} }
return strings.Join(list, "\n"), out.String(), nil return strings.Join(list, "\n"), out.String(), nil
} }
return strings.Join(list, "\n"), "", nil return strings.Join(list, "\n"), "", nil

View File

@@ -9,7 +9,6 @@ import (
"github.com/tal-tech/go-zero/tools/goctl/model/sql/parser" "github.com/tal-tech/go-zero/tools/goctl/model/sql/parser"
"github.com/tal-tech/go-zero/tools/goctl/model/sql/template" "github.com/tal-tech/go-zero/tools/goctl/model/sql/template"
"github.com/tal-tech/go-zero/tools/goctl/templatex"
"github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/console" "github.com/tal-tech/go-zero/tools/goctl/util/console"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx" "github.com/tal-tech/go-zero/tools/goctl/util/stringx"
@@ -120,7 +119,7 @@ type (
) )
func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, error) { func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, error) {
t := templatex.With("model"). t := util.With("model").
Parse(template.Model). Parse(template.Model).
GoFmt(true) GoFmt(true)

View File

@@ -2,25 +2,37 @@ package gen
import ( import (
"github.com/tal-tech/go-zero/tools/goctl/model/sql/template" "github.com/tal-tech/go-zero/tools/goctl/model/sql/template"
"github.com/tal-tech/go-zero/tools/goctl/templatex" "github.com/tal-tech/go-zero/tools/goctl/util"
) )
func genImports(withCache, timeImport bool) (string, error) { func genImports(withCache, timeImport bool) (string, error) {
if withCache { if withCache {
buffer, err := templatex.With("import").Parse(template.Imports).Execute(map[string]interface{}{ text, err := util.LoadTemplate(category, importsTemplateFile, template.Imports)
if err != nil {
return "", err
}
buffer, err := util.With("import").Parse(text).Execute(map[string]interface{}{
"time": timeImport, "time": timeImport,
}) })
if err != nil { if err != nil {
return "", err return "", err
} }
return buffer.String(), nil return buffer.String(), nil
} else { } else {
buffer, err := templatex.With("import").Parse(template.ImportsNoCache).Execute(map[string]interface{}{ text, err := util.LoadTemplate(category, importsWithNoCacheTemplateFile, template.ImportsNoCache)
if err != nil {
return "", err
}
buffer, err := util.With("import").Parse(text).Execute(map[string]interface{}{
"time": timeImport, "time": timeImport,
}) })
if err != nil { if err != nil {
return "", err return "", err
} }
return buffer.String(), nil return buffer.String(), nil
} }
} }

View File

@@ -5,7 +5,7 @@ import (
"github.com/tal-tech/go-zero/core/collection" "github.com/tal-tech/go-zero/core/collection"
"github.com/tal-tech/go-zero/tools/goctl/model/sql/template" "github.com/tal-tech/go-zero/tools/goctl/model/sql/template"
"github.com/tal-tech/go-zero/tools/goctl/templatex" "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx" "github.com/tal-tech/go-zero/tools/goctl/util/stringx"
) )
@@ -34,8 +34,13 @@ func genInsert(table Table, withCache bool) (string, error) {
expressionValues = append(expressionValues, "data."+camel) expressionValues = append(expressionValues, "data."+camel)
} }
camel := table.Name.ToCamel() camel := table.Name.ToCamel()
output, err := templatex.With("insert"). text, err := util.LoadTemplate(category, insertTemplateFile, template.Insert)
Parse(template.Insert). if err != nil {
return "", err
}
output, err := util.With("insert").
Parse(text).
Execute(map[string]interface{}{ Execute(map[string]interface{}{
"withCache": withCache, "withCache": withCache,
"containsIndexCache": table.ContainsUniqueKey, "containsIndexCache": table.ContainsUniqueKey,
@@ -49,5 +54,6 @@ func genInsert(table Table, withCache bool) (string, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
return output.String(), nil return output.String(), nil
} }

View File

@@ -12,7 +12,7 @@ type (
// {{prefix}}=cache // {{prefix}}=cache
// key:id // key:id
Key struct { Key struct {
VarExpression string // cacheUserIdPrefix="cache#user#id#" VarExpression string // cacheUserIdPrefix = "cache#User#id#"
Left string // cacheUserIdPrefix Left string // cacheUserIdPrefix
Right string // cache#user#id# Right string // cache#user#id#
Variable string // userIdKey Variable string // userIdKey

View File

@@ -0,0 +1,77 @@
package gen
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/tools/goctl/model/sql/parser"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
)
func TestGenCacheKeys(t *testing.T) {
m, err := genCacheKeys(parser.Table{
Name: stringx.From("user"),
PrimaryKey: parser.Primary{
Field: parser.Field{
Name: stringx.From("id"),
DataBaseType: "bigint",
DataType: "int64",
IsKey: false,
IsPrimaryKey: true,
IsUniqueKey: false,
Comment: "自增id",
},
AutoIncrement: true,
},
Fields: []parser.Field{
{
Name: stringx.From("mobile"),
DataBaseType: "varchar",
DataType: "string",
IsKey: false,
IsPrimaryKey: false,
IsUniqueKey: true,
Comment: "手机号",
},
{
Name: stringx.From("name"),
DataBaseType: "varchar",
DataType: "string",
IsKey: false,
IsPrimaryKey: false,
IsUniqueKey: true,
Comment: "姓名",
},
{
Name: stringx.From("createTime"),
DataBaseType: "timestamp",
DataType: "time.Time",
IsKey: false,
IsPrimaryKey: false,
IsUniqueKey: false,
Comment: "创建时间",
},
{
Name: stringx.From("updateTime"),
DataBaseType: "timestamp",
DataType: "time.Time",
IsKey: false,
IsPrimaryKey: false,
IsUniqueKey: false,
Comment: "更新时间",
},
},
})
assert.Nil(t, err)
for fieldName, key := range m {
name := stringx.From(fieldName)
assert.Equal(t, fmt.Sprintf(`cacheUser%sPrefix = "cache#User#%s#"`, name.ToCamel(), name.UnTitle()), key.VarExpression)
assert.Equal(t, fmt.Sprintf(`cacheUser%sPrefix`, name.ToCamel()), key.Left)
assert.Equal(t, fmt.Sprintf(`cache#User#%s#`, name.UnTitle()), key.Right)
assert.Equal(t, fmt.Sprintf(`user%sKey`, name.ToCamel()), key.Variable)
assert.Equal(t, `user`+name.ToCamel()+`Key := fmt.Sprintf("%s%v", cacheUser`+name.ToCamel()+`Prefix,`+name.UnTitle()+`)`, key.KeyExpression)
}
}

View File

@@ -2,12 +2,17 @@ package gen
import ( import (
"github.com/tal-tech/go-zero/tools/goctl/model/sql/template" "github.com/tal-tech/go-zero/tools/goctl/model/sql/template"
"github.com/tal-tech/go-zero/tools/goctl/templatex" "github.com/tal-tech/go-zero/tools/goctl/util"
) )
func genNew(table Table, withCache bool) (string, error) { func genNew(table Table, withCache bool) (string, error) {
output, err := templatex.With("new"). text, err := util.LoadTemplate(category, modelNewTemplateFile, template.New)
Parse(template.New). if err != nil {
return "", err
}
output, err := util.With("new").
Parse(text).
Execute(map[string]interface{}{ Execute(map[string]interface{}{
"withCache": withCache, "withCache": withCache,
"upperStartCamelObject": table.Name.ToCamel(), "upperStartCamelObject": table.Name.ToCamel(),
@@ -15,5 +20,6 @@ func genNew(table Table, withCache bool) (string, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
return output.String(), nil return output.String(), nil
} }

View File

@@ -2,15 +2,20 @@ package gen
import ( import (
"github.com/tal-tech/go-zero/tools/goctl/model/sql/template" "github.com/tal-tech/go-zero/tools/goctl/model/sql/template"
"github.com/tal-tech/go-zero/tools/goctl/templatex" "github.com/tal-tech/go-zero/tools/goctl/util"
) )
func genTag(in string) (string, error) { func genTag(in string) (string, error) {
if in == "" { if in == "" {
return in, nil return in, nil
} }
output, err := templatex.With("tag"). text, err := util.LoadTemplate(category, tagTemplateFile, template.Tag)
Parse(template.Tag). if err != nil {
return "", err
}
output, err := util.With("tag").
Parse(text).
Execute(map[string]interface{}{ Execute(map[string]interface{}{
"field": in, "field": in,
}) })

View File

@@ -0,0 +1,72 @@
package gen
import (
"fmt"
"github.com/tal-tech/go-zero/tools/goctl/model/sql/template"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/urfave/cli"
)
const (
category = "model"
deleteTemplateFile = "delete.tpl"
fieldTemplateFile = "filed.tpl"
findOneTemplateFile = "find-one.tpl"
findOneByFieldTemplateFile = "find-one-by-field.tpl"
findOneByFieldExtraMethodTemplateFile = "find-one-by-filed-extra-method.tpl"
importsTemplateFile = "import.tpl"
importsWithNoCacheTemplateFile = "import-no-cache.tpl"
insertTemplateFile = "insert.tpl"
modelTemplateFile = "model.tpl"
modelNewTemplateFile = "model-new.tpl"
tagTemplateFile = "tag.tpl"
typesTemplateFile = "types.tpl"
updateTemplateFile = "update.tpl"
varTemplateFile = "var.tpl"
)
var templates = map[string]string{
deleteTemplateFile: template.Delete,
fieldTemplateFile: template.Field,
findOneTemplateFile: template.FindOne,
findOneByFieldTemplateFile: template.FindOneByField,
findOneByFieldExtraMethodTemplateFile: template.FindOneByFieldExtraMethod,
importsTemplateFile: template.Imports,
importsWithNoCacheTemplateFile: template.ImportsNoCache,
insertTemplateFile: template.Insert,
modelTemplateFile: template.Model,
modelNewTemplateFile: template.New,
tagTemplateFile: template.Tag,
typesTemplateFile: template.Types,
updateTemplateFile: template.Update,
varTemplateFile: template.Vars,
}
func GenTemplates(_ *cli.Context) error {
return util.InitTemplates(category, templates)
}
func RevertTemplate(name string) error {
content, ok := templates[name]
if !ok {
return fmt.Errorf("%s: no such file name", name)
}
return util.CreateTemplate(category, name, content)
}
func Clean() error {
return util.Clean(category)
}
func Update(category string) error {
err := Clean()
if err != nil {
return err
}
return util.InitTemplates(category, templates)
}
func GetCategory() string {
return category
}

View File

@@ -0,0 +1,93 @@
package gen
import (
"io/ioutil"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/tools/goctl/model/sql/template"
"github.com/tal-tech/go-zero/tools/goctl/util"
)
func TestGenTemplates(t *testing.T) {
err := util.InitTemplates(category, templates)
assert.Nil(t, err)
dir, err := util.GetTemplateDir(category)
assert.Nil(t, err)
file := filepath.Join(dir, "model-new.tpl")
data, err := ioutil.ReadFile(file)
assert.Nil(t, err)
assert.Equal(t, string(data), template.New)
}
func TestRevertTemplate(t *testing.T) {
name := "model-new.tpl"
err := util.InitTemplates(category, templates)
assert.Nil(t, err)
dir, err := util.GetTemplateDir(category)
assert.Nil(t, err)
file := filepath.Join(dir, name)
data, err := ioutil.ReadFile(file)
assert.Nil(t, err)
modifyData := string(data) + "modify"
err = util.CreateTemplate(category, name, modifyData)
assert.Nil(t, err)
data, err = ioutil.ReadFile(file)
assert.Nil(t, err)
assert.Equal(t, string(data), modifyData)
assert.Nil(t, RevertTemplate(name))
data, err = ioutil.ReadFile(file)
assert.Nil(t, err)
assert.Equal(t, template.New, string(data))
}
func TestClean(t *testing.T) {
name := "model-new.tpl"
err := util.InitTemplates(category, templates)
assert.Nil(t, err)
assert.Nil(t, Clean())
dir, err := util.GetTemplateDir(category)
assert.Nil(t, err)
file := filepath.Join(dir, name)
_, err = ioutil.ReadFile(file)
assert.NotNil(t, err)
}
func TestUpdate(t *testing.T) {
name := "model-new.tpl"
err := util.InitTemplates(category, templates)
assert.Nil(t, err)
dir, err := util.GetTemplateDir(category)
assert.Nil(t, err)
file := filepath.Join(dir, name)
data, err := ioutil.ReadFile(file)
assert.Nil(t, err)
modifyData := string(data) + "modify"
err = util.CreateTemplate(category, name, modifyData)
assert.Nil(t, err)
data, err = ioutil.ReadFile(file)
assert.Nil(t, err)
assert.Equal(t, string(data), modifyData)
assert.Nil(t, Update(category))
data, err = ioutil.ReadFile(file)
assert.Nil(t, err)
assert.Equal(t, template.New, string(data))
}

View File

@@ -2,7 +2,7 @@ package gen
import ( import (
"github.com/tal-tech/go-zero/tools/goctl/model/sql/template" "github.com/tal-tech/go-zero/tools/goctl/model/sql/template"
"github.com/tal-tech/go-zero/tools/goctl/templatex" "github.com/tal-tech/go-zero/tools/goctl/util"
) )
func genTypes(table Table, withCache bool) (string, error) { func genTypes(table Table, withCache bool) (string, error) {
@@ -11,8 +11,14 @@ func genTypes(table Table, withCache bool) (string, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
output, err := templatex.With("types").
Parse(template.Types). text, err := util.LoadTemplate(category, typesTemplateFile, template.Types)
if err != nil {
return "", err
}
output, err := util.With("types").
Parse(text).
Execute(map[string]interface{}{ Execute(map[string]interface{}{
"withCache": withCache, "withCache": withCache,
"upperStartCamelObject": table.Name.ToCamel(), "upperStartCamelObject": table.Name.ToCamel(),
@@ -21,5 +27,6 @@ func genTypes(table Table, withCache bool) (string, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
return output.String(), nil return output.String(), nil
} }

View File

@@ -4,7 +4,7 @@ import (
"strings" "strings"
"github.com/tal-tech/go-zero/tools/goctl/model/sql/template" "github.com/tal-tech/go-zero/tools/goctl/model/sql/template"
"github.com/tal-tech/go-zero/tools/goctl/templatex" "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx" "github.com/tal-tech/go-zero/tools/goctl/util/stringx"
) )
@@ -22,8 +22,13 @@ func genUpdate(table Table, withCache bool) (string, error) {
} }
expressionValues = append(expressionValues, "data."+table.PrimaryKey.Name.ToCamel()) expressionValues = append(expressionValues, "data."+table.PrimaryKey.Name.ToCamel())
camelTableName := table.Name.ToCamel() camelTableName := table.Name.ToCamel()
output, err := templatex.With("update"). text, err := util.LoadTemplate(category, updateTemplateFile, template.Update)
Parse(template.Update). if err != nil {
return "", err
}
output, err := util.With("update").
Parse(text).
Execute(map[string]interface{}{ Execute(map[string]interface{}{
"withCache": withCache, "withCache": withCache,
"upperStartCamelObject": camelTableName, "upperStartCamelObject": camelTableName,
@@ -36,5 +41,6 @@ func genUpdate(table Table, withCache bool) (string, error) {
if err != nil { if err != nil {
return "", nil return "", nil
} }
return output.String(), nil return output.String(), nil
} }

View File

@@ -4,7 +4,7 @@ import (
"strings" "strings"
"github.com/tal-tech/go-zero/tools/goctl/model/sql/template" "github.com/tal-tech/go-zero/tools/goctl/model/sql/template"
"github.com/tal-tech/go-zero/tools/goctl/templatex" "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx" "github.com/tal-tech/go-zero/tools/goctl/util/stringx"
) )
@@ -14,8 +14,13 @@ func genVars(table Table, withCache bool) (string, error) {
keys = append(keys, v.VarExpression) keys = append(keys, v.VarExpression)
} }
camel := table.Name.ToCamel() camel := table.Name.ToCamel()
output, err := templatex.With("var"). text, err := util.LoadTemplate(category, varTemplateFile, template.Vars)
Parse(template.Vars). if err != nil {
return "", err
}
output, err := util.With("var").
Parse(text).
GoFmt(true). GoFmt(true).
Execute(map[string]interface{}{ Execute(map[string]interface{}{
"lowerStartCamelObject": stringx.From(camel).UnTitle(), "lowerStartCamelObject": stringx.From(camel).UnTitle(),

View File

@@ -17,11 +17,9 @@ func TestParseSelect(t *testing.T) {
} }
func TestParseCreateTable(t *testing.T) { func TestParseCreateTable(t *testing.T) {
_, err := Parse("CREATE TABLE `user_snake` (\n `id` bigint(10) NOT NULL AUTO_INCREMENT,\n `name` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '用户名称',\n `password` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '用户密码',\n `mobile` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '手机号',\n `gender` char(5) COLLATE utf8mb4_general_ci NOT NULL COMMENT '男|女|未公开',\n `nickname` varchar(255) COLLATE utf8mb4_general_ci DEFAULT '' COMMENT '用户昵称',\n `create_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP,\n `update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,\n PRIMARY KEY (`id`),\n UNIQUE KEY `name_index` (`name`),\n KEY `mobile_index` (`mobile`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci;") table, err := Parse("CREATE TABLE `user_snake` (\n `id` bigint(10) NOT NULL AUTO_INCREMENT,\n `name` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '用户名称',\n `password` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '用户密码',\n `mobile` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '手机号',\n `gender` char(5) COLLATE utf8mb4_general_ci NOT NULL COMMENT '男|女|未公开',\n `nickname` varchar(255) COLLATE utf8mb4_general_ci DEFAULT '' COMMENT '用户昵称',\n `create_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP,\n `update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,\n PRIMARY KEY (`id`),\n UNIQUE KEY `name_index` (`name`),\n KEY `mobile_index` (`mobile`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci;")
assert.Nil(t, err)
}
func TestParseCreateTable2(t *testing.T) {
_, err := Parse("create table `user_snake` (\n `id` bigint(10) NOT NULL AUTO_INCREMENT,\n `name` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '用户名称',\n `password` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '用户密码',\n `mobile` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '手机号',\n `gender` char(5) COLLATE utf8mb4_general_ci NOT NULL COMMENT '男|女|未公开',\n `nickname` varchar(255) COLLATE utf8mb4_general_ci DEFAULT '' COMMENT '用户昵称',\n `create_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP,\n `update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,\n PRIMARY KEY (`id`),\n UNIQUE KEY `name_index` (`name`),\n KEY `mobile_index` (`mobile`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci;")
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "user_snake", table.Name.Source())
assert.Equal(t, "id", table.PrimaryKey.Name.Source())
assert.Equal(t, true, table.ContainsTime())
} }

View File

@@ -15,6 +15,7 @@ var (
) )
` `
ImportsNoCache = `import ( ImportsNoCache = `import (
"database/sql"
"strings" "strings"
{{if .time}}"time"{{end}} {{if .time}}"time"{{end}}

View File

@@ -1,15 +1,15 @@
package template package template
var Insert = ` var Insert = `
func (m *{{.upperStartCamelObject}}Model) Insert(data {{.upperStartCamelObject}}) error { func (m *{{.upperStartCamelObject}}Model) Insert(data {{.upperStartCamelObject}}) (sql.Result,error) {
{{if .withCache}}{{if .containsIndexCache}}{{.keys}} {{if .withCache}}{{if .containsIndexCache}}{{.keys}}
_, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) { ret, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) {
query := ` + "`" + `insert into ` + "`" + ` + m.table + ` + "` (` + " + `{{.lowerStartCamelObject}}RowsExpectAutoSet` + " + `) values ({{.expression}})` " + ` query := ` + "`" + `insert into ` + "`" + ` + m.table + ` + "` (` + " + `{{.lowerStartCamelObject}}RowsExpectAutoSet` + " + `) values ({{.expression}})` " + `
return conn.Exec(query, {{.expressionValues}}) return conn.Exec(query, {{.expressionValues}})
}, {{.keyValues}}){{else}}query := ` + "`" + `insert into ` + "`" + ` + m.table + ` + "` (` + " + `{{.lowerStartCamelObject}}RowsExpectAutoSet` + " + `) values ({{.expression}})` " + ` }, {{.keyValues}}){{else}}query := ` + "`" + `insert into ` + "`" + ` + m.table + ` + "` (` + " + `{{.lowerStartCamelObject}}RowsExpectAutoSet` + " + `) values ({{.expression}})` " + `
_,err:=m.ExecNoCache(query, {{.expressionValues}}) ret,err:=m.ExecNoCache(query, {{.expressionValues}})
{{end}}{{else}}query := ` + "`" + `insert into ` + "`" + ` + m.table + ` + "` (` + " + `{{.lowerStartCamelObject}}RowsExpectAutoSet` + " + `) values ({{.expression}})` " + ` {{end}}{{else}}query := ` + "`" + `insert into ` + "`" + ` + m.table + ` + "` (` + " + `{{.lowerStartCamelObject}}RowsExpectAutoSet` + " + `) values ({{.expression}})` " + `
_,err:=m.conn.Exec(query, {{.expressionValues}}){{end}} ret,err:=m.conn.Exec(query, {{.expressionValues}}){{end}}
return err return ret,err
} }
` `

View File

@@ -1,13 +1,13 @@
package template package template
var Update = ` var Update = `
func (m *{{.upperStartCamelObject}}Model) Update(data {{.upperStartCamelObject}}) error { func (m *{{.upperStartCamelObject}}Model) Update(data {{.upperStartCamelObject}}) (sql.Result,error) {
{{if .withCache}}{{.primaryCacheKey}} {{if .withCache}}{{.primaryCacheKey}}
_, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) { ret, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) {
query := ` + "`" + `update ` + "` +" + `m.table +` + "` " + `set ` + "` +" + `{{.lowerStartCamelObject}}RowsWithPlaceHolder` + " + `" + ` where {{.originalPrimaryKey}} = ?` + "`" + ` query := ` + "`" + `update ` + "` +" + `m.table +` + "` " + `set ` + "` +" + `{{.lowerStartCamelObject}}RowsWithPlaceHolder` + " + `" + ` where {{.originalPrimaryKey}} = ?` + "`" + `
return conn.Exec(query, {{.expressionValues}}) return conn.Exec(query, {{.expressionValues}})
}, {{.primaryKeyVariable}}){{else}}query := ` + "`" + `update ` + "` +" + `m.table +` + "` " + `set ` + "` +" + `{{.lowerStartCamelObject}}RowsWithPlaceHolder` + " + `" + ` where {{.originalPrimaryKey}} = ?` + "`" + ` }, {{.primaryKeyVariable}}){{else}}query := ` + "`" + `update ` + "` +" + `m.table +` + "` " + `set ` + "` +" + `{{.lowerStartCamelObject}}RowsWithPlaceHolder` + " + `" + ` where {{.originalPrimaryKey}} = ?` + "`" + `
_,err:=m.conn.Exec(query, {{.expressionValues}}){{end}} ret,err:=m.conn.Exec(query, {{.expressionValues}}){{end}}
return err return ret,err
} }
` `

View File

@@ -1,5 +1,9 @@
# Change log # Change log
## 2020-10-19
* 增加template
## 2020-09-10 ## 2020-09-10
* rpc greet服务一键生成 * rpc greet服务一键生成

View File

@@ -32,7 +32,7 @@ func NewDefaultRpcGenerator(ctx *ctx.RpcContext) *defaultRpcGenerator {
} }
func (g *defaultRpcGenerator) Generate() (err error) { func (g *defaultRpcGenerator) Generate() (err error) {
g.Ctx.Info(aurora.Blue("-> goctl rpc reference documents: ").String() + "「https://github.com/tal-tech/go-zero/blob/master/doc/goctl-rpc.md」") g.Ctx.Info(aurora.Blue("-> goctl rpc reference documents: ").String() + "「https://github.com/tal-tech/zero-doc/blob/main/doc/goctl-rpc.md」")
g.Ctx.Warning("-> generating rpc code ...") g.Ctx.Warning("-> generating rpc code ...")
defer func() { defer func() {
if err == nil { if err == nil {

View File

@@ -7,7 +7,6 @@ import (
"github.com/tal-tech/go-zero/core/collection" "github.com/tal-tech/go-zero/core/collection"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser" "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/templatex"
"github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util"
) )
@@ -123,8 +122,12 @@ func (g *defaultRpcGenerator) genCall() error {
} }
filename := filepath.Join(callPath, typesFilename) filename := filepath.Join(callPath, typesFilename)
head := templatex.GetHead(g.Ctx.ProtoSource) head := util.GetHead(g.Ctx.ProtoSource)
err = templatex.With("types").GoFmt(true).Parse(callTemplateTypes).SaveTo(map[string]interface{}{ text, err := util.LoadTemplate(category, callTypesTemplateFile, callTemplateTypes)
if err != nil {
return err
}
err = util.With("types").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
"head": head, "head": head,
"const": constLit, "const": constLit,
"filePackage": service.Name.Lower(), "filePackage": service.Name.Lower(),
@@ -146,8 +149,11 @@ func (g *defaultRpcGenerator) genCall() error {
if err != nil { if err != nil {
return err return err
} }
text, err = util.LoadTemplate(category, callTemplateFile, callTemplateText)
err = templatex.With("shared").GoFmt(true).Parse(callTemplateText).SaveTo(map[string]interface{}{ if err != nil {
return err
}
err = util.With("shared").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
"name": service.Name.Lower(), "name": service.Name.Lower(),
"head": head, "head": head,
"filePackage": service.Name.Lower(), "filePackage": service.Name.Lower(),
@@ -167,7 +173,11 @@ func (g *defaultRpcGenerator) genFunction(service *parser.RpcService) ([]string,
imports.AddStr(fmt.Sprintf(`%v "%v"`, pkgName, g.mustGetPackage(dirPb))) imports.AddStr(fmt.Sprintf(`%v "%v"`, pkgName, g.mustGetPackage(dirPb)))
for _, method := range service.Funcs { for _, method := range service.Funcs {
imports.AddStr(g.ast.Imports[method.ParameterIn.Package]) imports.AddStr(g.ast.Imports[method.ParameterIn.Package])
buffer, err := templatex.With("sharedFn").Parse(callFunctionTemplate).Execute(map[string]interface{}{ text, err := util.LoadTemplate(category, callFunctionTemplateFile, callFunctionTemplate)
if err != nil {
return nil, nil, err
}
buffer, err := util.With("sharedFn").Parse(text).Execute(map[string]interface{}{
"rpcServiceName": service.Name.Title(), "rpcServiceName": service.Name.Title(),
"method": method.Name.Title(), "method": method.Name.Title(),
"package": pkgName, "package": pkgName,
@@ -190,7 +200,12 @@ func (g *defaultRpcGenerator) getInterfaceFuncs(service *parser.RpcService) ([]s
functions := make([]string, 0) functions := make([]string, 0)
for _, method := range service.Funcs { for _, method := range service.Funcs {
buffer, err := templatex.With("interfaceFn").Parse(callInterfaceFunctionTemplate).Execute( text, err := util.LoadTemplate(category, callInterfaceFunctionTemplateFile, callInterfaceFunctionTemplate)
if err != nil {
return nil, err
}
buffer, err := util.With("interfaceFn").Parse(text).Execute(
map[string]interface{}{ map[string]interface{}{
"hasComment": method.HaveDoc(), "hasComment": method.HaveDoc(),
"comment": method.GetDoc(), "comment": method.GetDoc(),

View File

@@ -23,5 +23,11 @@ func (g *defaultRpcGenerator) genConfig() error {
if util.FileExists(fileName) { if util.FileExists(fileName) {
return nil return nil
} }
return ioutil.WriteFile(fileName, []byte(configTemplate), os.ModePerm)
text, err := util.LoadTemplate(category, configTemplateFileFile, configTemplate)
if err != nil {
return err
}
return ioutil.WriteFile(fileName, []byte(text), os.ModePerm)
} }

View File

@@ -4,7 +4,6 @@ import (
"fmt" "fmt"
"path/filepath" "path/filepath"
"github.com/tal-tech/go-zero/tools/goctl/templatex"
"github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util"
) )
@@ -23,7 +22,12 @@ func (g *defaultRpcGenerator) genEtc() error {
return nil return nil
} }
return templatex.With("etc").Parse(etcTemplate).SaveTo(map[string]interface{}{ text, err := util.LoadTemplate(category, etcTemplateFileFile, etcTemplate)
if err != nil {
return err
}
return util.With("etc").Parse(text).SaveTo(map[string]interface{}{
"serviceName": g.Ctx.ServiceName.Lower(), "serviceName": g.Ctx.ServiceName.Lower(),
}, fileName, false) }, fileName, false)
} }

View File

@@ -7,7 +7,6 @@ import (
"github.com/tal-tech/go-zero/core/collection" "github.com/tal-tech/go-zero/core/collection"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser" "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/templatex"
"github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util"
) )
@@ -62,7 +61,11 @@ func (g *defaultRpcGenerator) genLogic() error {
svcImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirSvc)) svcImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirSvc))
imports.AddStr(svcImport) imports.AddStr(svcImport)
imports.AddStr(importList...) imports.AddStr(importList...)
err = templatex.With("logic").GoFmt(true).Parse(logicTemplate).SaveTo(map[string]interface{}{ text, err := util.LoadTemplate(category, logicTemplateFileFile, logicTemplate)
if err != nil {
return err
}
err = util.With("logic").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
"logicName": fmt.Sprintf("%sLogic", method.Name.Title()), "logicName": fmt.Sprintf("%sLogic", method.Name.Title()),
"functions": functions, "functions": functions,
"imports": strings.Join(imports.KeysStr(), util.NL), "imports": strings.Join(imports.KeysStr(), util.NL),
@@ -83,7 +86,12 @@ func (g *defaultRpcGenerator) genLogicFunction(packageName string, method *parse
} }
imports.AddStr(g.ast.Imports[method.ParameterIn.Package]) imports.AddStr(g.ast.Imports[method.ParameterIn.Package])
imports.AddStr(g.ast.Imports[method.ParameterOut.Package]) imports.AddStr(g.ast.Imports[method.ParameterOut.Package])
buffer, err := templatex.With("fun").Parse(logicFunctionTemplate).Execute(map[string]interface{}{ text, err := util.LoadTemplate(category, logicFuncTemplateFileFile, logicFunctionTemplate)
if err != nil {
return "", nil, err
}
buffer, err := util.With("fun").Parse(text).Execute(map[string]interface{}{
"logicName": fmt.Sprintf("%sLogic", method.Name.Title()), "logicName": fmt.Sprintf("%sLogic", method.Name.Title()),
"method": method.Name.Title(), "method": method.Name.Title(),
"request": method.ParameterIn.StarExpression, "request": method.ParameterIn.StarExpression,
@@ -95,6 +103,7 @@ func (g *defaultRpcGenerator) genLogicFunction(packageName string, method *parse
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }
functions = append(functions, buffer.String()) functions = append(functions, buffer.String())
return strings.Join(functions, util.NL), imports.KeysStr(), nil return strings.Join(functions, util.NL), imports.KeysStr(), nil
} }

View File

@@ -6,7 +6,6 @@ import (
"strings" "strings"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser" "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/templatex"
"github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util"
) )
@@ -58,8 +57,13 @@ func (g *defaultRpcGenerator) genMain() error {
configImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirConfig)) configImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirConfig))
imports = append(imports, configImport, pbImport, remoteImport, svcImport) imports = append(imports, configImport, pbImport, remoteImport, svcImport)
srv, registers := g.genServer(pkg, file.Service) srv, registers := g.genServer(pkg, file.Service)
head := templatex.GetHead(g.Ctx.ProtoSource) head := util.GetHead(g.Ctx.ProtoSource)
return templatex.With("main").GoFmt(true).Parse(mainTemplate).SaveTo(map[string]interface{}{ text, err := util.LoadTemplate(category, mainTemplateFile, mainTemplate)
if err != nil {
return err
}
return util.With("main").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
"head": head, "head": head,
"package": pkg, "package": pkg,
"serviceName": g.Ctx.ServiceName.Lower(), "serviceName": g.Ctx.ServiceName.Lower(),

View File

@@ -18,6 +18,7 @@ const (
func (g *defaultRpcGenerator) genPb() error { func (g *defaultRpcGenerator) genPb() error {
pbPath := g.dirM[dirPb] pbPath := g.dirM[dirPb]
// deprecated: containsAny will be removed in the feature
imports, containsAny, err := parser.ParseImport(g.Ctx.ProtoFileSrc) imports, containsAny, err := parser.ParseImport(g.Ctx.ProtoFileSrc)
if err != nil { if err != nil {
return err return err

View File

@@ -7,7 +7,6 @@ import (
"github.com/tal-tech/go-zero/core/collection" "github.com/tal-tech/go-zero/core/collection"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser" "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/templatex"
"github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util"
) )
@@ -52,7 +51,7 @@ func (g *defaultRpcGenerator) genHandler() error {
imports := collection.NewSet() imports := collection.NewSet()
imports.AddStr(logicImport, svcImport) imports.AddStr(logicImport, svcImport)
head := templatex.GetHead(g.Ctx.ProtoSource) head := util.GetHead(g.Ctx.ProtoSource)
for _, service := range file.Service { for _, service := range file.Service {
filename := fmt.Sprintf("%vserver.go", service.Name.Lower()) filename := fmt.Sprintf("%vserver.go", service.Name.Lower())
serverFile := filepath.Join(serverPath, filename) serverFile := filepath.Join(serverPath, filename)
@@ -60,8 +59,14 @@ func (g *defaultRpcGenerator) genHandler() error {
if err != nil { if err != nil {
return err return err
} }
imports.AddStr(importList...) imports.AddStr(importList...)
err = templatex.With("server").GoFmt(true).Parse(serverTemplate).SaveTo(map[string]interface{}{ text, err := util.LoadTemplate(category, serverTemplateFile, serverTemplate)
if err != nil {
return err
}
err = util.With("server").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
"head": head, "head": head,
"types": fmt.Sprintf(typeFmt, service.Name.Title()), "types": fmt.Sprintf(typeFmt, service.Name.Title()),
"server": service.Name.Title(), "server": service.Name.Title(),
@@ -86,7 +91,12 @@ func (g *defaultRpcGenerator) genFunctions(service *parser.RpcService) ([]string
} }
imports.AddStr(g.ast.Imports[method.ParameterIn.Package]) imports.AddStr(g.ast.Imports[method.ParameterIn.Package])
imports.AddStr(g.ast.Imports[method.ParameterOut.Package]) imports.AddStr(g.ast.Imports[method.ParameterOut.Package])
buffer, err := templatex.With("func").Parse(functionTemplate).Execute(map[string]interface{}{ text, err := util.LoadTemplate(category, serverFuncTemplateFile, functionTemplate)
if err != nil {
return nil, nil, err
}
buffer, err := util.With("func").Parse(text).Execute(map[string]interface{}{
"server": service.Name.Title(), "server": service.Name.Title(),
"logicName": fmt.Sprintf("%sLogic", method.Name.Title()), "logicName": fmt.Sprintf("%sLogic", method.Name.Title()),
"method": method.Name.Title(), "method": method.Name.Title(),
@@ -99,6 +109,7 @@ func (g *defaultRpcGenerator) genFunctions(service *parser.RpcService) ([]string
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
functionList = append(functionList, buffer.String()) functionList = append(functionList, buffer.String())
} }
return functionList, imports.KeysStr(), nil return functionList, imports.KeysStr(), nil

View File

@@ -4,7 +4,7 @@ import (
"fmt" "fmt"
"path/filepath" "path/filepath"
"github.com/tal-tech/go-zero/tools/goctl/templatex" "github.com/tal-tech/go-zero/tools/goctl/util"
) )
const svcTemplate = `package svc const svcTemplate = `package svc
@@ -25,7 +25,12 @@ func NewServiceContext(c config.Config) *ServiceContext {
func (g *defaultRpcGenerator) genSvc() error { func (g *defaultRpcGenerator) genSvc() error {
svcPath := g.dirM[dirSvc] svcPath := g.dirM[dirSvc]
fileName := filepath.Join(svcPath, fileServiceContext) fileName := filepath.Join(svcPath, fileServiceContext)
return templatex.With("svc").GoFmt(true).Parse(svcTemplate).SaveTo(map[string]interface{}{ text, err := util.LoadTemplate(category, svcTemplateFile, svcTemplate)
if err != nil {
return err
}
return util.With("svc").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
"imports": fmt.Sprintf(`"%v"`, g.mustGetPackage(dirConfig)), "imports": fmt.Sprintf(`"%v"`, g.mustGetPackage(dirConfig)),
}, fileName, false) }, fileName, false)
} }

View File

@@ -0,0 +1,59 @@
package gen
import (
"path/filepath"
"strings"
"github.com/logrusorgru/aurora"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/console"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
)
const rpcTemplateText = `syntax = "proto3";
package {{.package}};
message Request {
string ping = 1;
}
message Response {
string pong = 1;
}
service {{.serviceName}} {
rpc Ping(Request) returns(Response);
}
`
type rpcTemplate struct {
out string
console.Console
}
func NewRpcTemplate(out string, idea bool) *rpcTemplate {
return &rpcTemplate{
out: out,
Console: console.NewConsole(idea),
}
}
func (r *rpcTemplate) MustGenerate(showState bool) {
r.Info(aurora.Blue("-> goctl rpc reference documents: ").String() + "「https://github.com/tal-tech/zero-doc/blob/main/doc/goctl-rpc.md」")
r.Info("-> generating template...")
protoFilename := filepath.Base(r.out)
serviceName := stringx.From(strings.TrimSuffix(protoFilename, filepath.Ext(protoFilename)))
text, err := util.LoadTemplate(category, rpcTemplateFile, rpcTemplateText)
r.Must(err)
err = util.With("t").Parse(text).SaveTo(map[string]string{
"package": serviceName.UnTitle(),
"serviceName": serviceName.Title(),
}, r.out, false)
r.Must(err)
if showState {
r.Success("Done.")
}
}

View File

@@ -1,54 +1,69 @@
package gen package gen
import ( import (
"path/filepath" "fmt"
"strings"
"github.com/tal-tech/go-zero/tools/goctl/templatex" "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/console" "github.com/urfave/cli"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
) )
const rpcTemplateText = `syntax = "proto3"; const (
category = "rpc"
callTemplateFile = "call.tpl"
callTypesTemplateFile = "call-types.tpl"
callInterfaceFunctionTemplateFile = "call-interface-func.tpl"
callFunctionTemplateFile = "call-func.tpl"
configTemplateFileFile = "config.tpl"
etcTemplateFileFile = "etc.tpl"
logicTemplateFileFile = "logic.tpl"
logicFuncTemplateFileFile = "logic-func.tpl"
mainTemplateFile = "main.tpl"
serverTemplateFile = "server.tpl"
serverFuncTemplateFile = "server-func.tpl"
svcTemplateFile = "svc.tpl"
rpcTemplateFile = "template.tpl"
)
package {{.package}}; var templates = map[string]string{
callTemplateFile: callTemplateText,
message Request { callTypesTemplateFile: callTemplateTypes,
string ping = 1; callInterfaceFunctionTemplateFile: callInterfaceFunctionTemplate,
callFunctionTemplateFile: callFunctionTemplate,
configTemplateFileFile: configTemplate,
etcTemplateFileFile: etcTemplate,
logicTemplateFileFile: logicTemplate,
logicFuncTemplateFileFile: logicFunctionTemplate,
mainTemplateFile: mainTemplate,
serverTemplateFile: serverTemplate,
serverFuncTemplateFile: functionTemplate,
svcTemplateFile: svcTemplate,
rpcTemplateFile: rpcTemplateText,
} }
message Response { func GenTemplates(_ *cli.Context) error {
string pong = 1; return util.InitTemplates(category, templates)
} }
service {{.serviceName}} { func RevertTemplate(name string) error {
rpc Ping(Request) returns(Response); content, ok := templates[name]
} if !ok {
` return fmt.Errorf("%s: no such file name", name)
type rpcTemplate struct {
out string
console.Console
}
func NewRpcTemplate(out string, idea bool) *rpcTemplate {
return &rpcTemplate{
out: out,
Console: console.NewConsole(idea),
} }
return util.CreateTemplate(category, name, content)
} }
func (r *rpcTemplate) MustGenerate(showState bool) { func Clean() error {
r.Info("查看rpc生成请移步至「https://github.com/tal-tech/zero-doc/blob/main/doc/goctl-rpc.md」") return util.Clean(category)
r.Info("generating template...") }
protoFilename := filepath.Base(r.out)
serviceName := stringx.From(strings.TrimSuffix(protoFilename, filepath.Ext(protoFilename))) func Update(category string) error {
err := templatex.With("t").Parse(rpcTemplateText).SaveTo(map[string]string{ err := Clean()
"package": serviceName.UnTitle(), if err != nil {
"serviceName": serviceName.Title(), return err
}, r.out, false) }
r.Must(err) return util.InitTemplates(category, templates)
if showState { }
r.Success("Done.")
} func GetCategory() string {
return category
} }

View File

@@ -0,0 +1,92 @@
package gen
import (
"io/ioutil"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/tools/goctl/util"
)
func TestGenTemplates(t *testing.T) {
err := util.InitTemplates(category, templates)
assert.Nil(t, err)
dir, err := util.GetTemplateDir(category)
assert.Nil(t, err)
file := filepath.Join(dir, "main.tpl")
data, err := ioutil.ReadFile(file)
assert.Nil(t, err)
assert.Equal(t, string(data), mainTemplate)
}
func TestRevertTemplate(t *testing.T) {
name := "main.tpl"
err := util.InitTemplates(category, templates)
assert.Nil(t, err)
dir, err := util.GetTemplateDir(category)
assert.Nil(t, err)
file := filepath.Join(dir, name)
data, err := ioutil.ReadFile(file)
assert.Nil(t, err)
modifyData := string(data) + "modify"
err = util.CreateTemplate(category, name, modifyData)
assert.Nil(t, err)
data, err = ioutil.ReadFile(file)
assert.Nil(t, err)
assert.Equal(t, string(data), modifyData)
assert.Nil(t, RevertTemplate(name))
data, err = ioutil.ReadFile(file)
assert.Nil(t, err)
assert.Equal(t, mainTemplate, string(data))
}
func TestClean(t *testing.T) {
name := "main.tpl"
err := util.InitTemplates(category, templates)
assert.Nil(t, err)
assert.Nil(t, Clean())
dir, err := util.GetTemplateDir(category)
assert.Nil(t, err)
file := filepath.Join(dir, name)
_, err = ioutil.ReadFile(file)
assert.NotNil(t, err)
}
func TestUpdate(t *testing.T) {
name := "main.tpl"
err := util.InitTemplates(category, templates)
assert.Nil(t, err)
dir, err := util.GetTemplateDir(category)
assert.Nil(t, err)
file := filepath.Join(dir, name)
data, err := ioutil.ReadFile(file)
assert.Nil(t, err)
modifyData := string(data) + "modify"
err = util.CreateTemplate(category, name, modifyData)
assert.Nil(t, err)
data, err = ioutil.ReadFile(file)
assert.Nil(t, err)
assert.Equal(t, string(data), modifyData)
assert.Nil(t, Update(category))
data, err = ioutil.ReadFile(file)
assert.Nil(t, err)
assert.Equal(t, mainTemplate, string(data))
}

View File

@@ -12,7 +12,6 @@ import (
"github.com/tal-tech/go-zero/core/lang" "github.com/tal-tech/go-zero/core/lang"
sx "github.com/tal-tech/go-zero/core/stringx" sx "github.com/tal-tech/go-zero/core/stringx"
"github.com/tal-tech/go-zero/tools/goctl/templatex"
"github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/console" "github.com/tal-tech/go-zero/tools/goctl/util/console"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx" "github.com/tal-tech/go-zero/tools/goctl/util/stringx"
@@ -98,6 +97,7 @@ type (
} }
// parsing for rpc // parsing for rpc
PbAst struct { PbAst struct {
// deprecated: containsAny will be removed in the feature
ContainsAny bool ContainsAny bool
Imports map[string]string Imports map[string]string
Structure map[string]*Struct Structure map[string]*Struct
@@ -590,7 +590,7 @@ func (a *PbAst) GenTypesCode() (string, error) {
types = append(types, typeCode) types = append(types, typeCode)
} }
buffer, err := templatex.With("type").Parse(typeTemplate).Execute(map[string]interface{}{ buffer, err := util.With("type").Parse(typeTemplate).Execute(map[string]interface{}{
"types": strings.Join(types, util.NL+util.NL), "types": strings.Join(types, util.NL+util.NL),
}) })
if err != nil { if err != nil {
@@ -615,7 +615,7 @@ func (s *Struct) genCode(containsTypeStatement bool) (string, error) {
comment = f.Comment[0] comment = f.Comment[0]
} }
doc = strings.Join(f.Document, util.NL) doc = strings.Join(f.Document, util.NL)
buffer, err := templatex.With(sx.Rand()).Parse(fieldTemplate).Execute(map[string]interface{}{ buffer, err := util.With(sx.Rand()).Parse(fieldTemplate).Execute(map[string]interface{}{
"name": f.Name.Title(), "name": f.Name.Title(),
"type": f.Type.InvokeTypeExpression, "type": f.Type.InvokeTypeExpression,
"tag": f.JsonTag, "tag": f.JsonTag,
@@ -630,7 +630,7 @@ func (s *Struct) genCode(containsTypeStatement bool) (string, error) {
fields = append(fields, buffer.String()) fields = append(fields, buffer.String())
} }
buffer, err := templatex.With("struct").Parse(structTemplate).Execute(map[string]interface{}{ buffer, err := util.With("struct").Parse(structTemplate).Execute(map[string]interface{}{
"type": containsTypeStatement, "type": containsTypeStatement,
"name": s.Name.Title(), "name": s.Name.Title(),
"fields": strings.Join(fields, util.NL), "fields": strings.Join(fields, util.NL),

View File

@@ -10,7 +10,6 @@ import (
"github.com/emicklei/proto" "github.com/emicklei/proto"
"github.com/tal-tech/go-zero/core/collection" "github.com/tal-tech/go-zero/core/collection"
"github.com/tal-tech/go-zero/core/lang" "github.com/tal-tech/go-zero/core/lang"
"github.com/tal-tech/go-zero/tools/goctl/templatex"
"github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx" "github.com/tal-tech/go-zero/tools/goctl/util/stringx"
) )
@@ -48,9 +47,10 @@ type (
} }
Proto struct { Proto struct {
Package string Package string
Import []*Import Import []*Import
PbSrc string PbSrc string
// deprecated: containsAny will be removed in the feature
ContainsAny bool ContainsAny bool
Message map[string]lang.PlaceholderType Message map[string]lang.PlaceholderType
Enum map[string]*Enum Enum map[string]*Enum
@@ -263,7 +263,7 @@ func (e *Enum) GenEnumCode() (string, error) {
} }
element = append(element, code) element = append(element, code)
} }
buffer, err := templatex.With("enum").Parse(enumTemplate).Execute(map[string]interface{}{ buffer, err := util.With("enum").Parse(enumTemplate).Execute(map[string]interface{}{
"element": strings.Join(element, util.NL), "element": strings.Join(element, util.NL),
}) })
if err != nil { if err != nil {
@@ -273,7 +273,7 @@ func (e *Enum) GenEnumCode() (string, error) {
} }
func (e *Enum) GenEnumTypeCode() (string, error) { func (e *Enum) GenEnumTypeCode() (string, error) {
buffer, err := templatex.With("enumAlias").Parse(enumTypeTemplate).Execute(map[string]interface{}{ buffer, err := util.With("enumAlias").Parse(enumTypeTemplate).Execute(map[string]interface{}{
"name": e.Name.Source(), "name": e.Name.Source(),
}) })
if err != nil { if err != nil {
@@ -283,7 +283,7 @@ func (e *Enum) GenEnumTypeCode() (string, error) {
} }
func (e *EnumField) GenEnumFieldCode(parentName string) (string, error) { func (e *EnumField) GenEnumFieldCode(parentName string) (string, error) {
buffer, err := templatex.With("enumField").Parse(enumFiledTemplate).Execute(map[string]interface{}{ buffer, err := util.With("enumField").Parse(enumFiledTemplate).Execute(map[string]interface{}{
"key": e.Key, "key": e.Key,
"name": parentName, "name": parentName,
"value": e.Value, "value": e.Value,

View File

@@ -0,0 +1,102 @@
package tpl
import (
"fmt"
"github.com/logrusorgru/aurora"
"github.com/tal-tech/go-zero/core/errorx"
"github.com/tal-tech/go-zero/tools/goctl/api/gogen"
modelgen "github.com/tal-tech/go-zero/tools/goctl/model/sql/gen"
rpcgen "github.com/tal-tech/go-zero/tools/goctl/rpc/gen"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/urfave/cli"
)
const templateParentPath = "/"
func GenTemplates(ctx *cli.Context) error {
if err := errorx.Chain(
func() error {
return gogen.GenTemplates(ctx)
},
func() error {
return modelgen.GenTemplates(ctx)
},
func() error {
return rpcgen.GenTemplates(ctx)
},
); err != nil {
return err
}
dir, err := util.GetTemplateDir(templateParentPath)
if err != nil {
return err
}
fmt.Printf("Templates are generated in %s, %s\n", aurora.Green(dir),
aurora.Red("edit on your risk!"))
return nil
}
func CleanTemplates(_ *cli.Context) error {
err := errorx.Chain(
func() error {
return gogen.Clean()
},
func() error {
return modelgen.Clean()
},
func() error {
return rpcgen.Clean()
},
)
if err != nil {
return err
}
fmt.Printf("%s\n", aurora.Green("template are clean!"))
return nil
}
func UpdateTemplates(ctx *cli.Context) (err error) {
category := ctx.String("category")
defer func() {
if err == nil {
fmt.Println(aurora.Green(fmt.Sprintf("%s template are update!", category)).String())
}
}()
switch category {
case gogen.GetCategory():
return gogen.Update(category)
case rpcgen.GetCategory():
return rpcgen.Update(category)
case modelgen.GetCategory():
return modelgen.Update(category)
default:
err = fmt.Errorf("unexpected category: %s", category)
return
}
}
func RevertTemplates(ctx *cli.Context) (err error) {
category := ctx.String("category")
filename := ctx.String("name")
defer func() {
if err == nil {
fmt.Println(aurora.Green(fmt.Sprintf("%s template are reverted!", filename)).String())
}
}()
switch category {
case gogen.GetCategory():
return gogen.RevertTemplate(filename)
case rpcgen.GetCategory():
return rpcgen.RevertTemplate(filename)
case modelgen.GetCategory():
return modelgen.RevertTemplate(filename)
default:
err = fmt.Errorf("unexpected category: %s", category)
return
}
}

View File

@@ -1,47 +1,65 @@
package templatex package util
import ( import (
"fmt"
"io/ioutil" "io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
"github.com/logrusorgru/aurora"
"github.com/tal-tech/go-zero/tools/goctl/util"
) )
const goctlDir = ".goctl" const goctlDir = ".goctl"
func GetTemplateDir(category string) (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
return filepath.Join(home, goctlDir, category), nil
}
func InitTemplates(category string, templates map[string]string) error { func InitTemplates(category string, templates map[string]string) error {
dir, err := getTemplateDir(category) dir, err := GetTemplateDir(category)
if err != nil { if err != nil {
return err return err
} }
if err := util.MkdirIfNotExist(dir); err != nil { if err := MkdirIfNotExist(dir); err != nil {
return err return err
} }
for k, v := range templates { for k, v := range templates {
if err := createTemplate(filepath.Join(dir, k), v); err != nil { if err := createTemplate(filepath.Join(dir, k), v, false); err != nil {
return err return err
} }
} }
fmt.Printf("Templates are generated in %s, %s\n", aurora.Green(dir),
aurora.Red("edit on your risk!"))
return nil return nil
} }
func CreateTemplate(category, name, content string) error {
dir, err := GetTemplateDir(category)
if err != nil {
return err
}
return createTemplate(filepath.Join(dir, name), content, true)
}
func Clean(category string) error {
dir, err := GetTemplateDir(category)
if err != nil {
return err
}
return os.RemoveAll(dir)
}
func LoadTemplate(category, file, builtin string) (string, error) { func LoadTemplate(category, file, builtin string) (string, error) {
dir, err := getTemplateDir(category) dir, err := GetTemplateDir(category)
if err != nil { if err != nil {
return "", err return "", err
} }
file = filepath.Join(dir, file) file = filepath.Join(dir, file)
if !util.FileExists(file) { if !FileExists(file) {
return builtin, nil return builtin, nil
} }
@@ -53,9 +71,8 @@ func LoadTemplate(category, file, builtin string) (string, error) {
return string(content), nil return string(content), nil
} }
func createTemplate(file, content string) error { func createTemplate(file, content string, force bool) error {
if util.FileExists(file) { if FileExists(file) && !force {
println(1)
return nil return nil
} }
@@ -68,12 +85,3 @@ func createTemplate(file, content string) error {
_, err = f.WriteString(content) _, err = f.WriteString(content)
return err return err
} }
func getTemplateDir(category string) (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
return filepath.Join(home, goctlDir, category), nil
}

View File

@@ -1,4 +1,4 @@
package templatex package util
var headTemplate = `// Code generated by goctl. DO NOT EDIT! var headTemplate = `// Code generated by goctl. DO NOT EDIT!
// Source: {{.source}}` // Source: {{.source}}`

View File

@@ -1,12 +1,10 @@
package templatex package util
import ( import (
"bytes" "bytes"
goformat "go/format" goformat "go/format"
"io/ioutil" "io/ioutil"
"text/template" "text/template"
"github.com/tal-tech/go-zero/tools/goctl/util"
) )
const regularPerm = 0666 const regularPerm = 0666
@@ -34,7 +32,7 @@ func (t *defaultTemplate) GoFmt(format bool) *defaultTemplate {
} }
func (t *defaultTemplate) SaveTo(data interface{}, path string, forceUpdate bool) error { func (t *defaultTemplate) SaveTo(data interface{}, path string, forceUpdate bool) error {
if util.FileExists(path) && !forceUpdate { if FileExists(path) && !forceUpdate {
return nil return nil
} }

View File

@@ -11,15 +11,15 @@ import (
) )
var ( var (
WithDialOption = internal.WithDialOption WithDialOption = internal.WithDialOption
WithTimeout = internal.WithTimeout WithTimeout = internal.WithTimeout
WithUnaryClientInterceptor = internal.WithUnaryClientInterceptor
) )
type ( type (
ClientOption = internal.ClientOption ClientOption = internal.ClientOption
Client interface { Client interface {
AddInterceptor(interceptor grpc.UnaryClientInterceptor)
Conn() *grpc.ClientConn Conn() *grpc.ClientConn
} }
@@ -66,8 +66,8 @@ func NewClient(c RpcClientConf, options ...ClientOption) (Client, error) {
}, nil }, nil
} }
func NewClientNoAuth(c discov.EtcdConf) (Client, error) { func NewClientNoAuth(c discov.EtcdConf, opts ...ClientOption) (Client, error) {
client, err := internal.NewClient(internal.BuildDiscovTarget(c.Hosts, c.Key)) client, err := internal.NewClient(internal.BuildDiscovTarget(c.Hosts, c.Key), opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -81,10 +81,6 @@ func NewClientWithTarget(target string, opts ...ClientOption) (Client, error) {
return internal.NewClient(target, opts...) return internal.NewClient(target, opts...)
} }
func (rc *RpcClient) AddInterceptor(interceptor grpc.UnaryClientInterceptor) {
rc.client.AddInterceptor(interceptor)
}
func (rc *RpcClient) Conn() *grpc.ClientConn { func (rc *RpcClient) Conn() *grpc.ClientConn {
return rc.client.Conn() return rc.client.Conn()
} }

View File

@@ -60,14 +60,26 @@ func TestDepositServer_Deposit(t *testing.T) {
}, },
} }
directClient := MustNewClient(RpcClientConf{ directClient := MustNewClient(
Endpoints: []string{"foo"}, RpcClientConf{
App: "foo", Endpoints: []string{"foo"},
Token: "bar", App: "foo",
Timeout: 1000, Token: "bar",
}, WithDialOption(grpc.WithInsecure()), WithDialOption(grpc.WithContextDialer(dialer()))) Timeout: 1000,
},
WithDialOption(grpc.WithInsecure()),
WithDialOption(grpc.WithContextDialer(dialer())),
WithUnaryClientInterceptor(func(ctx context.Context, method string, req, reply interface{},
cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
return invoker(ctx, method, req, reply, cc, opts...)
}),
)
targetClient, err := NewClientWithTarget("foo", WithDialOption(grpc.WithInsecure()), targetClient, err := NewClientWithTarget("foo", WithDialOption(grpc.WithInsecure()),
WithDialOption(grpc.WithContextDialer(dialer()))) WithDialOption(grpc.WithContextDialer(dialer())), WithUnaryClientInterceptor(
func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
return invoker(ctx, method, req, reply, cc, opts...)
}))
assert.Nil(t, err) assert.Nil(t, err)
clients := []Client{ clients := []Client{
directClient, directClient,

View File

@@ -31,14 +31,13 @@ type (
ClientOption func(options *ClientOptions) ClientOption func(options *ClientOptions)
client struct { client struct {
conn *grpc.ClientConn conn *grpc.ClientConn
interceptors []grpc.UnaryClientInterceptor
} }
) )
func NewClient(target string, opts ...ClientOption) (*client, error) { func NewClient(target string, opts ...ClientOption) (*client, error) {
var cli client var cli client
opts = append(opts, WithDialOption(grpc.WithBalancerName(p2c.Name))) opts = append([]ClientOption{WithDialOption(grpc.WithBalancerName(p2c.Name))}, opts...)
if err := cli.dial(target, opts...); err != nil { if err := cli.dial(target, opts...); err != nil {
return nil, err return nil, err
} }
@@ -46,18 +45,14 @@ func NewClient(target string, opts ...ClientOption) (*client, error) {
return &cli, nil return &cli, nil
} }
func (c *client) AddInterceptor(interceptor grpc.UnaryClientInterceptor) {
c.interceptors = append(c.interceptors, interceptor)
}
func (c *client) Conn() *grpc.ClientConn { func (c *client) Conn() *grpc.ClientConn {
return c.conn return c.conn
} }
func (c *client) buildDialOptions(opts ...ClientOption) []grpc.DialOption { func (c *client) buildDialOptions(opts ...ClientOption) []grpc.DialOption {
var clientOptions ClientOptions var cliOpts ClientOptions
for _, opt := range opts { for _, opt := range opts {
opt(&clientOptions) opt(&cliOpts)
} }
options := []grpc.DialOption{ options := []grpc.DialOption{
@@ -68,14 +63,11 @@ func (c *client) buildDialOptions(opts ...ClientOption) []grpc.DialOption {
clientinterceptors.DurationInterceptor, clientinterceptors.DurationInterceptor,
clientinterceptors.BreakerInterceptor, clientinterceptors.BreakerInterceptor,
clientinterceptors.PrometheusInterceptor, clientinterceptors.PrometheusInterceptor,
clientinterceptors.TimeoutInterceptor(clientOptions.Timeout), clientinterceptors.TimeoutInterceptor(cliOpts.Timeout),
), ),
} }
for _, interceptor := range c.interceptors {
options = append(options, WithUnaryClientInterceptors(interceptor))
}
return append(options, clientOptions.DialOptions...) return append(options, cliOpts.DialOptions...)
} }
func (c *client) dial(server string, opts ...ClientOption) error { func (c *client) dial(server string, opts ...ClientOption) error {
@@ -111,3 +103,9 @@ func WithTimeout(timeout time.Duration) ClientOption {
options.Timeout = timeout options.Timeout = timeout
} }
} }
func WithUnaryClientInterceptor(interceptor grpc.UnaryClientInterceptor) ClientOption {
return func(options *ClientOptions) {
options.DialOptions = append(options.DialOptions, WithUnaryClientInterceptors(interceptor))
}
}

View File

@@ -1,6 +1,7 @@
package internal package internal
import ( import (
"context"
"testing" "testing"
"time" "time"
@@ -23,6 +24,16 @@ func TestWithTimeout(t *testing.T) {
assert.Equal(t, time.Second, options.Timeout) assert.Equal(t, time.Second, options.Timeout)
} }
func TestWithUnaryClientInterceptor(t *testing.T) {
var options ClientOptions
opt := WithUnaryClientInterceptor(func(ctx context.Context, method string, req, reply interface{},
cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
return nil
})
opt(&options)
assert.Equal(t, 1, len(options.DialOptions))
}
func TestBuildDialOptions(t *testing.T) { func TestBuildDialOptions(t *testing.T) {
var c client var c client
agent := grpc.WithUserAgent("chrome") agent := grpc.WithUserAgent("chrome")