Compare commits

...

34 Commits

Author SHA1 Message Date
kevin
95aa65efb9 add dockerfile generator 2020-11-08 21:28:58 +08:00
kevin
3806e66cf1 simplify http server starter 2020-11-08 13:17:14 +08:00
kevin
bd430baf52 graceful shutdown refined 2020-11-08 13:08:00 +08:00
Keson
48f4154ea8 update doc (#193) 2020-11-08 13:02:48 +08:00
super_mario
2599e0d28d Close the process when shutdown is finished (#157)
Co-authored-by: Kevin Wan <wanjunfeng@gmail.com>
2020-11-08 12:50:58 +08:00
kingxt
12327fa07d break generator when happen error (#192)
Co-authored-by: kim <xutao@xiaoheiban.cn>
2020-11-07 21:25:52 +08:00
kevin
57079bf4a4 update cli package 2020-11-07 20:01:25 +08:00
kingxt
7f6eceb5a3 add more test (#189)
* new test

* import bug when with quotation

* new test

* add test condition

* rpc template command use -o param

Co-authored-by: kim <xutao@xiaoheiban.cn>
2020-11-07 17:13:40 +08:00
kevin
7d7cb836af fix issue #186 2020-11-06 12:25:48 +08:00
kevin
f87d9d1dda refine code style 2020-11-06 12:13:28 +08:00
Keson
856b5aadb1 rpc generation fix (#184)
* reactor alert

* optimize

* add test case

* update the target directory in case proto contains option

* fix missing comments and format code
2020-11-05 19:08:34 +08:00
Keson
f7d778e0ed fix duplicate alias (#183) 2020-11-05 18:12:23 +08:00
kevin
88333ee77f faster the tests 2020-11-05 16:04:00 +08:00
Keson
e76f44a35b reactor rpc (#179)
* reactor rpc generation

* update flag

* update command

* update command

* update unit test

* delete test file

* optimize code

* update doc

* update gen pb

* rename target dir

* update mysql data type convert rule

* add done flag

* optimize req/reply parameter

* optimize req/reply parameter

* remove waste code

* remove duplicate parameter

* format code

* format code

* optimize naming

* reactor rpcv2 to rpc

* remove new line

* format code

* rename underline to snake

* reactor getParentPackage

* remove debug log

* reactor background
2020-11-05 14:12:47 +08:00
kevin
c9ec22d5f4 add https listen and serve 2020-11-05 11:56:40 +08:00
Dashuang Li
afffc1048b fix url 404 (#180) 2020-11-04 12:03:07 +08:00
kevin
d0b76b1d9a move redistest into redis package 2020-11-03 16:35:34 +08:00
kevin
b004b070d7 refine tests 2020-11-02 17:51:33 +08:00
kevin
677d581bd1 update doc 2020-11-02 17:05:09 +08:00
kingxt
b776468e69 route support no request and response (#178)
* add more test and support no request and response

* fix slash when run on windows

* optimize test
2020-11-02 13:48:16 +08:00
kevin
4c9315e984 add more tests 2020-10-31 22:10:11 +08:00
kevin
668a7011c4 add more tests 2020-10-31 20:11:12 +08:00
吴亲库里
cc07a1d69b Update sharedcalls.go (#174)
Removes unused parameters
2020-10-31 19:40:07 +08:00
kevin
7f99a3baa8 add gitee url 2020-10-31 13:58:33 +08:00
kevin
9504418462 update doc 2020-10-31 12:41:29 +08:00
kevin
b144a2335c update bookstore example for generation prototype 2020-10-31 11:42:44 +08:00
kevin
7b9ed7a313 update doc 2020-10-30 15:20:19 +08:00
kevin
3d2e9fcb84 remove wechat image 2020-10-30 11:57:32 +08:00
kevin
2b993424c1 update wechat qrcode 2020-10-30 11:54:06 +08:00
kevin
5e87b33b23 support https in rest 2020-10-29 17:44:51 +08:00
kevin
9b7cc43dcb update wechat qrcode 2020-10-29 15:32:08 +08:00
kevin
000b28cf84 update readme 2020-10-29 11:31:35 +08:00
kevin
9fd16cd278 add images back because of gitee not showing 2020-10-29 11:27:40 +08:00
kevin
b71429e16b add images back because of gitee not showing 2020-10-29 11:26:10 +08:00
151 changed files with 3397 additions and 3316 deletions

View File

@@ -2,20 +2,16 @@ package bloom
import ( import (
"testing" "testing"
"time"
"github.com/alicebob/miniredis"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/lang" "github.com/tal-tech/go-zero/core/stores/redis/redistest"
"github.com/tal-tech/go-zero/core/stores/redis"
) )
func TestRedisBitSet_New_Set_Test(t *testing.T) { func TestRedisBitSet_New_Set_Test(t *testing.T) {
s, clean, err := createMiniRedis() store, clean, err := redistest.CreateRedis()
assert.Nil(t, err) assert.Nil(t, err)
defer clean() defer clean()
store := redis.NewRedis(s.Addr(), redis.NodeType)
bitSet := newRedisBitSet(store, "test_key", 1024) bitSet := newRedisBitSet(store, "test_key", 1024)
isSetBefore, err := bitSet.check([]uint{0}) isSetBefore, err := bitSet.check([]uint{0})
if err != nil { if err != nil {
@@ -46,11 +42,10 @@ func TestRedisBitSet_New_Set_Test(t *testing.T) {
} }
func TestRedisBitSet_Add(t *testing.T) { func TestRedisBitSet_Add(t *testing.T) {
s, clean, err := createMiniRedis() store, clean, err := redistest.CreateRedis()
assert.Nil(t, err) assert.Nil(t, err)
defer clean() defer clean()
store := redis.NewRedis(s.Addr(), redis.NodeType)
filter := New(store, "test_key", 64) filter := New(store, "test_key", 64)
assert.Nil(t, filter.Add([]byte("hello"))) assert.Nil(t, filter.Add([]byte("hello")))
assert.Nil(t, filter.Add([]byte("world"))) assert.Nil(t, filter.Add([]byte("world")))
@@ -58,22 +53,3 @@ func TestRedisBitSet_Add(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, ok) assert.True(t, ok)
} }
func createMiniRedis() (r *miniredis.Miniredis, clean func(), err error) {
r, err = miniredis.Run()
if err != nil {
return nil, nil, err
}
return r, func() {
ch := make(chan lang.PlaceholderType)
go func() {
r.Close()
close(ch)
}()
select {
case <-ch:
case <-time.After(time.Second):
}
}, nil
}

View File

@@ -82,6 +82,7 @@ func (pe *PeriodicalExecutor) Sync(fn func()) {
} }
func (pe *PeriodicalExecutor) Wait() { func (pe *PeriodicalExecutor) Wait() {
pe.Flush()
pe.wgBarrier.Guard(func() { pe.wgBarrier.Guard(func() {
pe.waitGroup.Wait() pe.waitGroup.Wait()
}) })

View File

@@ -6,6 +6,7 @@ import (
"github.com/alicebob/miniredis" "github.com/alicebob/miniredis"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/stores/redis" "github.com/tal-tech/go-zero/core/stores/redis"
"github.com/tal-tech/go-zero/core/stores/redis/redistest"
) )
func TestPeriodLimit_Take(t *testing.T) { func TestPeriodLimit_Take(t *testing.T) {
@@ -33,16 +34,16 @@ func TestPeriodLimit_RedisUnavailable(t *testing.T) {
} }
func testPeriodLimit(t *testing.T, opts ...LimitOption) { func testPeriodLimit(t *testing.T, opts ...LimitOption) {
s, err := miniredis.Run() store, clean, err := redistest.CreateRedis()
assert.Nil(t, err) assert.Nil(t, err)
defer s.Close() defer clean()
const ( const (
seconds = 1 seconds = 1
total = 100 total = 100
quota = 5 quota = 5
) )
l := NewPeriodLimit(seconds, quota, redis.NewRedis(s.Addr(), redis.NodeType), "periodlimit", opts...) l := NewPeriodLimit(seconds, quota, store, "periodlimit", opts...)
var allowed, hitQuota, overQuota int var allowed, hitQuota, overQuota int
for i := 0; i < total; i++ { for i := 0; i < total; i++ {
val, err := l.Take("first") val, err := l.Take("first")

View File

@@ -8,6 +8,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/logx" "github.com/tal-tech/go-zero/core/logx"
"github.com/tal-tech/go-zero/core/stores/redis" "github.com/tal-tech/go-zero/core/stores/redis"
"github.com/tal-tech/go-zero/core/stores/redis/redistest"
) )
func init() { func init() {
@@ -44,16 +45,16 @@ func TestTokenLimit_Rescue(t *testing.T) {
} }
func TestTokenLimit_Take(t *testing.T) { func TestTokenLimit_Take(t *testing.T) {
s, err := miniredis.Run() store, clean, err := redistest.CreateRedis()
assert.Nil(t, err) assert.Nil(t, err)
defer s.Close() defer clean()
const ( const (
total = 100 total = 100
rate = 5 rate = 5
burst = 10 burst = 10
) )
l := NewTokenLimiter(rate, burst, redis.NewRedis(s.Addr(), redis.NodeType), "tokenlimit") l := NewTokenLimiter(rate, burst, store, "tokenlimit")
var allowed int var allowed int
for i := 0; i < total; i++ { for i := 0; i < total; i++ {
time.Sleep(time.Second / time.Duration(total)) time.Sleep(time.Second / time.Duration(total))
@@ -66,16 +67,16 @@ func TestTokenLimit_Take(t *testing.T) {
} }
func TestTokenLimit_TakeBurst(t *testing.T) { func TestTokenLimit_TakeBurst(t *testing.T) {
s, err := miniredis.Run() store, clean, err := redistest.CreateRedis()
assert.Nil(t, err) assert.Nil(t, err)
defer s.Close() defer clean()
const ( const (
total = 100 total = 100
rate = 5 rate = 5
burst = 10 burst = 10
) )
l := NewTokenLimiter(rate, burst, redis.NewRedis(s.Addr(), redis.NodeType), "tokenlimit") l := NewTokenLimiter(rate, burst, store, "tokenlimit")
var allowed int var allowed int
for i := 0; i < total; i++ { for i := 0; i < total; i++ {
if l.Allow() { if l.Allow() {

View File

@@ -12,6 +12,7 @@ import (
"github.com/tal-tech/go-zero/core/errorx" "github.com/tal-tech/go-zero/core/errorx"
"github.com/tal-tech/go-zero/core/hash" "github.com/tal-tech/go-zero/core/hash"
"github.com/tal-tech/go-zero/core/stores/redis" "github.com/tal-tech/go-zero/core/stores/redis"
"github.com/tal-tech/go-zero/core/stores/redis/redistest"
"github.com/tal-tech/go-zero/core/syncx" "github.com/tal-tech/go-zero/core/syncx"
) )
@@ -75,23 +76,23 @@ func (mc *mockedNode) TakeWithExpire(v interface{}, key string, query func(v int
func TestCache_SetDel(t *testing.T) { func TestCache_SetDel(t *testing.T) {
const total = 1000 const total = 1000
r1, clean1, err := createMiniRedis() r1, clean1, err := redistest.CreateRedis()
assert.Nil(t, err) assert.Nil(t, err)
defer clean1() defer clean1()
r2, clean2, err := createMiniRedis() r2, clean2, err := redistest.CreateRedis()
assert.Nil(t, err) assert.Nil(t, err)
defer clean2() defer clean2()
conf := ClusterConf{ conf := ClusterConf{
{ {
RedisConf: redis.RedisConf{ RedisConf: redis.RedisConf{
Host: r1.Addr(), Host: r1.Addr,
Type: redis.NodeType, Type: redis.NodeType,
}, },
Weight: 100, Weight: 100,
}, },
{ {
RedisConf: redis.RedisConf{ RedisConf: redis.RedisConf{
Host: r2.Addr(), Host: r2.Addr,
Type: redis.NodeType, Type: redis.NodeType,
}, },
Weight: 100, Weight: 100,
@@ -123,13 +124,13 @@ func TestCache_SetDel(t *testing.T) {
func TestCache_OneNode(t *testing.T) { func TestCache_OneNode(t *testing.T) {
const total = 1000 const total = 1000
r, clean, err := createMiniRedis() r, clean, err := redistest.CreateRedis()
assert.Nil(t, err) assert.Nil(t, err)
defer clean() defer clean()
conf := ClusterConf{ conf := ClusterConf{
{ {
RedisConf: redis.RedisConf{ RedisConf: redis.RedisConf{
Host: r.Addr(), Host: r.Addr,
Type: redis.NodeType, Type: redis.NodeType,
}, },
Weight: 100, Weight: 100,

View File

@@ -175,12 +175,12 @@ func (c cacheNode) doTake(v interface{}, key string, query func(v interface{}) e
} }
if fresh { if fresh {
return nil return nil
} else {
// got the result from previous ongoing query
c.stat.IncrementTotal()
c.stat.IncrementHit()
} }
// got the result from previous ongoing query
c.stat.IncrementTotal()
c.stat.IncrementHit()
return jsonx.Unmarshal(val.([]byte), v) return jsonx.Unmarshal(val.([]byte), v)
} }

View File

@@ -15,6 +15,7 @@ import (
"github.com/tal-tech/go-zero/core/mathx" "github.com/tal-tech/go-zero/core/mathx"
"github.com/tal-tech/go-zero/core/stat" "github.com/tal-tech/go-zero/core/stat"
"github.com/tal-tech/go-zero/core/stores/redis" "github.com/tal-tech/go-zero/core/stores/redis"
"github.com/tal-tech/go-zero/core/stores/redis/redistest"
"github.com/tal-tech/go-zero/core/syncx" "github.com/tal-tech/go-zero/core/syncx"
) )
@@ -26,12 +27,12 @@ func init() {
} }
func TestCacheNode_DelCache(t *testing.T) { func TestCacheNode_DelCache(t *testing.T) {
s, clean, err := createMiniRedis() store, clean, err := redistest.CreateRedis()
assert.Nil(t, err) assert.Nil(t, err)
defer clean() defer clean()
cn := cacheNode{ cn := cacheNode{
rds: redis.NewRedis(s.Addr(), redis.NodeType), rds: store,
r: rand.New(rand.NewSource(time.Now().UnixNano())), r: rand.New(rand.NewSource(time.Now().UnixNano())),
lock: new(sync.Mutex), lock: new(sync.Mutex),
unstableExpiry: mathx.NewUnstable(expiryDeviation), unstableExpiry: mathx.NewUnstable(expiryDeviation),
@@ -49,9 +50,9 @@ func TestCacheNode_DelCache(t *testing.T) {
} }
func TestCacheNode_InvalidCache(t *testing.T) { func TestCacheNode_InvalidCache(t *testing.T) {
s, clean, err := createMiniRedis() s, err := miniredis.Run()
assert.Nil(t, err) assert.Nil(t, err)
defer clean() defer s.Close()
cn := cacheNode{ cn := cacheNode{
rds: redis.NewRedis(s.Addr(), redis.NodeType), rds: redis.NewRedis(s.Addr(), redis.NodeType),
@@ -70,12 +71,12 @@ func TestCacheNode_InvalidCache(t *testing.T) {
} }
func TestCacheNode_Take(t *testing.T) { func TestCacheNode_Take(t *testing.T) {
s, clean, err := createMiniRedis() store, clean, err := redistest.CreateRedis()
assert.Nil(t, err) assert.Nil(t, err)
defer clean() defer clean()
cn := cacheNode{ cn := cacheNode{
rds: redis.NewRedis(s.Addr(), redis.NodeType), rds: store,
r: rand.New(rand.NewSource(time.Now().UnixNano())), r: rand.New(rand.NewSource(time.Now().UnixNano())),
barrier: syncx.NewSharedCalls(), barrier: syncx.NewSharedCalls(),
lock: new(sync.Mutex), lock: new(sync.Mutex),
@@ -91,18 +92,18 @@ func TestCacheNode_Take(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "value", str) assert.Equal(t, "value", str)
assert.Nil(t, cn.GetCache("any", &str)) assert.Nil(t, cn.GetCache("any", &str))
val, err := s.Get("any") val, err := store.Get("any")
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, `"value"`, val) assert.Equal(t, `"value"`, val)
} }
func TestCacheNode_TakeNotFound(t *testing.T) { func TestCacheNode_TakeNotFound(t *testing.T) {
s, clean, err := createMiniRedis() store, clean, err := redistest.CreateRedis()
assert.Nil(t, err) assert.Nil(t, err)
defer clean() defer clean()
cn := cacheNode{ cn := cacheNode{
rds: redis.NewRedis(s.Addr(), redis.NodeType), rds: store,
r: rand.New(rand.NewSource(time.Now().UnixNano())), r: rand.New(rand.NewSource(time.Now().UnixNano())),
barrier: syncx.NewSharedCalls(), barrier: syncx.NewSharedCalls(),
lock: new(sync.Mutex), lock: new(sync.Mutex),
@@ -116,18 +117,18 @@ func TestCacheNode_TakeNotFound(t *testing.T) {
}) })
assert.Equal(t, errTestNotFound, err) assert.Equal(t, errTestNotFound, err)
assert.Equal(t, errTestNotFound, cn.GetCache("any", &str)) assert.Equal(t, errTestNotFound, cn.GetCache("any", &str))
val, err := s.Get("any") val, err := store.Get("any")
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, `*`, val) assert.Equal(t, `*`, val)
s.Set("any", "*") store.Set("any", "*")
err = cn.Take(&str, "any", func(v interface{}) error { err = cn.Take(&str, "any", func(v interface{}) error {
return nil return nil
}) })
assert.Equal(t, errTestNotFound, err) assert.Equal(t, errTestNotFound, err)
assert.Equal(t, errTestNotFound, cn.GetCache("any", &str)) assert.Equal(t, errTestNotFound, cn.GetCache("any", &str))
s.Del("any") store.Del("any")
var errDummy = errors.New("dummy") var errDummy = errors.New("dummy")
err = cn.Take(&str, "any", func(v interface{}) error { err = cn.Take(&str, "any", func(v interface{}) error {
return errDummy return errDummy
@@ -136,12 +137,12 @@ func TestCacheNode_TakeNotFound(t *testing.T) {
} }
func TestCacheNode_TakeWithExpire(t *testing.T) { func TestCacheNode_TakeWithExpire(t *testing.T) {
s, clean, err := createMiniRedis() store, clean, err := redistest.CreateRedis()
assert.Nil(t, err) assert.Nil(t, err)
defer clean() defer clean()
cn := cacheNode{ cn := cacheNode{
rds: redis.NewRedis(s.Addr(), redis.NodeType), rds: store,
r: rand.New(rand.NewSource(time.Now().UnixNano())), r: rand.New(rand.NewSource(time.Now().UnixNano())),
barrier: syncx.NewSharedCalls(), barrier: syncx.NewSharedCalls(),
lock: new(sync.Mutex), lock: new(sync.Mutex),
@@ -157,18 +158,18 @@ func TestCacheNode_TakeWithExpire(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "value", str) assert.Equal(t, "value", str)
assert.Nil(t, cn.GetCache("any", &str)) assert.Nil(t, cn.GetCache("any", &str))
val, err := s.Get("any") val, err := store.Get("any")
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, `"value"`, val) assert.Equal(t, `"value"`, val)
} }
func TestCacheNode_String(t *testing.T) { func TestCacheNode_String(t *testing.T) {
s, clean, err := createMiniRedis() store, clean, err := redistest.CreateRedis()
assert.Nil(t, err) assert.Nil(t, err)
defer clean() defer clean()
cn := cacheNode{ cn := cacheNode{
rds: redis.NewRedis(s.Addr(), redis.NodeType), rds: store,
r: rand.New(rand.NewSource(time.Now().UnixNano())), r: rand.New(rand.NewSource(time.Now().UnixNano())),
barrier: syncx.NewSharedCalls(), barrier: syncx.NewSharedCalls(),
lock: new(sync.Mutex), lock: new(sync.Mutex),
@@ -176,16 +177,16 @@ func TestCacheNode_String(t *testing.T) {
stat: NewCacheStat("any"), stat: NewCacheStat("any"),
errNotFound: errors.New("any"), errNotFound: errors.New("any"),
} }
assert.Equal(t, s.Addr(), cn.String()) assert.Equal(t, store.Addr, cn.String())
} }
func TestCacheValueWithBigInt(t *testing.T) { func TestCacheValueWithBigInt(t *testing.T) {
s, clean, err := createMiniRedis() store, clean, err := redistest.CreateRedis()
assert.Nil(t, err) assert.Nil(t, err)
defer clean() defer clean()
cn := cacheNode{ cn := cacheNode{
rds: redis.NewRedis(s.Addr(), redis.NodeType), rds: store,
r: rand.New(rand.NewSource(time.Now().UnixNano())), r: rand.New(rand.NewSource(time.Now().UnixNano())),
barrier: syncx.NewSharedCalls(), barrier: syncx.NewSharedCalls(),
lock: new(sync.Mutex), lock: new(sync.Mutex),

View File

@@ -2,11 +2,8 @@ package cache
import ( import (
"testing" "testing"
"time"
"github.com/alicebob/miniredis"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/lang"
) )
func TestFormatKeys(t *testing.T) { func TestFormatKeys(t *testing.T) {
@@ -27,22 +24,3 @@ func TestTotalWeights(t *testing.T) {
}) })
assert.Equal(t, 1, val) assert.Equal(t, 1, val)
} }
func createMiniRedis() (r *miniredis.Miniredis, clean func(), err error) {
r, err = miniredis.Run()
if err != nil {
return nil, nil, err
}
return r, func() {
ch := make(chan lang.PlaceholderType)
go func() {
r.Close()
close(ch)
}()
select {
case <-ch:
case <-time.After(time.Second):
}
}, nil
}

View File

@@ -11,7 +11,6 @@ import (
"testing" "testing"
"time" "time"
"github.com/alicebob/miniredis"
"github.com/globalsign/mgo" "github.com/globalsign/mgo"
"github.com/globalsign/mgo/bson" "github.com/globalsign/mgo/bson"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -19,6 +18,7 @@ import (
"github.com/tal-tech/go-zero/core/stores/cache" "github.com/tal-tech/go-zero/core/stores/cache"
"github.com/tal-tech/go-zero/core/stores/mongo" "github.com/tal-tech/go-zero/core/stores/mongo"
"github.com/tal-tech/go-zero/core/stores/redis" "github.com/tal-tech/go-zero/core/stores/redis"
"github.com/tal-tech/go-zero/core/stores/redis/redistest"
) )
func init() { func init() {
@@ -27,12 +27,10 @@ func init() {
func TestStat(t *testing.T) { func TestStat(t *testing.T) {
resetStats() resetStats()
s, err := miniredis.Run() r, clean, err := redistest.CreateRedis()
if err != nil { assert.Nil(t, err)
t.Error(err) defer clean()
}
r := redis.NewRedis(s.Addr(), redis.NodeType)
cach := cache.NewCacheNode(r, sharedCalls, stats, mgo.ErrNotFound) cach := cache.NewCacheNode(r, sharedCalls, stats, mgo.ErrNotFound)
c := newCollection(dummyConn{}, cach) c := newCollection(dummyConn{}, cach)
@@ -73,12 +71,10 @@ func TestStatCacheFails(t *testing.T) {
func TestStatDbFails(t *testing.T) { func TestStatDbFails(t *testing.T) {
resetStats() resetStats()
s, err := miniredis.Run() r, clean, err := redistest.CreateRedis()
if err != nil { assert.Nil(t, err)
t.Error(err) defer clean()
}
r := redis.NewRedis(s.Addr(), redis.NodeType)
cach := cache.NewCacheNode(r, sharedCalls, stats, mgo.ErrNotFound) cach := cache.NewCacheNode(r, sharedCalls, stats, mgo.ErrNotFound)
c := newCollection(dummyConn{}, cach) c := newCollection(dummyConn{}, cach)
@@ -97,12 +93,10 @@ func TestStatDbFails(t *testing.T) {
func TestStatFromMemory(t *testing.T) { func TestStatFromMemory(t *testing.T) {
resetStats() resetStats()
s, err := miniredis.Run() r, clean, err := redistest.CreateRedis()
if err != nil { assert.Nil(t, err)
t.Error(err) defer clean()
}
r := redis.NewRedis(s.Addr(), redis.NodeType)
cach := cache.NewCacheNode(r, sharedCalls, stats, mgo.ErrNotFound) cach := cache.NewCacheNode(r, sharedCalls, stats, mgo.ErrNotFound)
c := newCollection(dummyConn{}, cach) c := newCollection(dummyConn{}, cach)

View File

@@ -0,0 +1,28 @@
package redistest
import (
"time"
"github.com/alicebob/miniredis"
"github.com/tal-tech/go-zero/core/lang"
"github.com/tal-tech/go-zero/core/stores/redis"
)
func CreateRedis() (r *redis.Redis, clean func(), err error) {
mr, err := miniredis.Run()
if err != nil {
return nil, nil, err
}
return redis.NewRedis(mr.Addr(), redis.NodeType), func() {
ch := make(chan lang.PlaceholderType)
go func() {
mr.Close()
close(ch)
}()
select {
case <-ch:
case <-time.After(time.Second):
}
}, nil
}

View File

@@ -16,11 +16,12 @@ import (
"github.com/alicebob/miniredis" "github.com/alicebob/miniredis"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/lang" "github.com/tal-tech/go-zero/core/fx"
"github.com/tal-tech/go-zero/core/logx" "github.com/tal-tech/go-zero/core/logx"
"github.com/tal-tech/go-zero/core/stat" "github.com/tal-tech/go-zero/core/stat"
"github.com/tal-tech/go-zero/core/stores/cache" "github.com/tal-tech/go-zero/core/stores/cache"
"github.com/tal-tech/go-zero/core/stores/redis" "github.com/tal-tech/go-zero/core/stores/redis"
"github.com/tal-tech/go-zero/core/stores/redis/redistest"
"github.com/tal-tech/go-zero/core/stores/sqlx" "github.com/tal-tech/go-zero/core/stores/sqlx"
) )
@@ -31,16 +32,15 @@ func init() {
func TestCachedConn_GetCache(t *testing.T) { func TestCachedConn_GetCache(t *testing.T) {
resetStats() resetStats()
s, clean, err := createMiniRedis() r, clean, err := redistest.CreateRedis()
assert.Nil(t, err) assert.Nil(t, err)
defer clean() defer clean()
r := redis.NewRedis(s.Addr(), redis.NodeType)
c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10)) c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10))
var value string var value string
err = c.GetCache("any", &value) err = c.GetCache("any", &value)
assert.Equal(t, ErrNotFound, err) assert.Equal(t, ErrNotFound, err)
s.Set("any", `"value"`) r.Set("any", `"value"`)
err = c.GetCache("any", &value) err = c.GetCache("any", &value)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "value", value) assert.Equal(t, "value", value)
@@ -48,11 +48,10 @@ func TestCachedConn_GetCache(t *testing.T) {
func TestStat(t *testing.T) { func TestStat(t *testing.T) {
resetStats() resetStats()
s, clean, err := createMiniRedis() r, clean, err := redistest.CreateRedis()
assert.Nil(t, err) assert.Nil(t, err)
defer clean() defer clean()
r := redis.NewRedis(s.Addr(), redis.NodeType)
c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10)) c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10))
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
@@ -72,15 +71,14 @@ func TestStat(t *testing.T) {
func TestCachedConn_QueryRowIndex_NoCache(t *testing.T) { func TestCachedConn_QueryRowIndex_NoCache(t *testing.T) {
resetStats() resetStats()
s, clean, err := createMiniRedis() r, clean, err := redistest.CreateRedis()
assert.Nil(t, err) assert.Nil(t, err)
defer clean() defer clean()
r := redis.NewRedis(s.Addr(), redis.NodeType)
c := NewConn(dummySqlConn{}, cache.CacheConf{ c := NewConn(dummySqlConn{}, cache.CacheConf{
{ {
RedisConf: redis.RedisConf{ RedisConf: redis.RedisConf{
Host: s.Addr(), Host: r.Addr,
Type: redis.NodeType, Type: redis.NodeType,
}, },
Weight: 100, Weight: 100,
@@ -122,11 +120,10 @@ func TestCachedConn_QueryRowIndex_NoCache(t *testing.T) {
func TestCachedConn_QueryRowIndex_HasCache(t *testing.T) { func TestCachedConn_QueryRowIndex_HasCache(t *testing.T) {
resetStats() resetStats()
s, clean, err := createMiniRedis() r, clean, err := redistest.CreateRedis()
assert.Nil(t, err) assert.Nil(t, err)
defer clean() defer clean()
r := redis.NewRedis(s.Addr(), redis.NodeType)
c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10), c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10),
cache.WithNotFoundExpiry(time.Second)) cache.WithNotFoundExpiry(time.Second))
@@ -210,16 +207,13 @@ func TestCachedConn_QueryRowIndex_HasCache_IntPrimary(t *testing.T) {
}, },
} }
s, clean, err := createMiniRedis()
assert.Nil(t, err)
defer clean()
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
resetStats() resetStats()
s.FlushAll() r, clean, err := redistest.CreateRedis()
assert.Nil(t, err)
defer clean()
r := redis.NewRedis(s.Addr(), redis.NodeType)
c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10), c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10),
cache.WithNotFoundExpiry(time.Second)) cache.WithNotFoundExpiry(time.Second))
@@ -256,11 +250,10 @@ func TestCachedConn_QueryRowIndex_HasWrongCache(t *testing.T) {
for k, v := range caches { for k, v := range caches {
t.Run(k+"/"+v, func(t *testing.T) { t.Run(k+"/"+v, func(t *testing.T) {
resetStats() resetStats()
s, clean, err := createMiniRedis() r, clean, err := redistest.CreateRedis()
assert.Nil(t, err) assert.Nil(t, err)
defer clean() defer clean()
r := redis.NewRedis(s.Addr(), redis.NodeType)
c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10), c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10),
cache.WithNotFoundExpiry(time.Second)) cache.WithNotFoundExpiry(time.Second))
@@ -312,11 +305,10 @@ func TestStatCacheFails(t *testing.T) {
func TestStatDbFails(t *testing.T) { func TestStatDbFails(t *testing.T) {
resetStats() resetStats()
s, clean, err := createMiniRedis() r, clean, err := redistest.CreateRedis()
assert.Nil(t, err) assert.Nil(t, err)
defer clean() defer clean()
r := redis.NewRedis(s.Addr(), redis.NodeType)
c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10)) c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10))
for i := 0; i < 20; i++ { for i := 0; i < 20; i++ {
@@ -334,11 +326,10 @@ func TestStatDbFails(t *testing.T) {
func TestStatFromMemory(t *testing.T) { func TestStatFromMemory(t *testing.T) {
resetStats() resetStats()
s, clean, err := createMiniRedis() r, clean, err := redistest.CreateRedis()
assert.Nil(t, err) assert.Nil(t, err)
defer clean() defer clean()
r := redis.NewRedis(s.Addr(), redis.NodeType)
c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10)) c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10))
var all sync.WaitGroup var all sync.WaitGroup
@@ -393,7 +384,7 @@ func TestStatFromMemory(t *testing.T) {
} }
func TestCachedConnQueryRow(t *testing.T) { func TestCachedConnQueryRow(t *testing.T) {
s, clean, err := createMiniRedis() r, clean, err := redistest.CreateRedis()
assert.Nil(t, err) assert.Nil(t, err)
defer clean() defer clean()
@@ -404,7 +395,6 @@ func TestCachedConnQueryRow(t *testing.T) {
var conn trackedConn var conn trackedConn
var user string var user string
var ran bool var ran bool
r := redis.NewRedis(s.Addr(), redis.NodeType)
c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*30)) c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*30))
err = c.QueryRow(&user, key, func(conn sqlx.SqlConn, v interface{}) error { err = c.QueryRow(&user, key, func(conn sqlx.SqlConn, v interface{}) error {
ran = true ran = true
@@ -412,7 +402,7 @@ func TestCachedConnQueryRow(t *testing.T) {
return nil return nil
}) })
assert.Nil(t, err) assert.Nil(t, err)
actualValue, err := s.Get(key) actualValue, err := r.Get(key)
assert.Nil(t, err) assert.Nil(t, err)
var actual string var actual string
assert.Nil(t, json.Unmarshal([]byte(actualValue), &actual)) assert.Nil(t, json.Unmarshal([]byte(actualValue), &actual))
@@ -422,7 +412,7 @@ func TestCachedConnQueryRow(t *testing.T) {
} }
func TestCachedConnQueryRowFromCache(t *testing.T) { func TestCachedConnQueryRowFromCache(t *testing.T) {
s, clean, err := createMiniRedis() r, clean, err := redistest.CreateRedis()
assert.Nil(t, err) assert.Nil(t, err)
defer clean() defer clean()
@@ -433,7 +423,6 @@ func TestCachedConnQueryRowFromCache(t *testing.T) {
var conn trackedConn var conn trackedConn
var user string var user string
var ran bool var ran bool
r := redis.NewRedis(s.Addr(), redis.NodeType)
c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*30)) c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*30))
assert.Nil(t, c.SetCache(key, value)) assert.Nil(t, c.SetCache(key, value))
err = c.QueryRow(&user, key, func(conn sqlx.SqlConn, v interface{}) error { err = c.QueryRow(&user, key, func(conn sqlx.SqlConn, v interface{}) error {
@@ -442,7 +431,7 @@ func TestCachedConnQueryRowFromCache(t *testing.T) {
return nil return nil
}) })
assert.Nil(t, err) assert.Nil(t, err)
actualValue, err := s.Get(key) actualValue, err := r.Get(key)
assert.Nil(t, err) assert.Nil(t, err)
var actual string var actual string
assert.Nil(t, json.Unmarshal([]byte(actualValue), &actual)) assert.Nil(t, json.Unmarshal([]byte(actualValue), &actual))
@@ -452,7 +441,7 @@ func TestCachedConnQueryRowFromCache(t *testing.T) {
} }
func TestQueryRowNotFound(t *testing.T) { func TestQueryRowNotFound(t *testing.T) {
s, clean, err := createMiniRedis() r, clean, err := redistest.CreateRedis()
assert.Nil(t, err) assert.Nil(t, err)
defer clean() defer clean()
@@ -460,7 +449,6 @@ func TestQueryRowNotFound(t *testing.T) {
var conn trackedConn var conn trackedConn
var user string var user string
var ran int var ran int
r := redis.NewRedis(s.Addr(), redis.NodeType)
c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*30)) c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*30))
for i := 0; i < 20; i++ { for i := 0; i < 20; i++ {
err = c.QueryRow(&user, key, func(conn sqlx.SqlConn, v interface{}) error { err = c.QueryRow(&user, key, func(conn sqlx.SqlConn, v interface{}) error {
@@ -473,12 +461,11 @@ func TestQueryRowNotFound(t *testing.T) {
} }
func TestCachedConnExec(t *testing.T) { func TestCachedConnExec(t *testing.T) {
s, clean, err := createMiniRedis() r, clean, err := redistest.CreateRedis()
assert.Nil(t, err) assert.Nil(t, err)
defer clean() defer clean()
var conn trackedConn var conn trackedConn
r := redis.NewRedis(s.Addr(), redis.NodeType)
c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*10)) c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*10))
_, err = c.ExecNoCache("delete from user_table where id='kevin'") _, err = c.ExecNoCache("delete from user_table where id='kevin'")
assert.Nil(t, err) assert.Nil(t, err)
@@ -486,24 +473,26 @@ func TestCachedConnExec(t *testing.T) {
} }
func TestCachedConnExecDropCache(t *testing.T) { func TestCachedConnExecDropCache(t *testing.T) {
s, clean, err := createMiniRedis() r, err := miniredis.Run()
assert.Nil(t, err) assert.Nil(t, err)
defer clean() defer fx.DoWithTimeout(func() error {
r.Close()
return nil
}, time.Second)
const ( const (
key = "user" key = "user"
value = "any" value = "any"
) )
var conn trackedConn var conn trackedConn
r := redis.NewRedis(s.Addr(), redis.NodeType) c := NewNodeConn(&conn, redis.NewRedis(r.Addr(), redis.NodeType), cache.WithExpiry(time.Second*30))
c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*30))
assert.Nil(t, c.SetCache(key, value)) assert.Nil(t, c.SetCache(key, value))
_, err = c.Exec(func(conn sqlx.SqlConn) (result sql.Result, e error) { _, err = c.Exec(func(conn sqlx.SqlConn) (result sql.Result, e error) {
return conn.Exec("delete from user_table where id='kevin'") return conn.Exec("delete from user_table where id='kevin'")
}, key) }, key)
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, conn.execValue) assert.True(t, conn.execValue)
_, err = s.Get(key) _, err = r.Get(key)
assert.Exactly(t, miniredis.ErrKeyNotFound, err) assert.Exactly(t, miniredis.ErrKeyNotFound, err)
_, err = c.Exec(func(conn sqlx.SqlConn) (result sql.Result, e error) { _, err = c.Exec(func(conn sqlx.SqlConn) (result sql.Result, e error) {
return nil, errors.New("foo") return nil, errors.New("foo")
@@ -524,12 +513,11 @@ func TestCachedConnExecDropCacheFailed(t *testing.T) {
} }
func TestCachedConnQueryRows(t *testing.T) { func TestCachedConnQueryRows(t *testing.T) {
s, clean, err := createMiniRedis() r, clean, err := redistest.CreateRedis()
assert.Nil(t, err) assert.Nil(t, err)
defer clean() defer clean()
var conn trackedConn var conn trackedConn
r := redis.NewRedis(s.Addr(), redis.NodeType)
c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*10)) c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*10))
var users []string var users []string
err = c.QueryRowsNoCache(&users, "select user from user_table where id='kevin'") err = c.QueryRowsNoCache(&users, "select user from user_table where id='kevin'")
@@ -538,12 +526,11 @@ func TestCachedConnQueryRows(t *testing.T) {
} }
func TestCachedConnTransact(t *testing.T) { func TestCachedConnTransact(t *testing.T) {
s, clean, err := createMiniRedis() r, clean, err := redistest.CreateRedis()
assert.Nil(t, err) assert.Nil(t, err)
defer clean() defer clean()
var conn trackedConn var conn trackedConn
r := redis.NewRedis(s.Addr(), redis.NodeType)
c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*10)) c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*10))
err = c.Transact(func(session sqlx.Session) error { err = c.Transact(func(session sqlx.Session) error {
return nil return nil
@@ -553,7 +540,7 @@ func TestCachedConnTransact(t *testing.T) {
} }
func TestQueryRowNoCache(t *testing.T) { func TestQueryRowNoCache(t *testing.T) {
s, clean, err := createMiniRedis() r, clean, err := redistest.CreateRedis()
assert.Nil(t, err) assert.Nil(t, err)
defer clean() defer clean()
@@ -563,7 +550,6 @@ func TestQueryRowNoCache(t *testing.T) {
) )
var user string var user string
var ran bool var ran bool
r := redis.NewRedis(s.Addr(), redis.NodeType)
conn := dummySqlConn{queryRow: func(v interface{}, q string, args ...interface{}) error { conn := dummySqlConn{queryRow: func(v interface{}, q string, args ...interface{}) error {
user = value user = value
ran = true ran = true
@@ -639,22 +625,3 @@ func (c *trackedConn) Transact(fn func(session sqlx.Session) error) error {
c.transactValue = true c.transactValue = true
return c.dummySqlConn.Transact(fn) return c.dummySqlConn.Transact(fn)
} }
func createMiniRedis() (r *miniredis.Miniredis, clean func(), err error) {
r, err = miniredis.Run()
if err != nil {
return nil, nil, err
}
return r, func() {
ch := make(chan lang.PlaceholderType)
go func() {
r.Close()
close(ch)
}()
select {
case <-ch:
case <-time.After(time.Second):
}
}, nil
}

View File

@@ -33,7 +33,7 @@ func NewSharedCalls() SharedCalls {
} }
func (g *sharedGroup) Do(key string, fn func() (interface{}, error)) (interface{}, error) { func (g *sharedGroup) Do(key string, fn func() (interface{}, error)) (interface{}, error) {
c, done := g.createCall(key, fn) c, done := g.createCall(key)
if done { if done {
return c.val, c.err return c.val, c.err
} }
@@ -43,7 +43,7 @@ func (g *sharedGroup) Do(key string, fn func() (interface{}, error)) (interface{
} }
func (g *sharedGroup) DoEx(key string, fn func() (interface{}, error)) (val interface{}, fresh bool, err error) { func (g *sharedGroup) DoEx(key string, fn func() (interface{}, error)) (val interface{}, fresh bool, err error) {
c, done := g.createCall(key, fn) c, done := g.createCall(key)
if done { if done {
return c.val, false, c.err return c.val, false, c.err
} }
@@ -52,7 +52,7 @@ func (g *sharedGroup) DoEx(key string, fn func() (interface{}, error)) (val inte
return c.val, true, c.err return c.val, true, c.err
} }
func (g *sharedGroup) createCall(key string, fn func() (interface{}, error)) (c *call, done bool) { func (g *sharedGroup) createCall(key string) (c *call, done bool) {
g.lock.Lock() g.lock.Lock()
if c, ok := g.calls[key]; ok { if c, ok := g.calls[key]; ok {
g.lock.Unlock() g.lock.Unlock()

Binary file not shown.

After

Width:  |  Height:  |  Size: 286 KiB

BIN
doc/images/architecture.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 333 KiB

BIN
doc/images/benchmark.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 26 KiB

BIN
doc/images/go-zero.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 84 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 170 KiB

BIN
doc/images/resilience.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 150 KiB

View File

@@ -8,9 +8,9 @@ import (
"fmt" "fmt"
"bookstore/rpc/add/internal/config" "bookstore/rpc/add/internal/config"
add "bookstore/rpc/add/internal/pb"
"bookstore/rpc/add/internal/server" "bookstore/rpc/add/internal/server"
"bookstore/rpc/add/internal/svc" "bookstore/rpc/add/internal/svc"
add "bookstore/rpc/add/pb"
"github.com/tal-tech/go-zero/core/conf" "github.com/tal-tech/go-zero/core/conf"
"github.com/tal-tech/go-zero/core/logx" "github.com/tal-tech/go-zero/core/logx"

View File

@@ -8,13 +8,15 @@ package adder
import ( import (
"context" "context"
add "bookstore/rpc/add/pb" add "bookstore/rpc/add/internal/pb"
"github.com/tal-tech/go-zero/core/jsonx"
"github.com/tal-tech/go-zero/zrpc" "github.com/tal-tech/go-zero/zrpc"
) )
type ( type (
AddReq = add.AddReq
AddResp = add.AddResp
Adder interface { Adder interface {
Add(ctx context.Context, in *AddReq) (*AddResp, error) Add(ctx context.Context, in *AddReq) (*AddResp, error)
} }
@@ -31,33 +33,6 @@ func NewAdder(cli zrpc.Client) Adder {
} }
func (m *defaultAdder) Add(ctx context.Context, in *AddReq) (*AddResp, error) { func (m *defaultAdder) Add(ctx context.Context, in *AddReq) (*AddResp, error) {
var request add.AddReq adder := add.NewAdderClient(m.cli.Conn())
bts, err := jsonx.Marshal(in) return adder.Add(ctx, in)
if err != nil {
return nil, errJsonConvert
}
err = jsonx.Unmarshal(bts, &request)
if err != nil {
return nil, errJsonConvert
}
client := add.NewAdderClient(m.cli.Conn())
resp, err := client.Add(ctx, &request)
if err != nil {
return nil, err
}
var ret AddResp
bts, err = jsonx.Marshal(resp)
if err != nil {
return nil, errJsonConvert
}
err = jsonx.Unmarshal(bts, &ret)
if err != nil {
return nil, errJsonConvert
}
return &ret, nil
} }

View File

@@ -1,19 +0,0 @@
// Code generated by goctl. DO NOT EDIT!
// Source: add.proto
package adder
import "errors"
var errJsonConvert = errors.New("json convert error")
type (
AddReq struct {
Book string `json:"book,omitempty"`
Price int64 `json:"price,omitempty"`
}
AddResp struct {
Ok bool `json:"ok,omitempty"`
}
)

View File

@@ -3,8 +3,8 @@ package logic
import ( import (
"context" "context"
add "bookstore/rpc/add/internal/pb"
"bookstore/rpc/add/internal/svc" "bookstore/rpc/add/internal/svc"
add "bookstore/rpc/add/pb"
"bookstore/rpc/model" "bookstore/rpc/model"
"github.com/tal-tech/go-zero/core/logx" "github.com/tal-tech/go-zero/core/logx"

View File

@@ -14,13 +14,13 @@ It has these top-level messages:
*/ */
package add package add
import proto "github.com/golang/protobuf/proto"
import fmt "fmt"
import math "math"
import ( import (
context "golang.org/x/net/context" "fmt"
grpc "google.golang.org/grpc" "math"
"github.com/golang/protobuf/proto"
"golang.org/x/net/context"
"google.golang.org/grpc"
) )
// Reference imports to suppress errors if they are not otherwise used. // Reference imports to suppress errors if they are not otherwise used.

View File

@@ -7,8 +7,8 @@ import (
"context" "context"
"bookstore/rpc/add/internal/logic" "bookstore/rpc/add/internal/logic"
add "bookstore/rpc/add/internal/pb"
"bookstore/rpc/add/internal/svc" "bookstore/rpc/add/internal/svc"
add "bookstore/rpc/add/pb"
) )
type AdderServer struct { type AdderServer struct {

View File

@@ -8,9 +8,9 @@ import (
"fmt" "fmt"
"bookstore/rpc/check/internal/config" "bookstore/rpc/check/internal/config"
check "bookstore/rpc/check/internal/pb"
"bookstore/rpc/check/internal/server" "bookstore/rpc/check/internal/server"
"bookstore/rpc/check/internal/svc" "bookstore/rpc/check/internal/svc"
check "bookstore/rpc/check/pb"
"github.com/tal-tech/go-zero/core/conf" "github.com/tal-tech/go-zero/core/conf"
"github.com/tal-tech/go-zero/core/logx" "github.com/tal-tech/go-zero/core/logx"

View File

@@ -8,13 +8,15 @@ package checker
import ( import (
"context" "context"
check "bookstore/rpc/check/pb" check "bookstore/rpc/check/internal/pb"
"github.com/tal-tech/go-zero/core/jsonx"
"github.com/tal-tech/go-zero/zrpc" "github.com/tal-tech/go-zero/zrpc"
) )
type ( type (
CheckReq = check.CheckReq
CheckResp = check.CheckResp
Checker interface { Checker interface {
Check(ctx context.Context, in *CheckReq) (*CheckResp, error) Check(ctx context.Context, in *CheckReq) (*CheckResp, error)
} }
@@ -31,33 +33,6 @@ func NewChecker(cli zrpc.Client) Checker {
} }
func (m *defaultChecker) Check(ctx context.Context, in *CheckReq) (*CheckResp, error) { func (m *defaultChecker) Check(ctx context.Context, in *CheckReq) (*CheckResp, error) {
var request check.CheckReq checker := check.NewCheckerClient(m.cli.Conn())
bts, err := jsonx.Marshal(in) return checker.Check(ctx, in)
if err != nil {
return nil, errJsonConvert
}
err = jsonx.Unmarshal(bts, &request)
if err != nil {
return nil, errJsonConvert
}
client := check.NewCheckerClient(m.cli.Conn())
resp, err := client.Check(ctx, &request)
if err != nil {
return nil, err
}
var ret CheckResp
bts, err = jsonx.Marshal(resp)
if err != nil {
return nil, errJsonConvert
}
err = jsonx.Unmarshal(bts, &ret)
if err != nil {
return nil, errJsonConvert
}
return &ret, nil
} }

View File

@@ -1,19 +0,0 @@
// Code generated by goctl. DO NOT EDIT!
// Source: check.proto
package checker
import "errors"
var errJsonConvert = errors.New("json convert error")
type (
CheckReq struct {
Book string `json:"book,omitempty"`
}
CheckResp struct {
Found bool `json:"found,omitempty"`
Price int64 `json:"price,omitempty"`
}
)

View File

@@ -3,8 +3,8 @@ package logic
import ( import (
"context" "context"
check "bookstore/rpc/check/internal/pb"
"bookstore/rpc/check/internal/svc" "bookstore/rpc/check/internal/svc"
check "bookstore/rpc/check/pb"
"github.com/tal-tech/go-zero/core/logx" "github.com/tal-tech/go-zero/core/logx"
) )

View File

@@ -14,13 +14,13 @@ It has these top-level messages:
*/ */
package check package check
import proto "github.com/golang/protobuf/proto"
import fmt "fmt"
import math "math"
import ( import (
context "golang.org/x/net/context" "fmt"
grpc "google.golang.org/grpc" "math"
"github.com/golang/protobuf/proto"
"golang.org/x/net/context"
"google.golang.org/grpc"
) )
// Reference imports to suppress errors if they are not otherwise used. // Reference imports to suppress errors if they are not otherwise used.

View File

@@ -7,8 +7,8 @@ import (
"context" "context"
"bookstore/rpc/check/internal/logic" "bookstore/rpc/check/internal/logic"
check "bookstore/rpc/check/internal/pb"
"bookstore/rpc/check/internal/svc" "bookstore/rpc/check/internal/svc"
check "bookstore/rpc/check/pb"
) )
type CheckerServer struct { type CheckerServer struct {

2
go.mod
View File

@@ -43,7 +43,7 @@ require (
github.com/spaolacci/murmur3 v1.1.0 github.com/spaolacci/murmur3 v1.1.0
github.com/stretchr/testify v1.5.1 github.com/stretchr/testify v1.5.1
github.com/tmc/grpc-websocket-proxy v0.0.0-20171017195756-830351dc03c6 // indirect github.com/tmc/grpc-websocket-proxy v0.0.0-20171017195756-830351dc03c6 // indirect
github.com/urfave/cli v1.22.4 github.com/urfave/cli v1.22.5
github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2 github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2
github.com/yuin/gopher-lua v0.0.0-20191220021717-ab39c6098bdb // indirect github.com/yuin/gopher-lua v0.0.0-20191220021717-ab39c6098bdb // indirect
go.etcd.io/etcd v0.0.0-20200402134248-51bdeb39e698 go.etcd.io/etcd v0.0.0-20200402134248-51bdeb39e698

2
go.sum
View File

@@ -277,6 +277,8 @@ github.com/tmc/grpc-websocket-proxy v0.0.0-20171017195756-830351dc03c6/go.mod h1
github.com/urfave/cli v1.20.0/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA= github.com/urfave/cli v1.20.0/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA=
github.com/urfave/cli v1.22.4 h1:u7tSpNPPswAFymm8IehJhy4uJMlUuU/GmqSkvJ1InXA= github.com/urfave/cli v1.22.4 h1:u7tSpNPPswAFymm8IehJhy4uJMlUuU/GmqSkvJ1InXA=
github.com/urfave/cli v1.22.4/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= github.com/urfave/cli v1.22.4/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0=
github.com/urfave/cli v1.22.5 h1:lNq9sAHXK2qfdI8W+GRItjCEkI+2oR4d+MEHy1CKXoU=
github.com/urfave/cli v1.22.5/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0=
github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2 h1:eY9dn8+vbi4tKz5Qo6v2eYzo7kUS51QINcR5jNpbZS8= github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2 h1:eY9dn8+vbi4tKz5Qo6v2eYzo7kUS51QINcR5jNpbZS8=
github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU=
github.com/xlab/treeprint v0.0.0-20180616005107-d6fb6747feb6 h1:YdYsPAZ2pC6Tow/nPZOPQ96O3hm/ToAkGsPLzedXERk= github.com/xlab/treeprint v0.0.0-20180616005107-d6fb6747feb6 h1:YdYsPAZ2pC6Tow/nPZOPQ96O3hm/ToAkGsPLzedXERk=

View File

@@ -1,4 +1,4 @@
<img align="right" width="150px" src="https://raw.githubusercontent.com/tal-tech/zero-doc/main/doc/images/go-zero.png"> <img align="right" width="150px" src="doc/images/go-zero.png">
# go-zero # go-zero
@@ -25,7 +25,7 @@ Advantages of go-zero:
* auto validate the request parameters from clients * auto validate the request parameters from clients
* plenty of builtin microservice management and concurrent toolkits * plenty of builtin microservice management and concurrent toolkits
<img src="https://raw.githubusercontent.com/tal-tech/zero-doc/main/doc/images/architecture-en.png" alt="Architecture" width="1500" /> <img src="doc/images/architecture-en.png" alt="Architecture" width="1500" />
## 1. Backgrounds of go-zero ## 1. Backgrounds of go-zero
@@ -76,7 +76,7 @@ go-zero is a web and rpc framework that integrates lots of engineering practices
As below, go-zero protects the system with couple layers and mechanisms: As below, go-zero protects the system with couple layers and mechanisms:
![Resilience](https://raw.githubusercontent.com/tal-tech/zero-doc/main/doc/images/resilience-en.png) ![Resilience](doc/images/resilience-en.png)
## 4. Future development plans of go-zero ## 4. Future development plans of go-zero
@@ -121,19 +121,17 @@ go get -u github.com/tal-tech/go-zero
} }
service greet-api { service greet-api {
@server( @handler GreetHandler
handler: GreetHandler
)
get /greet/from/:name(Request) returns (Response); get /greet/from/:name(Request) returns (Response);
} }
``` ```
the .api files also can be generate by goctl, like below: the .api files also can be generate by goctl, like below:
```shell ```shell
goctl api -o greet.api goctl api -o greet.api
``` ```
3. generate the go server side code 3. generate the go server side code
```shell ```shell
@@ -202,7 +200,7 @@ go get -u github.com/tal-tech/go-zero
## 7. Benchmark ## 7. Benchmark
![benchmark](https://raw.githubusercontent.com/tal-tech/zero-doc/main/doc/images/benchmark.png) ![benchmark](doc/images/benchmark.png)
[Checkout the test code](https://github.com/smallnest/go-web-framework-benchmark) [Checkout the test code](https://github.com/smallnest/go-web-framework-benchmark)

View File

@@ -1,4 +1,4 @@
<img align="right" width="150px" src="https://raw.githubusercontent.com/tal-tech/zero-doc/main/doc/images/go-zero.png"> <img align="right" width="150px" src="doc/images/go-zero.png">
# go-zero # go-zero
@@ -25,7 +25,7 @@ go-zero 包含极简的 API 定义和生成工具 goctl可以根据定义的
* 自动校验客户端请求参数合法性 * 自动校验客户端请求参数合法性
* 大量微服务治理和并发工具包 * 大量微服务治理和并发工具包
<img src="https://raw.githubusercontent.com/tal-tech/zero-doc/main/doc/images/architecture.png" alt="架构图" width="1500" /> <img src="doc/images/architecture.png" alt="架构图" width="1500" />
## 1. go-zero 框架背景 ## 1. go-zero 框架背景
@@ -77,7 +77,9 @@ go-zero 是一个集成了各种工程实践的包含 web 和 rpc 框架,有
如下图,我们从多个层面保障了整体服务的高可用: 如下图,我们从多个层面保障了整体服务的高可用:
![弹性设计](https://raw.githubusercontent.com/tal-tech/zero-doc/main/doc/images/resilience.jpg) ![弹性设计](doc/images/resilience.jpg)
觉得不错的话,别忘 **star** 👏
## 4. Installation ## 4. Installation
@@ -93,7 +95,7 @@ GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/tal-tech/
[快速构建高并发微服务](https://github.com/tal-tech/zero-doc/blob/main/doc/shorturl.md) [快速构建高并发微服务](https://github.com/tal-tech/zero-doc/blob/main/doc/shorturl.md)
[快速构建高并发微服务 - 多 RPC 版](https://github.com/tal-tech/zero-doc/blob/main/doc/bookstore.md) [快速构建高并发微服务 - 多 RPC 版](https://github.com/tal-tech/zero-doc/blob/main/docs/frame/bookstore.md)
1. 安装 goctl 工具 1. 安装 goctl 工具
@@ -148,7 +150,7 @@ GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/tal-tech/
## 6. Benchmark ## 6. Benchmark
![benchmark](https://raw.githubusercontent.com/tal-tech/zero-doc/main/doc/images/benchmark.png) ![benchmark](doc/images/benchmark.png)
[测试代码见这里](https://github.com/smallnest/go-web-framework-benchmark) [测试代码见这里](https://github.com/smallnest/go-web-framework-benchmark)
@@ -160,7 +162,7 @@ GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/tal-tech/
* awesome 系列 * awesome 系列
* [快速构建高并发微服务](https://github.com/tal-tech/zero-doc/blob/main/doc/shorturl.md) * [快速构建高并发微服务](https://github.com/tal-tech/zero-doc/blob/main/doc/shorturl.md)
* [快速构建高并发微服务 - 多 RPC 版](https://github.com/tal-tech/zero-doc/blob/main/doc/bookstore.md) * [快速构建高并发微服务 - 多 RPC 版](https://github.com/tal-tech/zero-doc/blob/main/docs/frame/bookstore.md)
* [goctl 使用帮助](https://github.com/tal-tech/zero-doc/blob/main/doc/goctl.md) * [goctl 使用帮助](https://github.com/tal-tech/zero-doc/blob/main/doc/goctl.md)
* [通过 MapReduce 降低服务响应时间](https://github.com/tal-tech/zero-doc/blob/main/doc/mapreduce.md) * [通过 MapReduce 降低服务响应时间](https://github.com/tal-tech/zero-doc/blob/main/doc/mapreduce.md)
* [关键字替换和敏感词过滤工具](https://github.com/tal-tech/zero-doc/blob/main/doc/keywords.md) * [关键字替换和敏感词过滤工具](https://github.com/tal-tech/zero-doc/blob/main/doc/keywords.md)
@@ -170,18 +172,22 @@ GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/tal-tech/
* [文本序列化和反序列化](https://github.com/tal-tech/zero-doc/blob/main/doc/mapping.md) * [文本序列化和反序列化](https://github.com/tal-tech/zero-doc/blob/main/doc/mapping.md)
* [快速构建 jwt 鉴权认证](https://github.com/tal-tech/zero-doc/blob/main/doc/jwt.md) * [快速构建 jwt 鉴权认证](https://github.com/tal-tech/zero-doc/blob/main/doc/jwt.md)
## 9. 微信交流群 ## 8. 微信交流群
加群之前有劳给一个 star一个小小的 star 是作者们回答海量问题的动力。
如果文档中未能覆盖的任何疑问,欢迎您在群里提出,我们会尽快答复。 如果文档中未能覆盖的任何疑问,欢迎您在群里提出,我们会尽快答复。
您可以在群内提出使用中需要改进的地方,我们会考虑合理性并尽快修改。 您可以在群内提出使用中需要改进的地方,我们会考虑合理性并尽快修改。
如果您发现 bug 请及时提 issue我们会尽快确认并修改。 如果您发现 ***bug*** 请及时提 ***issue***,我们会尽快确认并修改。
<!-- 扫码后请加群主,便于我邀请您进讨论群,并请退出扫码网关群,谢!--> 为了防止广告用户、识别技术同行,请 ***star*** 后加我时注明 **github** 当前 ***star*** 数,我再拉进 **go-zero** 群,谢!
开源中国年度评选给go-zero投上一票[https://www.oschina.net/p/go-zero](https://www.oschina.net/p/go-zero) 加我之前有劳点一下 ***star***,一个小小的 ***star*** 是作者们回答海量问题的动力🤝
<img src="https://raw.githubusercontent.com/tal-tech/zero-doc/main/doc/images/wechat.jpg" alt="wechat" width="300" /> <img src="https://gitee.com/kevwan/static/raw/master/images/wechat.jpg" alt="wechat" width="300" />
项目地址:[https://github.com/tal-tech/go-zero](https://github.com/tal-tech/go-zero)
码云地址:[https://gitee.com/kevwan/go-zero](https://gitee.com/kevwan/go-zero) (国内用户可访问gitee每日自动从github同步代码)
开源中国年度评选,给 **go-zero** 投上一票:[https://www.oschina.net/p/go-zero](https://www.oschina.net/p/go-zero)

View File

@@ -18,7 +18,7 @@ type (
PrivateKeys []PrivateKeyConf PrivateKeys []PrivateKeyConf
} }
// why not name it as Conf, because we need to consider usage like: // Why not name it as Conf, because we need to consider usage like:
// type Config struct { // type Config struct {
// zrpc.RpcConf // zrpc.RpcConf
// rest.RestConf // rest.RestConf
@@ -28,9 +28,11 @@ type (
service.ServiceConf service.ServiceConf
Host string `json:",default=0.0.0.0"` Host string `json:",default=0.0.0.0"`
Port int Port int
Verbose bool `json:",optional"` CertFile string `json:",optional"`
MaxConns int `json:",default=10000"` KeyFile string `json:",optional"`
MaxBytes int64 `json:",default=1048576,range=[0:8388608]"` Verbose bool `json:",optional"`
MaxConns int `json:",default=10000"`
MaxBytes int64 `json:",default=1048576,range=[0:8388608]"`
// milliseconds // milliseconds
Timeout int64 `json:",default=3000"` Timeout int64 `json:",default=3000"`
CpuThreshold int64 `json:",default=900,range=[0:1000]"` CpuThreshold int64 `json:",default=900,range=[0:1000]"`

View File

@@ -65,7 +65,11 @@ func (s *engine) StartWithRouter(router httpx.Router) error {
return err return err
} }
return internal.StartHttp(s.conf.Host, s.conf.Port, router) if len(s.conf.CertFile) == 0 && len(s.conf.KeyFile) == 0 {
return internal.StartHttp(s.conf.Host, s.conf.Port, router)
}
return internal.StartHttps(s.conf.Host, s.conf.Port, s.conf.CertFile, s.conf.KeyFile, router)
} }
func (s *engine) appendAuthHandler(fr featuredRoutes, chain alice.Chain, func (s *engine) appendAuthHandler(fr featuredRoutes, chain alice.Chain,

View File

@@ -19,9 +19,7 @@ func TestParseForm(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, "http://hello.com/a?name=hello&age=18&percent=3.4", nil) r, err := http.NewRequest(http.MethodGet, "http://hello.com/a?name=hello&age=18&percent=3.4", nil)
assert.Nil(t, err) assert.Nil(t, err)
assert.Nil(t, Parse(r, &v))
err = Parse(r, &v)
assert.Nil(t, err)
assert.Equal(t, "hello", v.Name) assert.Equal(t, "hello", v.Name)
assert.Equal(t, 18, v.Age) assert.Equal(t, 18, v.Age)
assert.Equal(t, 3.4, v.Percent) assert.Equal(t, 3.4, v.Percent)
@@ -97,8 +95,44 @@ Content-Disposition: form-data; name="age"
r := httptest.NewRequest(http.MethodPost, "http://localhost:3333/", strings.NewReader(body)) r := httptest.NewRequest(http.MethodPost, "http://localhost:3333/", strings.NewReader(body))
r.Header.Set(ContentType, "multipart/form-data; boundary=--------------------------220477612388154780019383") r.Header.Set(ContentType, "multipart/form-data; boundary=--------------------------220477612388154780019383")
err := Parse(r, &v) assert.Nil(t, Parse(r, &v))
assert.Nil(t, err) assert.Equal(t, "kevin", v.Name)
assert.Equal(t, 18, v.Age)
}
func TestParseMultipartFormWrongBoundary(t *testing.T) {
var v struct {
Name string `form:"name"`
Age int `form:"age"`
}
body := strings.Replace(`----------------------------22047761238815478001938
Content-Disposition: form-data; name="name"
kevin
----------------------------22047761238815478001938
Content-Disposition: form-data; name="age"
18
----------------------------22047761238815478001938--`, "\n", "\r\n", -1)
r := httptest.NewRequest(http.MethodPost, "http://localhost:3333/", strings.NewReader(body))
r.Header.Set(ContentType, "multipart/form-data; boundary=--------------------------220477612388154780019383")
assert.NotNil(t, Parse(r, &v))
}
func TestParseJsonBody(t *testing.T) {
var v struct {
Name string `json:"name"`
Age int `json:"age"`
}
body := `{"name":"kevin", "age": 18}`
r := httptest.NewRequest(http.MethodPost, "http://localhost:3333/", strings.NewReader(body))
r.Header.Set(ContentType, ApplicationJson)
assert.Nil(t, Parse(r, &v))
assert.Equal(t, "kevin", v.Name) assert.Equal(t, "kevin", v.Name)
assert.Equal(t, 18, v.Age) assert.Equal(t, 18, v.Age)
} }
@@ -111,9 +145,7 @@ func TestParseRequired(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, "http://hello.com/a?name=hello", nil) r, err := http.NewRequest(http.MethodGet, "http://hello.com/a?name=hello", nil)
assert.Nil(t, err) assert.Nil(t, err)
assert.NotNil(t, Parse(r, &v))
err = Parse(r, &v)
assert.NotNil(t, err)
} }
func TestParseOptions(t *testing.T) { func TestParseOptions(t *testing.T) {
@@ -123,9 +155,7 @@ func TestParseOptions(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, "http://hello.com/a?pos=4", nil) r, err := http.NewRequest(http.MethodGet, "http://hello.com/a?pos=4", nil)
assert.Nil(t, err) assert.Nil(t, err)
assert.NotNil(t, Parse(r, &v))
err = Parse(r, &v)
assert.NotNil(t, err)
} }
func BenchmarkParseRaw(b *testing.B) { func BenchmarkParseRaw(b *testing.B) {

View File

@@ -23,5 +23,5 @@ func WithPathVars(r *http.Request, params map[string]string) *http.Request {
type contextKey string type contextKey string
func (c contextKey) String() string { func (c contextKey) String() string {
return "rest/internal/context context key" + string(c) return "rest/internal/context key: " + string(c)
} }

View File

@@ -0,0 +1,32 @@
package context
import (
"context"
"net/http"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestVars(t *testing.T) {
expect := map[string]string{
"a": "1",
"b": "2",
}
r, err := http.NewRequest(http.MethodGet, "/", nil)
assert.Nil(t, err)
r = r.WithContext(context.WithValue(context.Background(), pathVars, expect))
assert.EqualValues(t, expect, Vars(r))
}
func TestVarsNil(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, "/", nil)
assert.Nil(t, err)
assert.Nil(t, Vars(r))
}
func TestContextKey(t *testing.T) {
ck := contextKey("hello")
assert.True(t, strings.Contains(ck.String(), "hello"))
}

View File

@@ -2,7 +2,6 @@ package internal
import ( import (
"context" "context"
"crypto/tls"
"fmt" "fmt"
"net/http" "net/http"
@@ -10,42 +9,27 @@ import (
) )
func StartHttp(host string, port int, handler http.Handler) error { func StartHttp(host string, port int, handler http.Handler) error {
addr := fmt.Sprintf("%s:%d", host, port) return start(host, port, handler, func(srv *http.Server) error {
server := buildHttpServer(addr, handler) return srv.ListenAndServe()
return StartServer(server) })
} }
func StartHttps(host string, port int, certFile, keyFile string, handler http.Handler) error { func StartHttps(host string, port int, certFile, keyFile string, handler http.Handler) error {
addr := fmt.Sprintf("%s:%d", host, port) return start(host, port, handler, func(srv *http.Server) error {
if server, err := buildHttpsServer(addr, handler, certFile, keyFile); err != nil { // certFile and keyFile are set in buildHttpsServer
return err return srv.ListenAndServeTLS(certFile, keyFile)
} else {
return StartServer(server)
}
}
func StartServer(srv *http.Server) error {
proc.AddWrapUpListener(func() {
srv.Shutdown(context.Background())
}) })
return srv.ListenAndServe()
} }
func buildHttpServer(addr string, handler http.Handler) *http.Server { func start(host string, port int, handler http.Handler, run func(srv *http.Server) error) error {
return &http.Server{Addr: addr, Handler: handler} server := &http.Server{
} Addr: fmt.Sprintf("%s:%d", host, port),
Handler: handler,
func buildHttpsServer(addr string, handler http.Handler, certFile, keyFile string) (*http.Server, error) {
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, err
} }
waitForCalled := proc.AddWrapUpListener(func() {
server.Shutdown(context.Background())
})
defer waitForCalled()
config := tls.Config{Certificates: []tls.Certificate{cert}} return run(server)
return &http.Server{
Addr: addr,
Handler: handler,
TLSConfig: &config,
}, nil
} }

View File

@@ -94,7 +94,6 @@ func ApiFormatByPath(apiFilePath string) error {
} }
func apiFormat(data string) (string, error) { func apiFormat(data string) (string, error) {
r := reg.ReplaceAllStringFunc(data, func(m string) string { r := reg.ReplaceAllStringFunc(data, func(m string) string {
parts := reg.FindStringSubmatch(m) parts := reg.FindStringSubmatch(m)
if len(parts) < 2 { if len(parts) < 2 {

View File

@@ -1,11 +1,9 @@
package gogen package gogen
import ( import (
"bytes"
"errors" "errors"
"fmt" "fmt"
"os" "os"
"os/exec"
"path" "path"
"path/filepath" "path/filepath"
"strconv" "strconv"
@@ -60,7 +58,6 @@ func DoGenProject(apiFile, dir string, force bool) error {
logx.Must(genHandlers(dir, api)) logx.Must(genHandlers(dir, api))
logx.Must(genRoutes(dir, api, force)) logx.Must(genRoutes(dir, api, force))
logx.Must(genLogic(dir, api)) logx.Must(genLogic(dir, api))
createGoModFileIfNeed(dir)
if err := backupAndSweep(apiFile); err != nil { if err := backupAndSweep(apiFile); err != nil {
return err return err
@@ -129,34 +126,3 @@ func sweep() error {
return nil return nil
}) })
} }
func createGoModFileIfNeed(dir string) {
absDir, err := filepath.Abs(dir)
if err != nil {
panic(err)
}
_, hasGoMod := util.FindGoModPath(dir)
if hasGoMod {
return
}
gopath := os.Getenv("GOPATH")
parent := path.Join(gopath, "src")
pos := strings.Index(absDir, parent)
if pos >= 0 {
return
}
moduleName := absDir[len(filepath.Dir(absDir))+1:]
cmd := exec.Command("go", "mod", "init", moduleName)
cmd.Dir = dir
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
if err = cmd.Run(); err != nil {
fmt.Println(err.Error())
}
outStr, errStr := string(stdout.Bytes()), string(stderr.Bytes())
fmt.Printf(outStr + "\n" + errStr)
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/tools/goctl/api/parser" "github.com/tal-tech/go-zero/tools/goctl/api/parser"
"github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
) )
const testApiTemplate = ` const testApiTemplate = `
@@ -29,6 +30,9 @@ type Response struct {
Message string ` + "`" + `json:"message"` + "`" + ` Message string ` + "`" + `json:"message"` + "`" + `
} }
@server(
group: greet
)
service A-api { service A-api {
@server( @server(
handler: GreetHandler handler: GreetHandler
@@ -37,6 +41,7 @@ service A-api {
@server( @server(
handler: NoResponseHandler handler: NoResponseHandler
) )
get /greet/get(Request) returns get /greet/get(Request) returns
} }
@@ -185,6 +190,94 @@ service A-api {
} }
` `
const apiRouteTest = `
type Request struct {
Name string ` + "`" + `path:"name,options=you|me"` + "`" + `
}
type Response struct {
Message string ` + "`" + `json:"message"` + "`" + `
}
service A-api {
@handler NormalHandler
get /greet/from/:name(Request) returns (Response)
@handler NoResponseHandler
get /greet/from/:sex(Request)
@handler NoRequestHandler
get /greet/from/request returns (Response)
@handler NoRequestNoResponseHandler
get /greet/from
}
`
const hasCommentApiTest = `
type Inline struct {
}
type Request struct {
Inline
Name string ` + "`" + `path:"name,options=you|me"` + "`" + ` // name in path
}
type Response struct {
Message string ` + "`" + `json:"msg"` + "`" + ` // message
}
service A-api {
@doc(helloworld)
@server(
handler: GreetHandler
)
get /greet/from/:name(Request) returns (Response)
}
`
const hasInlineNoExistTest = `
type Request struct {
Inline
Name string ` + "`" + `path:"name,options=you|me"` + "`" + `
}
type Response struct {
Message string ` + "`" + `json:"message"` + "`" + ` // message
}
service A-api {
@doc(helloworld)
@server(
handler: GreetHandler
)
get /greet/from/:name(Request) returns (Response)
}
`
const importApi = `
type ImportData struct {
Name string ` + "`" + `path:"name,options=you|me"` + "`" + `
}
`
const hasImportApi = `
import "importApi.api"
type Request struct {
Name string ` + "`" + `path:"name,options=you|me"` + "`" + `
}
type Response struct {
Message string ` + "`" + `json:"message"` + "`" + ` // message
}
service A-api {
@server(
handler: GreetHandler
)
get /greet/from/:name(Request) returns (Response)
}
`
func TestParser(t *testing.T) { func TestParser(t *testing.T) {
filename := "greet.api" filename := "greet.api"
err := ioutil.WriteFile(filename, []byte(testApiTemplate), os.ModePerm) err := ioutil.WriteFile(filename, []byte(testApiTemplate), os.ModePerm)
@@ -333,6 +426,79 @@ func TestApiHasNoRequestBody(t *testing.T) {
validate(t, filename) validate(t, filename)
} }
func TestApiRoutes(t *testing.T) {
filename := "greet.api"
err := ioutil.WriteFile(filename, []byte(apiRouteTest), os.ModePerm)
assert.Nil(t, err)
defer os.Remove(filename)
parser, err := parser.NewParser(filename)
assert.Nil(t, err)
_, err = parser.Parse()
assert.Nil(t, err)
validate(t, filename)
}
func TestHasCommentRoutes(t *testing.T) {
filename := "greet.api"
err := ioutil.WriteFile(filename, []byte(hasCommentApiTest), os.ModePerm)
assert.Nil(t, err)
defer os.Remove(filename)
parser, err := parser.NewParser(filename)
assert.Nil(t, err)
_, err = parser.Parse()
assert.Nil(t, err)
validate(t, filename)
}
func TestInlineTypeNotExist(t *testing.T) {
filename := "greet.api"
err := ioutil.WriteFile(filename, []byte(hasInlineNoExistTest), os.ModePerm)
assert.Nil(t, err)
defer os.Remove(filename)
parser, err := parser.NewParser(filename)
assert.Nil(t, err)
_, err = parser.Parse()
assert.Nil(t, err)
validate(t, filename)
}
func TestHasImportApi(t *testing.T) {
filename := "greet.api"
err := ioutil.WriteFile(filename, []byte(hasImportApi), os.ModePerm)
assert.Nil(t, err)
defer os.Remove(filename)
importApiName := "importApi.api"
err = ioutil.WriteFile(importApiName, []byte(importApi), os.ModePerm)
assert.Nil(t, err)
defer os.Remove(importApiName)
parser, err := parser.NewParser(filename)
assert.Nil(t, err)
api, err := parser.Parse()
assert.Nil(t, err)
var hasInline bool
for _, ty := range api.Types {
if ty.Name == "ImportData" {
hasInline = true
break
}
}
assert.True(t, hasInline)
validate(t, filename)
}
func validate(t *testing.T, api string) { func validate(t *testing.T, api string) {
dir := "_go" dir := "_go"
err := DoGenProject(api, dir, true) err := DoGenProject(api, dir, true)
@@ -346,6 +512,9 @@ func validate(t *testing.T, api string) {
} }
return nil return nil
}) })
_, err = execx.Run("go test ./...", dir)
assert.Nil(t, err)
} }
func validateCode(code string) error { func validateCode(code string) error {

View File

@@ -60,8 +60,9 @@ func genConfig(dir string, api *spec.ApiSpec) error {
"auth": strings.Join(auths, "\n"), "auth": strings.Join(auths, "\n"),
}) })
if err != nil { if err != nil {
return nil return err
} }
formatCode := formatCode(buffer.String()) formatCode := formatCode(buffer.String())
_, err = fp.WriteString(formatCode) _, err = fp.WriteString(formatCode)
return err return err

View File

@@ -55,6 +55,7 @@ func genEtc(dir string, api *spec.ApiSpec) error {
if err != nil { if err != nil {
return err return err
} }
formatCode := formatCode(buffer.String()) formatCode := formatCode(buffer.String())
_, err = fp.WriteString(formatCode) _, err = fp.WriteString(formatCode)
return err return err

View File

@@ -103,7 +103,7 @@ func doGenToFile(dir, handler string, group spec.Group, route spec.Route, handle
buffer := new(bytes.Buffer) buffer := new(bytes.Buffer)
err = template.Must(template.New("handlerTemplate").Parse(text)).Execute(buffer, handleObj) err = template.Must(template.New("handlerTemplate").Parse(text)).Execute(buffer, handleObj)
if err != nil { if err != nil {
return nil return err
} }
formatCode := formatCode(buffer.String()) formatCode := formatCode(buffer.String())

View File

@@ -72,8 +72,9 @@ func genMain(dir string, api *spec.ApiSpec) error {
"serviceName": api.Service.Name, "serviceName": api.Service.Name,
}) })
if err != nil { if err != nil {
return nil return err
} }
formatCode := formatCode(buffer.String()) formatCode := formatCode(buffer.String())
_, err = fp.WriteString(formatCode) _, err = fp.WriteString(formatCode)
return err return err

View File

@@ -49,8 +49,9 @@ func genMiddleware(dir string, middlewares []string) error {
"name": strings.Title(name), "name": strings.Title(name),
}) })
if err != nil { if err != nil {
return nil return err
} }
formatCode := formatCode(buffer.String()) formatCode := formatCode(buffer.String())
_, err = fp.WriteString(formatCode) _, err = fp.WriteString(formatCode)
return err return err

View File

@@ -52,7 +52,7 @@ type (
jwtEnabled bool jwtEnabled bool
signatureEnabled bool signatureEnabled bool
authName string authName string
middleware []string middlewares []string
} }
route struct { route struct {
method string method string
@@ -92,9 +92,9 @@ func genRoutes(dir string, api *spec.ApiSpec, force bool) error {
} }
var routes string var routes string
if len(g.middleware) > 0 { if len(g.middlewares) > 0 {
gbuilder.WriteString("\n}...,") gbuilder.WriteString("\n}...,")
var params = g.middleware var params = g.middlewares
for i := range params { for i := range params {
params[i] = "serverCtx." + params[i] params[i] = "serverCtx." + params[i]
} }
@@ -143,8 +143,9 @@ func genRoutes(dir string, api *spec.ApiSpec, force bool) error {
"routesAdditions": strings.TrimSpace(builder.String()), "routesAdditions": strings.TrimSpace(builder.String()),
}) })
if err != nil { if err != nil {
return nil return err
} }
formatCode := formatCode(buffer.String()) formatCode := formatCode(buffer.String())
_, err = fp.WriteString(formatCode) _, err = fp.WriteString(formatCode)
return err return err
@@ -206,7 +207,7 @@ func getRoutes(api *spec.ApiSpec) ([]group, error) {
} }
if value, ok := apiutil.GetAnnotationValue(g.Annotations, "server", "middleware"); ok { if value, ok := apiutil.GetAnnotationValue(g.Annotations, "server", "middleware"); ok {
for _, item := range strings.Split(value, ",") { for _, item := range strings.Split(value, ",") {
groupedRoutes.middleware = append(groupedRoutes.middleware, item) groupedRoutes.middlewares = append(groupedRoutes.middlewares, item)
} }
} }
routes = append(routes, groupedRoutes) routes = append(routes, groupedRoutes)

View File

@@ -90,8 +90,9 @@ func genServiceContext(dir string, api *spec.ApiSpec) error {
"middlewareAssignment": middlewareAssignment, "middlewareAssignment": middlewareAssignment,
}) })
if err != nil { if err != nil {
return nil return err
} }
formatCode := formatCode(buffer.String()) formatCode := formatCode(buffer.String())
_, err = fp.WriteString(formatCode) _, err = fp.WriteString(formatCode)
return err return err

View File

@@ -71,8 +71,9 @@ func genTypes(dir string, api *spec.ApiSpec, force bool) error {
"containsTime": api.ContainsTime(), "containsTime": api.ContainsTime(),
}) })
if err != nil { if err != nil {
return nil return err
} }
formatCode := formatCode(buffer.String()) formatCode := formatCode(buffer.String())
_, err = fp.WriteString(formatCode) _, err = fp.WriteString(formatCode)
return err return err
@@ -90,12 +91,6 @@ func convertTypeCase(types []spec.Type, t string) (string, error) {
if typ.Name == tp { if typ.Name == tp {
defTypes = append(defTypes, tp) defTypes = append(defTypes, tp)
} }
if len(typ.Annotations) > 0 {
if value, ok := apiutil.GetAnnotationValue(typ.Annotations, "serverReplacer", tp); ok {
t = strings.ReplaceAll(t, tp, value)
}
}
} }
} }

View File

@@ -10,28 +10,20 @@ import (
"github.com/tal-tech/go-zero/core/collection" "github.com/tal-tech/go-zero/core/collection"
"github.com/tal-tech/go-zero/tools/goctl/api/spec" "github.com/tal-tech/go-zero/tools/goctl/api/spec"
"github.com/tal-tech/go-zero/tools/goctl/api/util" "github.com/tal-tech/go-zero/tools/goctl/api/util"
goctlutil "github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util/ctx"
"github.com/tal-tech/go-zero/tools/goctl/util/project"
) )
func getParentPackage(dir string) (string, error) { func getParentPackage(dir string) (string, error) {
p, err := project.Prepare(dir, false) abs, err := filepath.Abs(dir)
if err != nil { if err != nil {
return "", err return "", err
} }
if len(p.GoMod.Path) > 0 { projectCtx, err := ctx.Prepare(abs)
goModePath := filepath.Clean(filepath.Dir(p.GoMod.Path)) if err != nil {
absPath, err := filepath.Abs(dir) return "", err
if err != nil {
return "", err
}
parent := filepath.Clean(goctlutil.JoinPackages(p.GoMod.Module, absPath[len(goModePath):]))
parent = strings.ReplaceAll(parent, "\\", "/")
return parent, nil
} }
return filepath.ToSlash(filepath.Join(projectCtx.Path, strings.TrimPrefix(projectCtx.WorkDir, projectCtx.Dir))), nil
return p.Package, nil
} }
func writeIndent(writer io.Writer, indent int) { func writeIndent(writer io.Writer, indent int) {

View File

@@ -12,7 +12,7 @@ import (
const apiTemplate = ` const apiTemplate = `
type Request struct { type Request struct {
Name string ` + "`" + `path:"name,options=you|me"` + "`" + ` // 框架自动验证请求参数是否合法 Name string ` + "`" + `path:"name,options=you|me"` + "`" + `
} }
type Response struct { type Response struct {

View File

@@ -39,6 +39,8 @@ func NewParser(filename string) (*Parser, error) {
if len(ip) > 0 { if len(ip) > 0 {
item := strings.TrimPrefix(item, "import") item := strings.TrimPrefix(item, "import")
item = strings.TrimSpace(item) item = strings.TrimSpace(item)
item = strings.TrimPrefix(item, `"`)
item = strings.TrimSuffix(item, `"`)
var path = item var path = item
if !util.FileExists(item) { if !util.FileExists(item) {
path = filepath.Join(filepath.Dir(apiAbsPath), item) path = filepath.Join(filepath.Dir(apiAbsPath), item)

View File

@@ -70,6 +70,12 @@ func (p *serviceEntityParser) parseLine(line string, api *spec.ApiSpec, annos []
ch, _, err := reader.ReadRune() ch, _, err := reader.ReadRune()
if err != nil { if err != nil {
if err == io.EOF { if err == io.EOF {
if builder.Len() > 0 {
token := strings.TrimSpace(builder.String())
if len(token) > 0 && token != returnsTag {
fields = append(fields, token)
}
}
break break
} }
return err return err

View File

@@ -125,7 +125,7 @@ func ParseApi(api string) (*ApiStruct, error) {
} }
func isImportBeginLine(line string) bool { func isImportBeginLine(line string) bool {
return strings.HasPrefix(line, "import") && strings.HasSuffix(line, ".api") return strings.HasPrefix(line, "import") && (strings.HasSuffix(line, ".api") || strings.HasSuffix(line, `.api"`))
} }
func isTypeBeginLine(line string) bool { func isTypeBeginLine(line string) bool {

View File

@@ -2,16 +2,58 @@ package docker
import ( import (
"errors" "errors"
"os"
"path/filepath"
"strings"
"github.com/tal-tech/go-zero/tools/goctl/gen" "github.com/tal-tech/go-zero/tools/goctl/gen"
"github.com/urfave/cli" "github.com/urfave/cli"
) )
const (
etcDir = "etc"
yamlEtx = ".yaml"
)
func DockerCommand(c *cli.Context) error { func DockerCommand(c *cli.Context) error {
goFile := c.String("go") goFile := c.String("go")
if len(goFile) == 0 { if len(goFile) == 0 {
return errors.New("-go can't be empty") return errors.New("-go can't be empty")
} }
return gen.GenerateDockerfile(goFile, "-f", "etc/config.yaml") cfg, err := findConfig(goFile, etcDir)
if err != nil {
return err
}
return gen.GenerateDockerfile(goFile, "-f", "etc/"+cfg)
}
func findConfig(file, dir string) (string, error) {
var files []string
err := filepath.Walk(dir, func(path string, f os.FileInfo, _ error) error {
if !f.IsDir() {
if filepath.Ext(f.Name()) == yamlEtx {
files = append(files, f.Name())
}
}
return nil
})
if err != nil {
return "", err
}
if len(files) == 0 {
return "", errors.New("no yaml file")
}
name := strings.TrimSuffix(filepath.Base(file), ".go")
for _, f := range files {
if strings.Index(f, name) == 0 {
return f, nil
}
}
return files[0], nil
} }

View File

@@ -6,7 +6,6 @@ import (
"text/template" "text/template"
"github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/vars"
) )
func GenerateDockerfile(goFile string, args ...string) error { func GenerateDockerfile(goFile string, args ...string) error {
@@ -33,10 +32,9 @@ func GenerateDockerfile(goFile string, args ...string) error {
t := template.Must(template.New("dockerfile").Parse(dockerTemplate)) t := template.Must(template.New("dockerfile").Parse(dockerTemplate))
return t.Execute(out, map[string]string{ return t.Execute(out, map[string]string{
"projectName": vars.ProjectName, "goRelPath": projPath,
"goRelPath": projPath, "goFile": goFile,
"goFile": goFile, "exeFile": util.FileNameWithoutExt(filepath.Base(goFile)),
"exeFile": util.FileNameWithoutExt(goFile), "argument": builder.String(),
"argument": builder.String(),
}) })
} }

View File

@@ -8,8 +8,9 @@ ENV CGO_ENABLED 0
ENV GOOS linux ENV GOOS linux
ENV GOPROXY https://goproxy.cn,direct ENV GOPROXY https://goproxy.cn,direct
WORKDIR $GOPATH/src/{{.projectName}} WORKDIR /build/zero
COPY . . COPY . .
COPY {{.goRelPath}}/etc /app/etc
RUN go build -ldflags="-s -w" -o /app/{{.exeFile}} {{.goRelPath}}/{{.goFile}} RUN go build -ldflags="-s -w" -o /app/{{.exeFile}} {{.goRelPath}}/{{.goFile}}
@@ -22,6 +23,7 @@ ENV TZ Asia/Shanghai
WORKDIR /app WORKDIR /app
COPY --from=builder /app/{{.exeFile}} /app/{{.exeFile}} COPY --from=builder /app/{{.exeFile}} /app/{{.exeFile}}
COPY --from=builder /app/etc /app/etc
CMD ["./{{.exeFile}}"{{.argument}}] CMD ["./{{.exeFile}}"{{.argument}}]
` `

View File

@@ -19,13 +19,13 @@ import (
"github.com/tal-tech/go-zero/tools/goctl/configgen" "github.com/tal-tech/go-zero/tools/goctl/configgen"
"github.com/tal-tech/go-zero/tools/goctl/docker" "github.com/tal-tech/go-zero/tools/goctl/docker"
model "github.com/tal-tech/go-zero/tools/goctl/model/sql/command" model "github.com/tal-tech/go-zero/tools/goctl/model/sql/command"
rpc "github.com/tal-tech/go-zero/tools/goctl/rpc/command" rpc "github.com/tal-tech/go-zero/tools/goctl/rpc/cli"
"github.com/tal-tech/go-zero/tools/goctl/tpl" "github.com/tal-tech/go-zero/tools/goctl/tpl"
"github.com/urfave/cli" "github.com/urfave/cli"
) )
var ( var (
BuildVersion = "20201021" BuildVersion = "20201108"
commands = []cli.Command{ commands = []cli.Command{
{ {
Name: "api", Name: "api",
@@ -188,16 +188,12 @@ var (
}, },
{ {
Name: "docker", Name: "docker",
Usage: "generate Dockerfile and Makefile", Usage: "generate Dockerfile",
Flags: []cli.Flag{ Flags: []cli.Flag{
cli.StringFlag{ cli.StringFlag{
Name: "go", Name: "go",
Usage: "the file that contains main function", Usage: "the file that contains main function",
}, },
cli.StringFlag{
Name: "namespace, n",
Usage: "which namespace of kubernetes to deploy the service",
},
}, },
Action: docker.DockerCommand, Action: docker.DockerCommand,
}, },
@@ -211,7 +207,7 @@ var (
Flags: []cli.Flag{ Flags: []cli.Flag{
cli.BoolFlag{ cli.BoolFlag{
Name: "idea", Name: "idea",
Usage: "whether the command execution environment is from idea plugin. [option]", Usage: "whether the command execution environment is from idea plugin. [optional]",
}, },
}, },
Action: rpc.RpcNew, Action: rpc.RpcNew,
@@ -224,10 +220,6 @@ var (
Name: "out, o", Name: "out, o",
Usage: "the target path of proto", Usage: "the target path of proto",
}, },
cli.BoolFlag{
Name: "idea",
Usage: "whether the command execution environment is from idea plugin. [option]",
},
}, },
Action: rpc.RpcTemplate, Action: rpc.RpcTemplate,
}, },
@@ -239,17 +231,17 @@ var (
Name: "src, s", Name: "src, s",
Usage: "the file path of the proto source file", Usage: "the file path of the proto source file",
}, },
cli.StringFlag{ cli.StringSliceFlag{
Name: "dir, d", Name: "proto_path, I",
Usage: `the target path of the code,default path is "${pwd}". [option]`, Usage: `native command of protoc, specify the directory in which to search for imports. [optional]`,
}, },
cli.StringFlag{ cli.StringFlag{
Name: "service, srv", Name: "dir, d",
Usage: `the name of rpc service. [option]`, Usage: `the target path of the code`,
}, },
cli.BoolFlag{ cli.BoolFlag{
Name: "idea", Name: "idea",
Usage: "whether the command execution environment is from idea plugin. [option]", Usage: "whether the command execution environment is from idea plugin. [optional]",
}, },
}, },
Action: rpc.Rpc, Action: rpc.Rpc,
@@ -313,7 +305,7 @@ var (
}, },
cli.StringFlag{ cli.StringFlag{
Name: "style", Name: "style",
Usage: "the file naming style, lower|camel|underline,default is lower", Usage: "the file naming style, lower|camel|snake, default is lower",
}, },
cli.BoolFlag{ cli.BoolFlag{
Name: "idea", Name: "idea",

View File

@@ -26,7 +26,8 @@ goctl model 为go-zero下的工具模块中的组件之一目前支持识别m
* 生成代码示例 * 生成代码示例
``` go ```go
package model package model
import ( import (
@@ -48,9 +49,9 @@ goctl model 为go-zero下的工具模块中的组件之一目前支持识别m
userRowsExpectAutoSet = strings.Join(stringx.Remove(userFieldNames, "id", "create_time", "update_time"), ",") userRowsExpectAutoSet = strings.Join(stringx.Remove(userFieldNames, "id", "create_time", "update_time"), ",")
userRowsWithPlaceHolder = strings.Join(stringx.Remove(userFieldNames, "id", "create_time", "update_time"), "=?,") + "=?" userRowsWithPlaceHolder = strings.Join(stringx.Remove(userFieldNames, "id", "create_time", "update_time"), "=?,") + "=?"
cacheUserMobilePrefix = "cache#User#mobile#"
cacheUserIdPrefix = "cache#User#id#" cacheUserIdPrefix = "cache#User#id#"
cacheUserNamePrefix = "cache#User#name#" cacheUserNamePrefix = "cache#User#name#"
cacheUserMobilePrefix = "cache#User#mobile#"
) )
type ( type (
@@ -71,23 +72,28 @@ goctl model 为go-zero下的工具模块中的组件之一目前支持识别m
} }
) )
func NewUserModel(conn sqlx.SqlConn, c cache.CacheConf, table string) *UserModel { func NewUserModel(conn sqlx.SqlConn, c cache.CacheConf) *UserModel {
return &UserModel{ return &UserModel{
CachedConn: sqlc.NewConn(conn, c), CachedConn: sqlc.NewConn(conn, c),
table: table, table: "user",
} }
} }
func (m *UserModel) Insert(data User) (sql.Result, error) { func (m *UserModel) Insert(data User) (sql.Result, error) {
query := `insert into ` + m.table + `(` + userRowsExpectAutoSet + `) value (?, ?, ?, ?, ?)` userNameKey := fmt.Sprintf("%s%v", cacheUserNamePrefix, data.Name)
return m.ExecNoCache(query, data.Name, data.Password, data.Mobile, data.Gender, data.Nickname) userMobileKey := fmt.Sprintf("%s%v", cacheUserMobilePrefix, data.Mobile)
ret, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) {
query := fmt.Sprintf("insert into %s (%s) values (?, ?, ?, ?, ?)", m.table, userRowsExpectAutoSet)
return conn.Exec(query, data.Name, data.Password, data.Mobile, data.Gender, data.Nickname)
}, userNameKey, userMobileKey)
return ret, err
} }
func (m *UserModel) FindOne(id int64) (*User, error) { func (m *UserModel) FindOne(id int64) (*User, error) {
userIdKey := fmt.Sprintf("%s%v", cacheUserIdPrefix, id) userIdKey := fmt.Sprintf("%s%v", cacheUserIdPrefix, id)
var resp User var resp User
err := m.QueryRow(&resp, userIdKey, func(conn sqlx.SqlConn, v interface{}) error { err := m.QueryRow(&resp, userIdKey, func(conn sqlx.SqlConn, v interface{}) error {
query := `select ` + userRows + ` from ` + m.table + ` where id = ? limit 1` query := fmt.Sprintf("select %s from %s where id = ? limit 1", userRows, m.table)
return conn.QueryRow(v, query, id) return conn.QueryRow(v, query, id)
}) })
switch err { switch err {
@@ -103,18 +109,13 @@ goctl model 为go-zero下的工具模块中的组件之一目前支持识别m
func (m *UserModel) FindOneByName(name string) (*User, error) { func (m *UserModel) FindOneByName(name string) (*User, error) {
userNameKey := fmt.Sprintf("%s%v", cacheUserNamePrefix, name) userNameKey := fmt.Sprintf("%s%v", cacheUserNamePrefix, name)
var resp User var resp User
err := m.QueryRowIndex(&resp, userNameKey, func(primary interface{}) string { err := m.QueryRowIndex(&resp, userNameKey, m.formatPrimary, func(conn sqlx.SqlConn, v interface{}) (i interface{}, e error) {
return fmt.Sprintf("%s%v", cacheUserIdPrefix, primary) query := fmt.Sprintf("select %s from %s where name = ? limit 1", userRows, m.table)
}, func(conn sqlx.SqlConn, v interface{}) (i interface{}, e error) {
query := `select ` + userRows + ` from ` + m.table + ` where name = ? limit 1`
if err := conn.QueryRow(&resp, query, name); err != nil { if err := conn.QueryRow(&resp, query, name); err != nil {
return nil, err return nil, err
} }
return resp.Id, nil return resp.Id, nil
}, func(conn sqlx.SqlConn, v, primary interface{}) error { }, m.queryPrimary)
query := `select ` + userRows + ` from ` + m.table + ` where id = ? limit 1`
return conn.QueryRow(v, query, primary)
})
switch err { switch err {
case nil: case nil:
return &resp, nil return &resp, nil
@@ -128,18 +129,13 @@ goctl model 为go-zero下的工具模块中的组件之一目前支持识别m
func (m *UserModel) FindOneByMobile(mobile string) (*User, error) { func (m *UserModel) FindOneByMobile(mobile string) (*User, error) {
userMobileKey := fmt.Sprintf("%s%v", cacheUserMobilePrefix, mobile) userMobileKey := fmt.Sprintf("%s%v", cacheUserMobilePrefix, mobile)
var resp User var resp User
err := m.QueryRowIndex(&resp, userMobileKey, func(primary interface{}) string { err := m.QueryRowIndex(&resp, userMobileKey, m.formatPrimary, func(conn sqlx.SqlConn, v interface{}) (i interface{}, e error) {
return fmt.Sprintf("%s%v", cacheUserIdPrefix, primary) query := fmt.Sprintf("select %s from %s where mobile = ? limit 1", userRows, m.table)
}, func(conn sqlx.SqlConn, v interface{}) (i interface{}, e error) {
query := `select ` + userRows + ` from ` + m.table + ` where mobile = ? limit 1`
if err := conn.QueryRow(&resp, query, mobile); err != nil { if err := conn.QueryRow(&resp, query, mobile); err != nil {
return nil, err return nil, err
} }
return resp.Id, nil return resp.Id, nil
}, func(conn sqlx.SqlConn, v, primary interface{}) error { }, m.queryPrimary)
query := `select ` + userRows + ` from ` + m.table + ` where id = ? limit 1`
return conn.QueryRow(v, query, primary)
})
switch err { switch err {
case nil: case nil:
return &resp, nil return &resp, nil
@@ -153,7 +149,7 @@ goctl model 为go-zero下的工具模块中的组件之一目前支持识别m
func (m *UserModel) Update(data User) error { func (m *UserModel) Update(data User) error {
userIdKey := fmt.Sprintf("%s%v", cacheUserIdPrefix, data.Id) userIdKey := fmt.Sprintf("%s%v", cacheUserIdPrefix, data.Id)
_, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) { _, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) {
query := `update ` + m.table + ` set ` + userRowsWithPlaceHolder + ` where id = ?` query := fmt.Sprintf("update %s set %s where id = ?", m.table, userRowsWithPlaceHolder)
return conn.Exec(query, data.Name, data.Password, data.Mobile, data.Gender, data.Nickname, data.Id) return conn.Exec(query, data.Name, data.Password, data.Mobile, data.Gender, data.Nickname, data.Id)
}, userIdKey) }, userIdKey)
return err return err
@@ -164,16 +160,26 @@ goctl model 为go-zero下的工具模块中的组件之一目前支持识别m
if err != nil { if err != nil {
return err return err
} }
userMobileKey := fmt.Sprintf("%s%v", cacheUserMobilePrefix, data.Mobile)
userIdKey := fmt.Sprintf("%s%v", cacheUserIdPrefix, id) userIdKey := fmt.Sprintf("%s%v", cacheUserIdPrefix, id)
userNameKey := fmt.Sprintf("%s%v", cacheUserNamePrefix, data.Name) userNameKey := fmt.Sprintf("%s%v", cacheUserNamePrefix, data.Name)
userMobileKey := fmt.Sprintf("%s%v", cacheUserMobilePrefix, data.Mobile)
_, err = m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) { _, err = m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) {
query := `delete from ` + m.table + ` where id = ?` query := fmt.Sprintf("delete from %s where id = ?", m.table)
return conn.Exec(query, id) return conn.Exec(query, id)
}, userIdKey, userNameKey, userMobileKey) }, userMobileKey, userIdKey, userNameKey)
return err return err
} }
```
func (m *UserModel) formatPrimary(primary interface{}) string {
return fmt.Sprintf("%s%v", cacheUserIdPrefix, primary)
}
func (m *UserModel) queryPrimary(conn sqlx.SqlConn, v, primary interface{}) error {
query := fmt.Sprintf("select %s from %s where id = ? limit 1", userRows, m.table)
return conn.QueryRow(v, query, primary)
}
```
## 用法 ## 用法
@@ -212,16 +218,18 @@ OPTIONS:
``` ```
NAME: NAME:
goctl model mysql ddl - generate mysql model from ddl goctl model mysql ddl - generate mysql model from ddl
USAGE: USAGE:
goctl model mysql ddl [command options] [arguments...] goctl model mysql ddl [command options] [arguments...]
OPTIONS:
--src value, -s value the path or path globbing patterns of the ddl
--dir value, -d value the target dir
--style value the file naming style, lower|camel|underline,default is lower
--cache, -c generate code with cache [optional]
--idea for idea plugin [optional]
OPTIONS:
--src value, -s value the path or path globbing patterns of the ddl
--dir value, -d value the target dir
--cache, -c generate code with cache [optional]
--idea for idea plugin [optional]
``` ```
* datasource * datasource
@@ -233,22 +241,26 @@ OPTIONS:
help help
``` ```
NAME: NAME:
goctl model mysql datasource - generate model from datasource goctl model mysql datasource - generate model from datasource
USAGE: USAGE:
goctl model mysql datasource [command options] [arguments...] goctl model mysql datasource [command options] [arguments...]
OPTIONS:
--url value the data source of database,like "root:password@tcp(127.0.0.1:3306)/database
--table value, -t value the table or table globbing patterns in the database
--cache, -c generate code with cache [optional]
--dir value, -d value the target dir
--style value the file naming style, lower|camel|snake, default is lower
--idea for idea plugin [optional]
OPTIONS:
--url value the data source of database,like "root:password@tcp(127.0.0.1:3306)/database
--table value, -t value the table or table globbing patterns in the database
--cache, -c generate code with cache [optional]
--dir value, -d value the target dir
--idea for idea plugin [optional]
``` ```
示例用法请参考[用法](./example/generator.sh) 示例用法请参考[用法](./example/generator.sh)
> NOTE: goctl model mysql ddl/datasource 均新增了一个`--style`参数,用于标记文件命名风格。
目前仅支持redis缓存如果选择带缓存模式即生成的`FindOne(ByXxx)`&`Delete`代码会生成带缓存逻辑的代码目前仅支持单索引字段除全文索引外对于联合索引我们默认认为不需要带缓存且不属于通用型代码因此没有放在代码生成行列如example中user表中的`id`、`name`、`mobile`字段均属于单字段索引。 目前仅支持redis缓存如果选择带缓存模式即生成的`FindOne(ByXxx)`&`Delete`代码会生成带缓存逻辑的代码目前仅支持单索引字段除全文索引外对于联合索引我们默认认为不需要带缓存且不属于通用型代码因此没有放在代码生成行列如example中user表中的`id`、`name`、`mobile`字段均属于单字段索引。
* 不带缓存模式 * 不带缓存模式

View File

@@ -40,7 +40,7 @@ func MysqlDDL(ctx *cli.Context) error {
} }
switch namingStyle { switch namingStyle {
case gen.NamingLower, gen.NamingCamel, gen.NamingUnderline: case gen.NamingLower, gen.NamingCamel, gen.NamingSnake:
case "": case "":
namingStyle = gen.NamingLower namingStyle = gen.NamingLower
default: default:
@@ -87,7 +87,7 @@ func MyDataSource(ctx *cli.Context) error {
} }
switch namingStyle { switch namingStyle {
case gen.NamingLower, gen.NamingCamel, gen.NamingUnderline: case gen.NamingLower, gen.NamingCamel, gen.NamingSnake:
case "": case "":
namingStyle = gen.NamingLower namingStyle = gen.NamingLower
default: default:

View File

@@ -8,30 +8,36 @@ import (
var ( var (
commonMysqlDataTypeMap = map[string]string{ commonMysqlDataTypeMap = map[string]string{
// For consistency, all integer types are converted to int64 // For consistency, all integer types are converted to int64
"tinyint": "int64", // number
"smallint": "int64", "bool": "int64",
"mediumint": "int64", "boolean": "int64",
"int": "int64", "tinyint": "int64",
"integer": "int64", "smallint": "int64",
"bigint": "int64", "mediumint": "int64",
"float": "float64", "int": "int64",
"double": "float64", "integer": "int64",
"decimal": "float64", "bigint": "int64",
"date": "time.Time", "float": "float64",
"time": "string", "double": "float64",
"year": "int64", "decimal": "float64",
"datetime": "time.Time", // date&time
"timestamp": "time.Time", "date": "time.Time",
"datetime": "time.Time",
"timestamp": "time.Time",
"time": "string",
"year": "int64",
// string
"char": "string", "char": "string",
"varchar": "string", "varchar": "string",
"tinyblob": "string", "binary": "string",
"varbinary": "string",
"tinytext": "string", "tinytext": "string",
"blob": "string",
"text": "string", "text": "string",
"mediumblob": "string",
"mediumtext": "string", "mediumtext": "string",
"longblob": "string",
"longtext": "string", "longtext": "string",
"enum": "string",
"set": "string",
"json": "string",
} }
) )

View File

@@ -19,7 +19,7 @@ const (
createTableFlag = `(?m)^(?i)CREATE\s+TABLE` // ignore case createTableFlag = `(?m)^(?i)CREATE\s+TABLE` // ignore case
NamingLower = "lower" NamingLower = "lower"
NamingCamel = "camel" NamingCamel = "camel"
NamingUnderline = "underline" NamingSnake = "snake"
) )
type ( type (
@@ -81,7 +81,7 @@ func (g *defaultGenerator) Start(withCache bool) error {
switch g.namingStyle { switch g.namingStyle {
case NamingCamel: case NamingCamel:
name = fmt.Sprintf("%sModel.go", tn.ToCamel()) name = fmt.Sprintf("%sModel.go", tn.ToCamel())
case NamingUnderline: case NamingSnake:
name = fmt.Sprintf("%s_model.go", tn.ToSnake()) name = fmt.Sprintf("%s_model.go", tn.ToSnake())
} }
filename := filepath.Join(dirAbs, name) filename := filepath.Join(dirAbs, name)

View File

@@ -1,6 +1,8 @@
package gen package gen
import ( import (
"os"
"path/filepath"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -8,27 +10,55 @@ import (
) )
var ( var (
source = "CREATE TABLE `test_user_info` (\n `id` bigint NOT NULL AUTO_INCREMENT,\n `nanosecond` bigint NOT NULL DEFAULT '0',\n `data` varchar(255) DEFAULT '',\n `create_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP,\n `update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,\n PRIMARY KEY (`id`),\n UNIQUE KEY `nanosecond_unique` (`nanosecond`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci;" source = "CREATE TABLE `test_user_info` (\n `id` bigint NOT NULL AUTO_INCREMENT,\n `nanosecond` bigint NOT NULL DEFAULT '0',\n `data` varchar(255) DEFAULT '',\n `content` json DEFAULT NULL,\n `create_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP,\n `update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,\n PRIMARY KEY (`id`),\n UNIQUE KEY `nanosecond_unique` (`nanosecond`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci;"
) )
func TestCacheModel(t *testing.T) { func TestCacheModel(t *testing.T) {
logx.Disable() logx.Disable()
_ = Clean() _ = Clean()
g := NewDefaultGenerator(source, "./testmodel/cache", NamingLower) dir, _ := filepath.Abs("./testmodel")
cacheDir := filepath.Join(dir, "cache")
noCacheDir := filepath.Join(dir, "nocache")
defer func() {
_ = os.RemoveAll(dir)
}()
g := NewDefaultGenerator(source, cacheDir, NamingLower)
err := g.Start(true) err := g.Start(true)
assert.Nil(t, err) assert.Nil(t, err)
g = NewDefaultGenerator(source, "./testmodel/nocache", NamingLower) assert.True(t, func() bool {
_, err := os.Stat(filepath.Join(cacheDir, "testuserinfomodel.go"))
return err == nil
}())
g = NewDefaultGenerator(source, noCacheDir, NamingLower)
err = g.Start(false) err = g.Start(false)
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, func() bool {
_, err := os.Stat(filepath.Join(noCacheDir, "testuserinfomodel.go"))
return err == nil
}())
} }
func TestNamingModel(t *testing.T) { func TestNamingModel(t *testing.T) {
logx.Disable() logx.Disable()
_ = Clean() _ = Clean()
g := NewDefaultGenerator(source, "./testmodel/camel", NamingCamel) dir, _ := filepath.Abs("./testmodel")
camelDir := filepath.Join(dir, "camel")
snakeDir := filepath.Join(dir, "snake")
defer func() {
_ = os.RemoveAll(dir)
}()
g := NewDefaultGenerator(source, camelDir, NamingCamel)
err := g.Start(true) err := g.Start(true)
assert.Nil(t, err) assert.Nil(t, err)
g = NewDefaultGenerator(source, "./testmodel/snake", NamingUnderline) assert.True(t, func() bool {
_, err := os.Stat(filepath.Join(camelDir, "TestUserInfoModel.go"))
return err == nil
}())
g = NewDefaultGenerator(source, snakeDir, NamingSnake)
err = g.Start(true) err = g.Start(true)
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, func() bool {
_, err := os.Stat(filepath.Join(snakeDir, "test_user_info_model.go"))
return err == nil
}())
} }

View File

@@ -1,20 +0,0 @@
# Change log
## 2020-10-19
* 增加template
## 2020-09-10
* rpc greet服务一键生成
* 修复相对路径生成rpc服务package引入错误bug
* 移除`--shared`参数
## 2020-08-29
* 新增支持windows生成
## 2020-08-27
* 新增支持rpc模板生成
* 新增支持rpc服务生成

View File

@@ -7,9 +7,8 @@ Goctl Rpc是`goctl`脚手架下的一个rpc服务代码生成模块支持prot
* 简单易用 * 简单易用
* 快速提升开发效率 * 快速提升开发效率
* 出错率低 * 出错率低
* 支持基于main proto作为相对路径的import * 贴近protoc
* 支持map、enum类型
* 支持any类型
## 快速开始 ## 快速开始
@@ -19,44 +18,64 @@ Goctl Rpc是`goctl`脚手架下的一个rpc服务代码生成模块支持prot
如生成greet rpc服务 如生成greet rpc服务
```shell script ```Bash
goctl rpc new greet goctl rpc new greet
``` ```
执行后代码结构如下: 执行后代码结构如下:
```golang ```golang
└── greet .
├── etc ├── etc // yaml配置文件
│   └── greet.yaml └── greet.yaml
├── go.mod ├── go.mod
├── go.sum ├── greet // pb.go文件夹①
── greet ── greet.pb.go
│   ├── greet.go ├── greet.go // main函数
│   ├── greet_mock.go ├── greet.proto // proto 文件
│   └── types.go ├── greetclient // call logic ②
── greet.go ── greet.go
├── greet.proto └── internal
├── internal ├── config // yaml配置对应的实体
  ── config ── config.go
│   │   └── config.go ├── logic // 业务代码
   ├── logic └── pinglogic.go
│   │   └── pinglogic.go ├── server // rpc server
   ├── server └── greetserver.go
│   │   └── greetserver.go └── svc // 依赖资源
│   └── svc └── servicecontext.go
│   └── servicecontext.go
└── pb
└── greet.pb.go
``` ```
rpc一键生成常见问题解决见 <a href="#常见问题解决">常见问题解决</a> > ① pb文件夹名老版本文件夹固定为pb称取自于proto文件中option go_package的值最后一层级按照一定格式进行转换若无此声明则取自于package的值大致代码如下
```go
if option.Name == "go_package" {
ret.GoPackage = option.Constant.Source
}
...
if len(ret.GoPackage) == 0 {
ret.GoPackage = ret.Package.Name
}
ret.PbPackage = GoSanitized(filepath.Base(ret.GoPackage))
...
```
> GoSanitized方法请参考google.golang.org/protobuf@v1.25.0/internal/strs/strings.go:71
> ② call 层文件夹名称取自于proto中service的名称如该sercice的名称和pb文件夹名称相等则会在srervice后面补充client进行区分使pb和call分隔。
```go
if strings.ToLower(proto.Service.Name) == strings.ToLower(proto.GoPackage) {
callDir = filepath.Join(ctx.WorkDir, strings.ToLower(stringx.From(proto.Service.Name+"_client").ToCamel()))
}
```
rpc一键生成常见问题解决见 <a href="#常见问题解决">常见问题解决</a>
### 方式二通过指定proto生成rpc服务 ### 方式二通过指定proto生成rpc服务
* 生成proto模板 * 生成proto模板
```shell script ```Bash
goctl rpc template -o=user.proto goctl rpc template -o=user.proto
``` ```
@@ -87,35 +106,10 @@ rpc一键生成常见问题解决见 <a href="#常见问题解决">常见问题
* 生成rpc服务代码 * 生成rpc服务代码
```shell script ```Bash
goctl rpc proto -src=user.proto goctl rpc proto -src=user.proto
``` ```
代码tree
```Plain Text
user
├── etc
│   └── user.json
├── internal
│   ├── config
│   │   └── config.go
│   ├── handler
│   │   ├── loginhandler.go
│   ├── logic
│   │   └── loginlogic.go
│   └── svc
│   └── servicecontext.go
├── pb
│   └── user.pb.go
├── shared
│   ├── mockusermodel.go
│   ├── types.go
│   └── usermodel.go
├── user.go
└── user.proto
```
## 准备工作 ## 准备工作
* 安装了go环境 * 安装了go环境
@@ -126,11 +120,11 @@ rpc一键生成常见问题解决见 <a href="#常见问题解决">常见问题
### rpc服务生成用法 ### rpc服务生成用法
```shell script ```Bash
goctl rpc proto -h goctl rpc proto -h
``` ```
```shell script ```Bash
NAME: NAME:
goctl rpc proto - generate rpc from proto goctl rpc proto - generate rpc from proto
@@ -139,35 +133,22 @@ USAGE:
OPTIONS: OPTIONS:
--src value, -s value the file path of the proto source file --src value, -s value the file path of the proto source file
--dir value, -d value the target path of the code,default path is "${pwd}". [option] --proto_path value, -I value native command of protoc, specify the directory in which to search for imports. [optional]
--service value, --srv value the name of rpc service. [option] --dir value, -d value the target path of the code
--idea whether the command execution environment is from idea plugin. [option] --idea whether the command execution environment is from idea plugin. [optional]
``` ```
### 参数说明 ### 参数说明
* --src 必填proto数据源目前暂时支持单个proto文件生成,这里不支持(不建议)外部依赖 * --src 必填proto数据源目前暂时支持单个proto文件生成
* --dir 非必填默认为proto文件所在目录生成代码的目标目录 * --proto_path 可选protoc原生子命令用于指定proto import从何处查找可指定多个路径,如`goctl rpc -I={path1} -I={path2} ...`,在没有import时可不填。当前proto路径不用指定已经内置`-I`的详细用法请参考`protoc -h`
* --service 服务名称非必填默认为proto文件所在目录名称但是如果proto所在目录为一下结构 * --dir 可选默认为proto文件所在目录生成代码的目标目录
* --idea 可选是否为idea插件中执行终端执行可以忽略
```shell script
user
├── cmd
│   └── rpc
│   └── user.proto
```
则服务名称亦为user而非proto所在文件夹名称了这里推荐使用这种结构可以方便在同一个服务名下建立不同类型的服务(api、rpc、mq等),便于代码管理与维护。
> 注意这里的shared文件夹名称将会是代码中的package名称。
* --idea 非必填是否为idea插件中执行保留字段终端执行可以忽略
### 开发人员需要做什么 ### 开发人员需要做什么
关注业务代码编写将重复性、与业务无关的工作交给goctl生成好rpc服务代码后人员仅需要修改 关注业务代码编写将重复性、与业务无关的工作交给goctl生成好rpc服务代码后人员仅需要修改
* 服务中的配置文件编写(etc/xx.json、internal/config/config.go) * 服务中的配置文件编写(etc/xx.json、internal/config/config.go)
* 服务中业务逻辑编写(internal/logic/xxlogic.go) * 服务中业务逻辑编写(internal/logic/xxlogic.go)
@@ -193,69 +174,54 @@ OPTIONS:
的标识,请注意不要将也写业务性代码写在里面。 的标识,请注意不要将也写业务性代码写在里面。
## any和import支持 ## proto import
* 支持any类型声明 * 对于rpc中的requestType和returnType必须在main proto文件定义对于proto中的message可以像protoc一样import其他proto文件。
* 支持import其他proto文件
any类型固定import为`google/protobuf/any.proto`,且从${GOPATH}/src中查找proto的import支持main proto的相对路径的import且与proto文件对应的pb.go文件必须在proto目录中能被找到。不支持工程外的其他proto文件import。 proto示例:
> ⚠️注意: 不支持proto嵌套import被import的proto文件不支持import。 ### 错误import
```proto
### import书写格式
import书写格式
```golang
// @{package_of_pb}
import {proto_omport}
```
@{package_of_pb}pb文件的真实import目录。
{proto_omport}proto import
demo中的
```golang
// @greet/base
import "base/base.proto";
```
工程目录结构如下
```
greet
│   ├── base
│   │   ├── base.pb.go
│   │   └── base.proto
│   ├── demo.proto
│   ├── go.mod
│   └── go.sum
```
demo
```golang
syntax = "proto3"; syntax = "proto3";
import "google/protobuf/any.proto";
// @greet/base
import "base/base.proto";
package stream;
package greet;
enum Gender{ import "base/common.proto"
UNKNOWN = 0;
MAN = 1; message Request {
WOMAN = 2; string ping = 1;
} }
message StreamResp{ message Response {
string name = 2; string pong = 1;
Gender gender = 3;
google.protobuf.Any details = 5;
base.StreamReq req = 6;
} }
service StreamGreeter {
rpc greet(base.StreamReq) returns (StreamResp); service Greet {
rpc Ping(base.In) returns(base.Out);// request和return 不支持import
} }
``` ```
### 正确import
```proto
syntax = "proto3";
package greet;
import "base/common.proto"
message Request {
base.In in = 1;// 支持import
}
message Response {
base.Out out = 2;// 支持import
}
service Greet {
rpc Ping(Request) returns(Response);
}
```
## 常见问题解决(go mod工程) ## 常见问题解决(go mod工程)

View File

@@ -1,108 +0,0 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// source: base.proto
package base
import (
fmt "fmt"
proto "github.com/golang/protobuf/proto"
math "math"
)
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package
type IdRequest struct {
Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *IdRequest) Reset() { *m = IdRequest{} }
func (m *IdRequest) String() string { return proto.CompactTextString(m) }
func (*IdRequest) ProtoMessage() {}
func (*IdRequest) Descriptor() ([]byte, []int) {
return fileDescriptor_db1b6b0986796150, []int{0}
}
func (m *IdRequest) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_IdRequest.Unmarshal(m, b)
}
func (m *IdRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_IdRequest.Marshal(b, m, deterministic)
}
func (m *IdRequest) XXX_Merge(src proto.Message) {
xxx_messageInfo_IdRequest.Merge(m, src)
}
func (m *IdRequest) XXX_Size() int {
return xxx_messageInfo_IdRequest.Size(m)
}
func (m *IdRequest) XXX_DiscardUnknown() {
xxx_messageInfo_IdRequest.DiscardUnknown(m)
}
var xxx_messageInfo_IdRequest proto.InternalMessageInfo
func (m *IdRequest) GetId() string {
if m != nil {
return m.Id
}
return ""
}
type EmptyResponse struct {
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *EmptyResponse) Reset() { *m = EmptyResponse{} }
func (m *EmptyResponse) String() string { return proto.CompactTextString(m) }
func (*EmptyResponse) ProtoMessage() {}
func (*EmptyResponse) Descriptor() ([]byte, []int) {
return fileDescriptor_db1b6b0986796150, []int{1}
}
func (m *EmptyResponse) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_EmptyResponse.Unmarshal(m, b)
}
func (m *EmptyResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_EmptyResponse.Marshal(b, m, deterministic)
}
func (m *EmptyResponse) XXX_Merge(src proto.Message) {
xxx_messageInfo_EmptyResponse.Merge(m, src)
}
func (m *EmptyResponse) XXX_Size() int {
return xxx_messageInfo_EmptyResponse.Size(m)
}
func (m *EmptyResponse) XXX_DiscardUnknown() {
xxx_messageInfo_EmptyResponse.DiscardUnknown(m)
}
var xxx_messageInfo_EmptyResponse proto.InternalMessageInfo
func init() {
proto.RegisterType((*IdRequest)(nil), "base.IdRequest")
proto.RegisterType((*EmptyResponse)(nil), "base.EmptyResponse")
}
func init() { proto.RegisterFile("base.proto", fileDescriptor_db1b6b0986796150) }
var fileDescriptor_db1b6b0986796150 = []byte{
// 91 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x4a, 0x4a, 0x2c, 0x4e,
0xd5, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x01, 0xb1, 0x95, 0xa4, 0xb9, 0x38, 0x3d, 0x53,
0x82, 0x52, 0x0b, 0x4b, 0x53, 0x8b, 0x4b, 0x84, 0xf8, 0xb8, 0x98, 0x32, 0x53, 0x24, 0x18, 0x15,
0x18, 0x35, 0x38, 0x83, 0x98, 0x32, 0x53, 0x94, 0xf8, 0xb9, 0x78, 0x5d, 0x73, 0x0b, 0x4a, 0x2a,
0x83, 0x52, 0x8b, 0x0b, 0xf2, 0xf3, 0x8a, 0x53, 0x93, 0xd8, 0xc0, 0x5a, 0x8d, 0x01, 0x01, 0x00,
0x00, 0xff, 0xff, 0xe1, 0x39, 0x3c, 0x22, 0x48, 0x00, 0x00, 0x00,
}

View File

@@ -1,11 +0,0 @@
syntax = "proto3";
package base;
message IdRequest {
string id = 1;
}
message EmptyResponse {
}

View File

@@ -0,0 +1,68 @@
package cli
import (
"errors"
"fmt"
"path/filepath"
"github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
"github.com/tal-tech/go-zero/tools/goctl/rpc/generator"
"github.com/urfave/cli"
)
// Rpc is to generate rpc service code from a proto file by specifying a proto file using flag src,
// you can specify a target folder for code generation, when the proto file has import, you can specify
// the import search directory through the proto_path command, for specific usage, please refer to protoc -h
func Rpc(c *cli.Context) error {
src := c.String("src")
out := c.String("dir")
protoImportPath := c.StringSlice("proto_path")
if len(src) == 0 {
return errors.New("missing -src")
}
if len(out) == 0 {
return errors.New("missing -dir")
}
g := generator.NewDefaultRpcGenerator()
return g.Generate(src, out, protoImportPath)
}
// RpcNew is to generate rpc greet service, this greet service can speed
// up your understanding of the zrpc service structure
func RpcNew(c *cli.Context) error {
name := c.Args().First()
ext := filepath.Ext(name)
if len(ext) > 0 {
return fmt.Errorf("unexpected ext: %s", ext)
}
protoName := name + ".proto"
filename := filepath.Join(".", name, protoName)
src, err := filepath.Abs(filename)
if err != nil {
return err
}
err = generator.ProtoTmpl(src)
if err != nil {
return err
}
workDir := filepath.Dir(src)
_, err = execx.Run("go mod init "+name, workDir)
if err != nil {
return err
}
g := generator.NewDefaultRpcGenerator()
return g.Generate(src, filepath.Dir(src), nil)
}
func RpcTemplate(c *cli.Context) error {
protoFile := c.String("o")
if len(protoFile) == 0 {
return errors.New("missing -o")
}
return generator.ProtoTmpl(protoFile)
}

View File

@@ -1,60 +0,0 @@
package command
import (
"fmt"
"os"
"path/filepath"
"github.com/tal-tech/go-zero/tools/goctl/rpc/ctx"
"github.com/tal-tech/go-zero/tools/goctl/rpc/gen"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/urfave/cli"
)
func Rpc(c *cli.Context) error {
rpcCtx := ctx.MustCreateRpcContextFromCli(c)
generator := gen.NewDefaultRpcGenerator(rpcCtx)
rpcCtx.Must(generator.Generate())
return nil
}
func RpcTemplate(c *cli.Context) error {
out := c.String("out")
idea := c.Bool("idea")
generator := gen.NewRpcTemplate(out, idea)
generator.MustGenerate(true)
return nil
}
func RpcNew(c *cli.Context) error {
idea := c.Bool("idea")
arg := c.Args().First()
if len(arg) == 0 {
arg = "greet"
}
abs, err := filepath.Abs(arg)
if err != nil {
return err
}
_, err = os.Stat(abs)
if err != nil {
if !os.IsNotExist(err) {
return err
}
err = util.MkdirIfNotExist(abs)
if err != nil {
return err
}
}
dir := filepath.Base(filepath.Clean(abs))
protoSrc := filepath.Join(abs, fmt.Sprintf("%v.proto", dir))
templateGenerator := gen.NewRpcTemplate(protoSrc, idea)
templateGenerator.MustGenerate(false)
rpcCtx := ctx.MustCreateRpcContext(protoSrc, "", "", idea)
generator := gen.NewDefaultRpcGenerator(rpcCtx)
rpcCtx.Must(generator.Generate())
return nil
}

View File

@@ -1,99 +0,0 @@
package ctx
import (
"fmt"
"path/filepath"
"runtime"
"strings"
"github.com/tal-tech/go-zero/core/logx"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/console"
"github.com/tal-tech/go-zero/tools/goctl/util/project"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
"github.com/tal-tech/go-zero/tools/goctl/vars"
"github.com/urfave/cli"
)
const (
flagSrc = "src"
flagDir = "dir"
flagService = "service"
flagIdea = "idea"
)
type RpcContext struct {
ProjectPath string
ProjectName stringx.String
ServiceName stringx.String
CurrentPath string
Module string
ProtoFileSrc string
ProtoSource string
TargetDir string
IsInGoEnv bool
console.Console
}
func MustCreateRpcContext(protoSrc, targetDir, serviceName string, idea bool) *RpcContext {
log := console.NewConsole(idea)
if stringx.From(protoSrc).IsEmptyOrSpace() {
log.Fatalln("expected proto source, but nothing found")
}
srcFp, err := filepath.Abs(protoSrc)
log.Must(err)
if !util.FileExists(srcFp) {
log.Fatalln("%s is not exists", srcFp)
}
current := filepath.Dir(srcFp)
if stringx.From(targetDir).IsEmptyOrSpace() {
targetDir = current
}
targetDirFp, err := filepath.Abs(targetDir)
log.Must(err)
if stringx.From(serviceName).IsEmptyOrSpace() {
serviceName = getServiceFromRpcStructure(targetDirFp)
}
serviceNameString := stringx.From(serviceName)
if serviceNameString.IsEmptyOrSpace() {
log.Fatalln("service name not found")
}
info, err := project.Prepare(targetDir, true)
log.Must(err)
return &RpcContext{
ProjectPath: info.Path,
ProjectName: stringx.From(info.Name),
ServiceName: serviceNameString,
CurrentPath: current,
Module: info.GoMod.Module,
ProtoFileSrc: srcFp,
ProtoSource: filepath.Base(srcFp),
TargetDir: targetDirFp,
IsInGoEnv: info.IsInGoEnv,
Console: log,
}
}
func MustCreateRpcContextFromCli(ctx *cli.Context) *RpcContext {
os := runtime.GOOS
switch os {
case vars.OsMac, vars.OsLinux, vars.OsWindows:
default:
logx.Must(fmt.Errorf("unexpected os: %s", os))
}
protoSrc := ctx.String(flagSrc)
targetDir := ctx.String(flagDir)
serviceName := ctx.String(flagService)
idea := ctx.Bool(flagIdea)
return MustCreateRpcContext(protoSrc, targetDir, serviceName, idea)
}
func getServiceFromRpcStructure(targetDir string) string {
targetDir = filepath.Clean(targetDir)
suffix := filepath.Join("cmd", "rpc")
return filepath.Base(strings.TrimSuffix(targetDir, suffix))
}

View File

@@ -6,7 +6,9 @@ import (
"fmt" "fmt"
"os/exec" "os/exec"
"runtime" "runtime"
"strings"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/vars" "github.com/tal-tech/go-zero/tools/goctl/vars"
) )
@@ -24,17 +26,17 @@ func Run(arg string, dir string) (string, error) {
if len(dir) > 0 { if len(dir) > 0 {
cmd.Dir = dir cmd.Dir = dir
} }
dtsout := new(bytes.Buffer) stdout := new(bytes.Buffer)
stderr := new(bytes.Buffer) stderr := new(bytes.Buffer)
cmd.Stdout = dtsout cmd.Stdout = stdout
cmd.Stderr = stderr cmd.Stderr = stderr
err := cmd.Run() err := cmd.Run()
if err != nil { if err != nil {
if stderr.Len() > 0 { if stderr.Len() > 0 {
return "", errors.New(stderr.String()) return "", errors.New(strings.TrimSuffix(stderr.String(), util.NL))
} }
return "", err return "", err
} }
return dtsout.String(), nil return strings.TrimSuffix(stdout.String(), util.NL), nil
} }

View File

@@ -1,93 +0,0 @@
package gen
import (
"github.com/logrusorgru/aurora"
"github.com/tal-tech/go-zero/tools/goctl/rpc/ctx"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
)
const (
dirTarget = "dirTarget"
dirConfig = "config"
dirEtc = "etc"
dirSvc = "svc"
dirServer = "server"
dirLogic = "logic"
dirPb = "pb"
dirInternal = "internal"
fileConfig = "config.go"
fileServiceContext = "servicecontext.go"
)
type defaultRpcGenerator struct {
dirM map[string]string
Ctx *ctx.RpcContext
ast *parser.PbAst
}
func NewDefaultRpcGenerator(ctx *ctx.RpcContext) *defaultRpcGenerator {
return &defaultRpcGenerator{
Ctx: ctx,
}
}
func (g *defaultRpcGenerator) Generate() (err error) {
g.Ctx.Info(aurora.Blue("-> goctl rpc reference documents: ").String() + "「https://github.com/tal-tech/zero-doc/blob/main/doc/goctl-rpc.md」")
g.Ctx.Warning("-> generating rpc code ...")
defer func() {
if err == nil {
g.Ctx.MarkDone()
}
}()
err = g.createDir()
if err != nil {
return
}
err = g.initGoMod()
if err != nil {
return
}
err = g.genEtc()
if err != nil {
return
}
err = g.genPb()
if err != nil {
return
}
err = g.genConfig()
if err != nil {
return
}
err = g.genSvc()
if err != nil {
return
}
err = g.genLogic()
if err != nil {
return
}
err = g.genHandler()
if err != nil {
return
}
err = g.genMain()
if err != nil {
return
}
err = g.genCall()
if err != nil {
return
}
return
}

View File

@@ -1,224 +0,0 @@
package gen
import (
"fmt"
"path/filepath"
"strings"
"github.com/tal-tech/go-zero/core/collection"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
)
const (
typesFilename = "types.go"
callTemplateText = `{{.head}}
//go:generate mockgen -destination ./{{.name}}_mock.go -package {{.filePackage}} -source $GOFILE
package {{.filePackage}}
import (
"context"
{{.package}}
"github.com/tal-tech/go-zero/core/jsonx"
"github.com/tal-tech/go-zero/zrpc"
)
type (
{{.serviceName}} interface {
{{.interface}}
}
default{{.serviceName}} struct {
cli zrpc.Client
}
)
func New{{.serviceName}}(cli zrpc.Client) {{.serviceName}} {
return &default{{.serviceName}}{
cli: cli,
}
}
{{.functions}}
`
callTemplateTypes = `{{.head}}
package {{.filePackage}}
import "errors"
var errJsonConvert = errors.New("json convert error")
{{.const}}
{{.types}}
`
callInterfaceFunctionTemplate = `{{if .hasComment}}{{.comment}}
{{end}}{{.method}}(ctx context.Context,in *{{.pbRequest}}) (*{{.pbResponse}},error)`
callFunctionTemplate = `
{{if .hasComment}}{{.comment}}{{end}}
func (m *default{{.rpcServiceName}}) {{.method}}(ctx context.Context,in *{{.pbRequestName}}) (*{{.pbResponse}}, error) {
var request {{.pbRequest}}
bts, err := jsonx.Marshal(in)
if err != nil {
return nil, errJsonConvert
}
err = jsonx.Unmarshal(bts, &request)
if err != nil {
return nil, errJsonConvert
}
client := {{.package}}.New{{.rpcServiceName}}Client(m.cli.Conn())
resp, err := client.{{.method}}(ctx, &request)
if err != nil{
return nil, err
}
var ret {{.pbResponse}}
bts, err = jsonx.Marshal(resp)
if err != nil{
return nil, errJsonConvert
}
err = jsonx.Unmarshal(bts, &ret)
if err != nil{
return nil, errJsonConvert
}
return &ret, nil
}
`
)
func (g *defaultRpcGenerator) genCall() error {
file := g.ast
if len(file.Service) == 0 {
return nil
}
if len(file.Service) > 1 {
return fmt.Errorf("we recommend only one service in a proto, currently %d", len(file.Service))
}
typeCode, err := file.GenTypesCode()
if err != nil {
return err
}
constLit, err := file.GenEnumCode()
if err != nil {
return err
}
service := file.Service[0]
callPath := filepath.Join(g.dirM[dirTarget], service.Name.Lower())
if err = util.MkdirIfNotExist(callPath); err != nil {
return err
}
filename := filepath.Join(callPath, typesFilename)
head := util.GetHead(g.Ctx.ProtoSource)
text, err := util.LoadTemplate(category, callTypesTemplateFile, callTemplateTypes)
if err != nil {
return err
}
err = util.With("types").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
"head": head,
"const": constLit,
"filePackage": service.Name.Lower(),
"serviceName": g.Ctx.ServiceName.Title(),
"lowerStartServiceName": g.Ctx.ServiceName.UnTitle(),
"types": typeCode,
}, filename, true)
if err != nil {
return err
}
filename = filepath.Join(callPath, fmt.Sprintf("%s.go", service.Name.Lower()))
functions, importList, err := g.genFunction(service)
if err != nil {
return err
}
iFunctions, err := g.getInterfaceFuncs(service)
if err != nil {
return err
}
text, err = util.LoadTemplate(category, callTemplateFile, callTemplateText)
if err != nil {
return err
}
err = util.With("shared").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
"name": service.Name.Lower(),
"head": head,
"filePackage": service.Name.Lower(),
"package": strings.Join(importList, util.NL),
"serviceName": service.Name.Title(),
"functions": strings.Join(functions, util.NL),
"interface": strings.Join(iFunctions, util.NL),
}, filename, true)
return err
}
func (g *defaultRpcGenerator) genFunction(service *parser.RpcService) ([]string, []string, error) {
file := g.ast
pkgName := file.Package
functions := make([]string, 0)
imports := collection.NewSet()
imports.AddStr(fmt.Sprintf(`%v "%v"`, pkgName, g.mustGetPackage(dirPb)))
for _, method := range service.Funcs {
imports.AddStr(g.ast.Imports[method.ParameterIn.Package])
text, err := util.LoadTemplate(category, callFunctionTemplateFile, callFunctionTemplate)
if err != nil {
return nil, nil, err
}
buffer, err := util.With("sharedFn").Parse(text).Execute(map[string]interface{}{
"rpcServiceName": service.Name.Title(),
"method": method.Name.Title(),
"package": pkgName,
"pbRequestName": method.ParameterIn.Name,
"pbRequest": method.ParameterIn.Expression,
"pbResponse": method.ParameterOut.Name,
"hasComment": method.HaveDoc(),
"comment": method.GetDoc(),
})
if err != nil {
return nil, nil, err
}
functions = append(functions, buffer.String())
}
return functions, imports.KeysStr(), nil
}
func (g *defaultRpcGenerator) getInterfaceFuncs(service *parser.RpcService) ([]string, error) {
functions := make([]string, 0)
for _, method := range service.Funcs {
text, err := util.LoadTemplate(category, callInterfaceFunctionTemplateFile, callInterfaceFunctionTemplate)
if err != nil {
return nil, err
}
buffer, err := util.With("interfaceFn").Parse(text).Execute(
map[string]interface{}{
"hasComment": method.HaveDoc(),
"comment": method.GetDoc(),
"method": method.Name.Title(),
"pbRequest": method.ParameterIn.Name,
"pbResponse": method.ParameterOut.Name,
})
if err != nil {
return nil, err
}
functions = append(functions, buffer.String())
}
return functions, nil
}

View File

@@ -1,54 +0,0 @@
package gen
import (
"path/filepath"
"runtime"
"strings"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/vars"
)
// target
// ├── etc
// ├── internal
// │   ├── config
// │   ├── handler
// │   ├── logic
// │   ├── pb
// │   └── svc
func (g *defaultRpcGenerator) createDir() error {
ctx := g.Ctx
m := make(map[string]string)
m[dirTarget] = ctx.TargetDir
m[dirEtc] = filepath.Join(ctx.TargetDir, dirEtc)
m[dirInternal] = filepath.Join(ctx.TargetDir, dirInternal)
m[dirConfig] = filepath.Join(ctx.TargetDir, dirInternal, dirConfig)
m[dirServer] = filepath.Join(ctx.TargetDir, dirInternal, dirServer)
m[dirLogic] = filepath.Join(ctx.TargetDir, dirInternal, dirLogic)
m[dirPb] = filepath.Join(ctx.TargetDir, dirPb)
m[dirSvc] = filepath.Join(ctx.TargetDir, dirInternal, dirSvc)
for _, d := range m {
err := util.MkdirIfNotExist(d)
if err != nil {
return err
}
}
g.dirM = m
return nil
}
func (g *defaultRpcGenerator) mustGetPackage(dir string) string {
target := g.dirM[dir]
projectPath := g.Ctx.ProjectPath
relativePath := strings.TrimPrefix(target, projectPath)
os := runtime.GOOS
switch os {
case vars.OsWindows:
relativePath = filepath.ToSlash(relativePath)
case vars.OsMac, vars.OsLinux:
default:
g.Ctx.Fatalln("unexpected os: %s", os)
}
return g.Ctx.Module + relativePath
}

View File

@@ -1,109 +0,0 @@
package gen
import (
"fmt"
"path/filepath"
"strings"
"github.com/tal-tech/go-zero/core/collection"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
)
const (
logicTemplate = `package logic
import (
"context"
{{.imports}}
"github.com/tal-tech/go-zero/core/logx"
)
type {{.logicName}} struct {
ctx context.Context
svcCtx *svc.ServiceContext
logx.Logger
}
func New{{.logicName}}(ctx context.Context,svcCtx *svc.ServiceContext) *{{.logicName}} {
return &{{.logicName}}{
ctx: ctx,
svcCtx: svcCtx,
Logger: logx.WithContext(ctx),
}
}
{{.functions}}
`
logicFunctionTemplate = `{{if .hasComment}}{{.comment}}{{end}}
func (l *{{.logicName}}) {{.method}} (in {{.request}}) ({{.response}}, error) {
// todo: add your logic here and delete this line
return &{{.responseType}}{}, nil
}
`
)
func (g *defaultRpcGenerator) genLogic() error {
logicPath := g.dirM[dirLogic]
protoPkg := g.ast.Package
service := g.ast.Service
for _, item := range service {
for _, method := range item.Funcs {
logicName := fmt.Sprintf("%slogic.go", method.Name.Lower())
filename := filepath.Join(logicPath, logicName)
functions, importList, err := g.genLogicFunction(protoPkg, method)
if err != nil {
return err
}
imports := collection.NewSet()
svcImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirSvc))
imports.AddStr(svcImport)
imports.AddStr(importList...)
text, err := util.LoadTemplate(category, logicTemplateFileFile, logicTemplate)
if err != nil {
return err
}
err = util.With("logic").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
"logicName": fmt.Sprintf("%sLogic", method.Name.Title()),
"functions": functions,
"imports": strings.Join(imports.KeysStr(), util.NL),
}, filename, false)
if err != nil {
return err
}
}
}
return nil
}
func (g *defaultRpcGenerator) genLogicFunction(packageName string, method *parser.Func) (string, []string, error) {
var functions = make([]string, 0)
var imports = collection.NewSet()
if method.ParameterIn.Package == packageName || method.ParameterOut.Package == packageName {
imports.AddStr(fmt.Sprintf(`%v "%v"`, packageName, g.mustGetPackage(dirPb)))
}
imports.AddStr(g.ast.Imports[method.ParameterIn.Package])
imports.AddStr(g.ast.Imports[method.ParameterOut.Package])
text, err := util.LoadTemplate(category, logicFuncTemplateFileFile, logicFunctionTemplate)
if err != nil {
return "", nil, err
}
buffer, err := util.With("fun").Parse(text).Execute(map[string]interface{}{
"logicName": fmt.Sprintf("%sLogic", method.Name.Title()),
"method": method.Name.Title(),
"request": method.ParameterIn.StarExpression,
"response": method.ParameterOut.StarExpression,
"responseType": method.ParameterOut.Expression,
"hasComment": method.HaveDoc(),
"comment": method.GetDoc(),
})
if err != nil {
return "", nil, err
}
functions = append(functions, buffer.String())
return strings.Join(functions, util.NL), imports.KeysStr(), nil
}

View File

@@ -1,85 +0,0 @@
package gen
import (
"fmt"
"path/filepath"
"strings"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
)
const mainTemplate = `{{.head}}
package main
import (
"flag"
"fmt"
{{.imports}}
"github.com/tal-tech/go-zero/core/conf"
"github.com/tal-tech/go-zero/zrpc"
"google.golang.org/grpc"
)
var configFile = flag.String("f", "etc/{{.serviceName}}.yaml", "the config file")
func main() {
flag.Parse()
var c config.Config
conf.MustLoad(*configFile, &c)
ctx := svc.NewServiceContext(c)
{{.srv}}
s := zrpc.MustNewServer(c.RpcServerConf, func(grpcServer *grpc.Server) {
{{.registers}}
})
defer s.Stop()
fmt.Printf("Starting rpc server at %s...\n", c.ListenOn)
s.Start()
}
`
func (g *defaultRpcGenerator) genMain() error {
mainPath := g.dirM[dirTarget]
file := g.ast
pkg := file.Package
fileName := filepath.Join(mainPath, fmt.Sprintf("%v.go", g.Ctx.ServiceName.Lower()))
imports := make([]string, 0)
pbImport := fmt.Sprintf(`%v "%v"`, pkg, g.mustGetPackage(dirPb))
svcImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirSvc))
remoteImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirServer))
configImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirConfig))
imports = append(imports, configImport, pbImport, remoteImport, svcImport)
srv, registers := g.genServer(pkg, file.Service)
head := util.GetHead(g.Ctx.ProtoSource)
text, err := util.LoadTemplate(category, mainTemplateFile, mainTemplate)
if err != nil {
return err
}
return util.With("main").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
"head": head,
"package": pkg,
"serviceName": g.Ctx.ServiceName.Lower(),
"srv": srv,
"registers": registers,
"imports": strings.Join(imports, util.NL),
}, fileName, false)
}
func (g *defaultRpcGenerator) genServer(pkg string, list []*parser.RpcService) (string, string) {
list1 := make([]string, 0)
list2 := make([]string, 0)
for _, item := range list {
name := item.Name.UnTitle()
list1 = append(list1, fmt.Sprintf("%sSrv := server.New%sServer(ctx)", name, item.Name.Title()))
list2 = append(list2, fmt.Sprintf("%s.Register%sServer(grpcServer, %sSrv)", pkg, item.Name.Title(), name))
}
return strings.Join(list1, util.NL), strings.Join(list2, util.NL)
}

View File

@@ -1,82 +0,0 @@
package gen
import (
"bytes"
"fmt"
"path/filepath"
"strings"
"github.com/tal-tech/go-zero/core/collection"
"github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
)
const (
protocCmd = "protoc"
grpcPluginCmd = "--go_out=plugins=grpc"
)
func (g *defaultRpcGenerator) genPb() error {
pbPath := g.dirM[dirPb]
// deprecated: containsAny will be removed in the feature
imports, containsAny, err := parser.ParseImport(g.Ctx.ProtoFileSrc)
if err != nil {
return err
}
err = g.protocGenGo(pbPath, imports)
if err != nil {
return err
}
ast, err := parser.Transfer(g.Ctx.ProtoFileSrc, pbPath, imports, g.Ctx.Console)
if err != nil {
return err
}
ast.ContainsAny = containsAny
if len(ast.Service) == 0 {
return fmt.Errorf("service not found")
}
g.ast = ast
return nil
}
func (g *defaultRpcGenerator) protocGenGo(target string, imports []*parser.Import) error {
dir := filepath.Dir(g.Ctx.ProtoFileSrc)
// cmd join,see the document of proto generating class @https://developers.google.com/protocol-buffers/docs/proto3#generating
// template: protoc -I=${import_path} -I=${other_import_path} -I=${...} --go_out=plugins=grpc,M${pb_package_kv}, M${...} :${target_dir}
// eg: protoc -I=${GOPATH}/src -I=. example.proto --go_out=plugins=grpc,Mbase/base.proto=github.com/go-zero/base.proto:.
// note: the external import out of the project which are found in ${GOPATH}/src so far.
buffer := new(bytes.Buffer)
buffer.WriteString(protocCmd + " ")
targetImportFiltered := collection.NewSet()
for _, item := range imports {
buffer.WriteString(fmt.Sprintf("-I=%s ", item.OriginalDir))
if len(item.BridgeImport) == 0 {
continue
}
targetImportFiltered.AddStr(item.BridgeImport)
}
buffer.WriteString("-I=${GOPATH}/src ")
buffer.WriteString(fmt.Sprintf("-I=%s %s ", dir, g.Ctx.ProtoFileSrc))
buffer.WriteString(grpcPluginCmd)
if targetImportFiltered.Count() > 0 {
buffer.WriteString(fmt.Sprintf(",%v", strings.Join(targetImportFiltered.KeysStr(), ",")))
}
buffer.WriteString(":" + target)
g.Ctx.Debug("-> " + buffer.String())
stdout, err := execx.Run(buffer.String(), "")
if err != nil {
return err
}
if len(stdout) > 0 {
g.Ctx.Info(stdout)
}
return nil
}

View File

@@ -1,116 +0,0 @@
package gen
import (
"fmt"
"path/filepath"
"strings"
"github.com/tal-tech/go-zero/core/collection"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
)
const (
serverTemplate = `{{.head}}
package server
import (
"context"
{{.imports}}
)
type {{.types}}
func New{{.server}}Server(svcCtx *svc.ServiceContext) *{{.server}}Server {
return &{{.server}}Server{
svcCtx: svcCtx,
}
}
{{.funcs}}
`
functionTemplate = `
{{if .hasComment}}{{.comment}}{{end}}
func (s *{{.server}}Server) {{.method}} (ctx context.Context, in {{.request}}) ({{.response}}, error) {
l := logic.New{{.logicName}}(ctx,s.svcCtx)
return l.{{.method}}(in)
}
`
typeFmt = `%sServer struct {
svcCtx *svc.ServiceContext
}`
)
func (g *defaultRpcGenerator) genHandler() error {
serverPath := g.dirM[dirServer]
file := g.ast
logicImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirLogic))
svcImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirSvc))
imports := collection.NewSet()
imports.AddStr(logicImport, svcImport)
head := util.GetHead(g.Ctx.ProtoSource)
for _, service := range file.Service {
filename := fmt.Sprintf("%vserver.go", service.Name.Lower())
serverFile := filepath.Join(serverPath, filename)
funcList, importList, err := g.genFunctions(service)
if err != nil {
return err
}
imports.AddStr(importList...)
text, err := util.LoadTemplate(category, serverTemplateFile, serverTemplate)
if err != nil {
return err
}
err = util.With("server").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
"head": head,
"types": fmt.Sprintf(typeFmt, service.Name.Title()),
"server": service.Name.Title(),
"imports": strings.Join(imports.KeysStr(), util.NL),
"funcs": strings.Join(funcList, util.NL),
}, serverFile, true)
if err != nil {
return err
}
}
return nil
}
func (g *defaultRpcGenerator) genFunctions(service *parser.RpcService) ([]string, []string, error) {
file := g.ast
pkg := file.Package
var functionList []string
imports := collection.NewSet()
for _, method := range service.Funcs {
if method.ParameterIn.Package == pkg || method.ParameterOut.Package == pkg {
imports.AddStr(fmt.Sprintf(`%v "%v"`, pkg, g.mustGetPackage(dirPb)))
}
imports.AddStr(g.ast.Imports[method.ParameterIn.Package])
imports.AddStr(g.ast.Imports[method.ParameterOut.Package])
text, err := util.LoadTemplate(category, serverFuncTemplateFile, functionTemplate)
if err != nil {
return nil, nil, err
}
buffer, err := util.With("func").Parse(text).Execute(map[string]interface{}{
"server": service.Name.Title(),
"logicName": fmt.Sprintf("%sLogic", method.Name.Title()),
"method": method.Name.Title(),
"package": pkg,
"request": method.ParameterIn.StarExpression,
"response": method.ParameterOut.StarExpression,
"hasComment": method.HaveDoc(),
"comment": method.GetDoc(),
})
if err != nil {
return nil, nil, err
}
functionList = append(functionList, buffer.String())
}
return functionList, imports.KeysStr(), nil
}

View File

@@ -1,22 +0,0 @@
package gen
import (
"fmt"
"github.com/tal-tech/go-zero/core/logx"
"github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
)
func (g *defaultRpcGenerator) initGoMod() error {
if !g.Ctx.IsInGoEnv {
projectDir := g.dirM[dirTarget]
cmd := fmt.Sprintf("go mod init %s", g.Ctx.ProjectName.Source())
output, err := execx.Run(fmt.Sprintf(cmd), projectDir)
if err != nil {
logx.Error(err)
return err
}
g.Ctx.Info(output)
}
return nil
}

View File

@@ -1,35 +0,0 @@
package base
import (
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util/console"
)
func TestParseImport(t *testing.T) {
src, _ := filepath.Abs("./test.proto")
base, _ := filepath.Abs("./base.proto")
imports, containsAny, err := parser.ParseImport(src)
assert.Nil(t, err)
assert.Equal(t, true, containsAny)
assert.Equal(t, 1, len(imports))
assert.Equal(t, "github.com/tal-tech/go-zero/tools/goctl/rpc", imports[0].PbImportName)
assert.Equal(t, base, imports[0].OriginalProtoPath)
}
func TestTransfer(t *testing.T) {
src, _ := filepath.Abs("./test.proto")
abs, _ := filepath.Abs("./test")
imports, _, _ := parser.ParseImport(src)
proto, err := parser.Transfer(src, abs, imports, console.NewConsole(false))
assert.Nil(t, err)
assert.Equal(t, 1, len(proto.Service))
assert.Equal(t, "Greeter", proto.Service[0].Name.Source())
assert.Equal(t, 5, len(proto.Structure))
data, ok := proto.Structure["map"]
assert.Equal(t, true, ok)
assert.Equal(t, "M", data.Field[0].Name.Source())
}

View File

@@ -0,0 +1,75 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// source: common.proto
package common
import (
fmt "fmt"
proto "github.com/golang/protobuf/proto"
math "math"
)
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package
type User struct {
Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *User) Reset() { *m = User{} }
func (m *User) String() string { return proto.CompactTextString(m) }
func (*User) ProtoMessage() {}
func (*User) Descriptor() ([]byte, []int) {
return fileDescriptor_555bd8c177793206, []int{0}
}
func (m *User) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_User.Unmarshal(m, b)
}
func (m *User) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_User.Marshal(b, m, deterministic)
}
func (m *User) XXX_Merge(src proto.Message) {
xxx_messageInfo_User.Merge(m, src)
}
func (m *User) XXX_Size() int {
return xxx_messageInfo_User.Size(m)
}
func (m *User) XXX_DiscardUnknown() {
xxx_messageInfo_User.DiscardUnknown(m)
}
var xxx_messageInfo_User proto.InternalMessageInfo
func (m *User) GetName() string {
if m != nil {
return m.Name
}
return ""
}
func init() {
proto.RegisterType((*User)(nil), "common.User")
}
func init() { proto.RegisterFile("common.proto", fileDescriptor_555bd8c177793206) }
var fileDescriptor_555bd8c177793206 = []byte{
// 72 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x49, 0xce, 0xcf, 0xcd,
0xcd, 0xcf, 0xd3, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x83, 0xf0, 0x94, 0xa4, 0xb8, 0x58,
0x42, 0x8b, 0x53, 0x8b, 0x84, 0x84, 0xb8, 0x58, 0xf2, 0x12, 0x73, 0x53, 0x25, 0x18, 0x15, 0x18,
0x35, 0x38, 0x83, 0xc0, 0xec, 0x24, 0x36, 0xb0, 0x52, 0x63, 0x40, 0x00, 0x00, 0x00, 0xff, 0xff,
0x2c, 0x6d, 0x58, 0x59, 0x3a, 0x00, 0x00, 0x00,
}

View File

@@ -0,0 +1,7 @@
syntax = "proto3";
package common;
message User {
string name = 1;
}

View File

@@ -0,0 +1,33 @@
package generator
import (
"os/exec"
"github.com/tal-tech/go-zero/tools/goctl/util/console"
)
type defaultGenerator struct {
log console.Console
}
func NewDefaultGenerator() *defaultGenerator {
log := console.NewColorConsole()
return &defaultGenerator{
log: log,
}
}
func (g *defaultGenerator) Prepare() error {
_, err := exec.LookPath("go")
if err != nil {
return err
}
_, err = exec.LookPath("protoc")
if err != nil {
return err
}
_, err = exec.LookPath("protoc-gen-go")
return err
}

View File

@@ -0,0 +1,11 @@
package generator
import (
"strings"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
)
func formatFilename(filename string) string {
return strings.ToLower(stringx.From(filename).ToCamel())
}

View File

@@ -0,0 +1,98 @@
package generator
import (
"path/filepath"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/console"
"github.com/tal-tech/go-zero/tools/goctl/util/ctx"
)
type RpcGenerator struct {
g Generator
}
func NewDefaultRpcGenerator() *RpcGenerator {
return NewRpcGenerator(NewDefaultGenerator())
}
func NewRpcGenerator(g Generator) *RpcGenerator {
return &RpcGenerator{
g: g,
}
}
func (g *RpcGenerator) Generate(src, target string, protoImportPath []string) error {
abs, err := filepath.Abs(target)
if err != nil {
return err
}
err = util.MkdirIfNotExist(abs)
if err != nil {
return err
}
err = g.g.Prepare()
if err != nil {
return err
}
projectCtx, err := ctx.Prepare(abs)
if err != nil {
return err
}
p := parser.NewDefaultProtoParser()
proto, err := p.Parse(src)
if err != nil {
return err
}
dirCtx, err := mkdir(projectCtx, proto)
if err != nil {
return err
}
err = g.g.GenEtc(dirCtx, proto)
if err != nil {
return err
}
err = g.g.GenPb(dirCtx, protoImportPath, proto)
if err != nil {
return err
}
err = g.g.GenConfig(dirCtx, proto)
if err != nil {
return err
}
err = g.g.GenSvc(dirCtx, proto)
if err != nil {
return err
}
err = g.g.GenLogic(dirCtx, proto)
if err != nil {
return err
}
err = g.g.GenServer(dirCtx, proto)
if err != nil {
return err
}
err = g.g.GenMain(dirCtx, proto)
if err != nil {
return err
}
err = g.g.GenCall(dirCtx, proto)
console.NewColorConsole().MarkDone()
return err
}

View File

@@ -0,0 +1,128 @@
package generator
import (
"os"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
)
func TestRpcGenerateCaseNilImport(t *testing.T) {
_ = Clean()
dispatcher := NewDefaultGenerator()
if err := dispatcher.Prepare(); err == nil {
g := NewRpcGenerator(dispatcher)
abs, err := filepath.Abs("./test")
assert.Nil(t, err)
err = g.Generate("./test_stream.proto", abs, nil)
defer func() {
_ = os.RemoveAll(abs)
}()
assert.Nil(t, err)
_, err = execx.Run("go test "+abs, abs)
assert.Nil(t, err)
}
}
func TestRpcGenerateCaseOption(t *testing.T) {
_ = Clean()
dispatcher := NewDefaultGenerator()
if err := dispatcher.Prepare(); err == nil {
g := NewRpcGenerator(dispatcher)
abs, err := filepath.Abs("./test")
assert.Nil(t, err)
err = g.Generate("./test_option.proto", abs, nil)
defer func() {
_ = os.RemoveAll(abs)
}()
assert.Nil(t, err)
_, err = execx.Run("go test "+abs, abs)
assert.Nil(t, err)
}
}
func TestRpcGenerateCaseWordOption(t *testing.T) {
_ = Clean()
dispatcher := NewDefaultGenerator()
if err := dispatcher.Prepare(); err == nil {
g := NewRpcGenerator(dispatcher)
abs, err := filepath.Abs("./test")
assert.Nil(t, err)
err = g.Generate("./test_word_option.proto", abs, nil)
defer func() {
_ = os.RemoveAll(abs)
}()
assert.Nil(t, err)
_, err = execx.Run("go test "+abs, abs)
assert.Nil(t, err)
}
}
// test keyword go
func TestRpcGenerateCaseGoOption(t *testing.T) {
_ = Clean()
dispatcher := NewDefaultGenerator()
if err := dispatcher.Prepare(); err == nil {
g := NewRpcGenerator(dispatcher)
abs, err := filepath.Abs("./test")
assert.Nil(t, err)
err = g.Generate("./test_go_option.proto", abs, nil)
defer func() {
_ = os.RemoveAll(abs)
}()
assert.Nil(t, err)
_, err = execx.Run("go test "+abs, abs)
assert.Nil(t, err)
}
}
func TestRpcGenerateCaseImport(t *testing.T) {
_ = Clean()
dispatcher := NewDefaultGenerator()
if err := dispatcher.Prepare(); err == nil {
g := NewRpcGenerator(dispatcher)
abs, err := filepath.Abs("./test")
assert.Nil(t, err)
err = g.Generate("./test_import.proto", abs, []string{"./base"})
defer func() {
_ = os.RemoveAll(abs)
}()
assert.Nil(t, err)
_, err = execx.Run("go test "+abs, abs)
assert.True(t, func() bool {
return strings.Contains(err.Error(), "package base is not in GOROOT")
}())
}
}
func TestRpcGenerateCaseServiceRpcNamingSnake(t *testing.T) {
_ = Clean()
dispatcher := NewDefaultGenerator()
if err := dispatcher.Prepare(); err == nil {
g := NewRpcGenerator(dispatcher)
abs, err := filepath.Abs("./test")
assert.Nil(t, err)
err = g.Generate("./test_service_rpc_naming_snake.proto", abs, nil)
defer func() {
_ = os.RemoveAll(abs)
}()
assert.Nil(t, err)
_, err = execx.Run("go test "+abs, abs)
assert.Nil(t, err)
}
}

View File

@@ -0,0 +1,156 @@
package generator
import (
"fmt"
"path/filepath"
"strings"
"github.com/tal-tech/go-zero/core/collection"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
)
const (
callTemplateText = `{{.head}}
//go:generate mockgen -destination ./{{.name}}_mock.go -package {{.filePackage}} -source $GOFILE
package {{.filePackage}}
import (
"context"
{{.package}}
"github.com/tal-tech/go-zero/zrpc"
)
type (
{{.alias}}
{{.serviceName}} interface {
{{.interface}}
}
default{{.serviceName}} struct {
cli zrpc.Client
}
)
func New{{.serviceName}}(cli zrpc.Client) {{.serviceName}} {
return &default{{.serviceName}}{
cli: cli,
}
}
{{.functions}}
`
callInterfaceFunctionTemplate = `{{if .hasComment}}{{.comment}}
{{end}}{{.method}}(ctx context.Context,in *{{.pbRequest}}) (*{{.pbResponse}},error)`
callFunctionTemplate = `
{{if .hasComment}}{{.comment}}{{end}}
func (m *default{{.serviceName}}) {{.method}}(ctx context.Context,in *{{.pbRequest}}) (*{{.pbResponse}}, error) {
client := {{.package}}.New{{.rpcServiceName}}Client(m.cli.Conn())
return client.{{.method}}(ctx, in)
}
`
)
func (g *defaultGenerator) GenCall(ctx DirContext, proto parser.Proto) error {
dir := ctx.GetCall()
service := proto.Service
head := util.GetHead(proto.Name)
filename := filepath.Join(dir.Filename, fmt.Sprintf("%s.go", formatFilename(service.Name)))
functions, err := g.genFunction(proto.PbPackage, service)
if err != nil {
return err
}
iFunctions, err := g.getInterfaceFuncs(service)
if err != nil {
return err
}
text, err := util.LoadTemplate(category, callTemplateFile, callTemplateText)
if err != nil {
return err
}
var alias = collection.NewSet()
for _, item := range service.RPC {
alias.AddStr(fmt.Sprintf("%s = %s", parser.CamelCase(item.RequestType), fmt.Sprintf("%s.%s", proto.PbPackage, parser.CamelCase(item.RequestType))))
alias.AddStr(fmt.Sprintf("%s = %s", parser.CamelCase(item.ReturnsType), fmt.Sprintf("%s.%s", proto.PbPackage, parser.CamelCase(item.ReturnsType))))
}
err = util.With("shared").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
"name": formatFilename(service.Name),
"alias": strings.Join(alias.KeysStr(), util.NL),
"head": head,
"filePackage": dir.Base,
"package": fmt.Sprintf(`"%s"`, ctx.GetPb().Package),
"serviceName": stringx.From(service.Name).ToCamel(),
"functions": strings.Join(functions, util.NL),
"interface": strings.Join(iFunctions, util.NL),
}, filename, true)
return err
}
func (g *defaultGenerator) genFunction(goPackage string, service parser.Service) ([]string, error) {
functions := make([]string, 0)
for _, rpc := range service.RPC {
text, err := util.LoadTemplate(category, callFunctionTemplateFile, callFunctionTemplate)
if err != nil {
return nil, err
}
comment := parser.GetComment(rpc.Doc())
buffer, err := util.With("sharedFn").Parse(text).Execute(map[string]interface{}{
"serviceName": stringx.From(service.Name).ToCamel(),
"rpcServiceName": parser.CamelCase(service.Name),
"method": parser.CamelCase(rpc.Name),
"package": goPackage,
"pbRequest": parser.CamelCase(rpc.RequestType),
"pbResponse": parser.CamelCase(rpc.ReturnsType),
"hasComment": len(comment) > 0,
"comment": comment,
})
if err != nil {
return nil, err
}
functions = append(functions, buffer.String())
}
return functions, nil
}
func (g *defaultGenerator) getInterfaceFuncs(service parser.Service) ([]string, error) {
functions := make([]string, 0)
for _, rpc := range service.RPC {
text, err := util.LoadTemplate(category, callInterfaceFunctionTemplateFile, callInterfaceFunctionTemplate)
if err != nil {
return nil, err
}
comment := parser.GetComment(rpc.Doc())
buffer, err := util.With("interfaceFn").Parse(text).Execute(
map[string]interface{}{
"hasComment": len(comment) > 0,
"comment": comment,
"method": parser.CamelCase(rpc.Name),
"pbRequest": parser.CamelCase(rpc.RequestType),
"pbResponse": parser.CamelCase(rpc.ReturnsType),
})
if err != nil {
return nil, err
}
functions = append(functions, buffer.String())
}
return functions, nil
}

View File

@@ -0,0 +1,44 @@
package generator
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/ctx"
)
func TestGenerateCall(t *testing.T) {
_ = Clean()
project := "stream"
abs, err := filepath.Abs("./test")
assert.Nil(t, err)
dir := filepath.Join(abs, project)
err = util.MkdirIfNotExist(dir)
assert.Nil(t, err)
defer func() {
_ = os.RemoveAll(abs)
}()
projectCtx, err := ctx.Prepare(dir)
assert.Nil(t, err)
p := parser.NewDefaultProtoParser()
proto, err := p.Parse("./test_stream.proto")
assert.Nil(t, err)
dirCtx, err := mkdir(projectCtx, proto)
assert.Nil(t, err)
g := NewDefaultGenerator()
err = g.Prepare()
if err != nil {
return
}
err = g.GenCall(dirCtx, proto)
assert.Nil(t, err)
}

View File

@@ -1,10 +1,11 @@
package gen package generator
import ( import (
"io/ioutil" "io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util"
) )
@@ -17,9 +18,9 @@ type Config struct {
} }
` `
func (g *defaultRpcGenerator) genConfig() error { func (g *defaultGenerator) GenConfig(ctx DirContext, _ parser.Proto) error {
configPath := g.dirM[dirConfig] dir := ctx.GetConfig()
fileName := filepath.Join(configPath, fileConfig) fileName := filepath.Join(dir.Filename, formatFilename("config")+".go")
if util.FileExists(fileName) { if util.FileExists(fileName) {
return nil return nil
} }

View File

@@ -0,0 +1,48 @@
package generator
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/ctx"
)
func TestGenerateConfig(t *testing.T) {
_ = Clean()
project := "stream"
abs, err := filepath.Abs("./test")
assert.Nil(t, err)
dir := filepath.Join(abs, project)
err = util.MkdirIfNotExist(dir)
assert.Nil(t, err)
defer func() {
_ = os.RemoveAll(abs)
}()
projectCtx, err := ctx.Prepare(dir)
assert.Nil(t, err)
p := parser.NewDefaultProtoParser()
proto, err := p.Parse("./test_stream.proto")
assert.Nil(t, err)
dirCtx, err := mkdir(projectCtx, proto)
assert.Nil(t, err)
g := NewDefaultGenerator()
err = g.Prepare()
if err != nil {
return
}
err = g.GenConfig(dirCtx, proto)
assert.Nil(t, err)
// test file exists
err = g.GenConfig(dirCtx, proto)
assert.Nil(t, err)
}

View File

@@ -0,0 +1,15 @@
package generator
import "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
type Generator interface {
Prepare() error
GenMain(ctx DirContext, proto parser.Proto) error
GenCall(ctx DirContext, proto parser.Proto) error
GenEtc(ctx DirContext, proto parser.Proto) error
GenConfig(ctx DirContext, proto parser.Proto) error
GenLogic(ctx DirContext, proto parser.Proto) error
GenServer(ctx DirContext, proto parser.Proto) error
GenSvc(ctx DirContext, proto parser.Proto) error
GenPb(ctx DirContext, protoImportPath []string, proto parser.Proto) error
}

View File

@@ -1,9 +1,10 @@
package gen package generator
import ( import (
"fmt" "fmt"
"path/filepath" "path/filepath"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util"
) )
@@ -15,12 +16,10 @@ Etcd:
Key: {{.serviceName}}.rpc Key: {{.serviceName}}.rpc
` `
func (g *defaultRpcGenerator) genEtc() error { func (g *defaultGenerator) GenEtc(ctx DirContext, _ parser.Proto) error {
etdDir := g.dirM[dirEtc] dir := ctx.GetEtc()
fileName := filepath.Join(etdDir, fmt.Sprintf("%v.yaml", g.Ctx.ServiceName.Lower())) serviceNameLower := formatFilename(ctx.GetMain().Base)
if util.FileExists(fileName) { fileName := filepath.Join(dir.Filename, fmt.Sprintf("%v.yaml", serviceNameLower))
return nil
}
text, err := util.LoadTemplate(category, etcTemplateFileFile, etcTemplate) text, err := util.LoadTemplate(category, etcTemplateFileFile, etcTemplate)
if err != nil { if err != nil {
@@ -28,6 +27,6 @@ func (g *defaultRpcGenerator) genEtc() error {
} }
return util.With("etc").Parse(text).SaveTo(map[string]interface{}{ return util.With("etc").Parse(text).SaveTo(map[string]interface{}{
"serviceName": g.Ctx.ServiceName.Lower(), "serviceName": serviceNameLower,
}, fileName, false) }, fileName, false)
} }

View File

@@ -0,0 +1,45 @@
package generator
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/ctx"
)
func TestGenerateEtc(t *testing.T) {
_ = Clean()
project := "stream"
abs, err := filepath.Abs("./test")
assert.Nil(t, err)
dir := filepath.Join(abs, project)
err = util.MkdirIfNotExist(dir)
assert.Nil(t, err)
defer func() {
_ = os.RemoveAll(abs)
}()
projectCtx, err := ctx.Prepare(dir)
assert.Nil(t, err)
p := parser.NewDefaultProtoParser()
proto, err := p.Parse("./test_stream.proto")
assert.Nil(t, err)
dirCtx, err := mkdir(projectCtx, proto)
assert.Nil(t, err)
g := NewDefaultGenerator()
err = g.Prepare()
if err != nil {
return
}
err = g.GenEtc(dirCtx, proto)
assert.Nil(t, err)
}

View File

@@ -0,0 +1,101 @@
package generator
import (
"fmt"
"path/filepath"
"strings"
"github.com/tal-tech/go-zero/core/collection"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
)
const (
logicTemplate = `package logic
import (
"context"
{{.imports}}
"github.com/tal-tech/go-zero/core/logx"
)
type {{.logicName}} struct {
ctx context.Context
svcCtx *svc.ServiceContext
logx.Logger
}
func New{{.logicName}}(ctx context.Context,svcCtx *svc.ServiceContext) *{{.logicName}} {
return &{{.logicName}}{
ctx: ctx,
svcCtx: svcCtx,
Logger: logx.WithContext(ctx),
}
}
{{.functions}}
`
logicFunctionTemplate = `{{if .hasComment}}{{.comment}}{{end}}
func (l *{{.logicName}}) {{.method}} (in {{.request}}) ({{.response}}, error) {
// todo: add your logic here and delete this line
return &{{.responseType}}{}, nil
}
`
)
func (g *defaultGenerator) GenLogic(ctx DirContext, proto parser.Proto) error {
dir := ctx.GetLogic()
for _, rpc := range proto.Service.RPC {
filename := filepath.Join(dir.Filename, formatFilename(rpc.Name+"_logic")+".go")
functions, err := g.genLogicFunction(proto.PbPackage, rpc)
if err != nil {
return err
}
imports := collection.NewSet()
imports.AddStr(fmt.Sprintf(`"%v"`, ctx.GetSvc().Package))
imports.AddStr(fmt.Sprintf(`"%v"`, ctx.GetPb().Package))
text, err := util.LoadTemplate(category, logicTemplateFileFile, logicTemplate)
if err != nil {
return err
}
err = util.With("logic").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
"logicName": fmt.Sprintf("%sLogic", stringx.From(rpc.Name).ToCamel()),
"functions": functions,
"imports": strings.Join(imports.KeysStr(), util.NL),
}, filename, false)
if err != nil {
return err
}
}
return nil
}
func (g *defaultGenerator) genLogicFunction(goPackage string, rpc *parser.RPC) (string, error) {
var functions = make([]string, 0)
text, err := util.LoadTemplate(category, logicFuncTemplateFileFile, logicFunctionTemplate)
if err != nil {
return "", err
}
logicName := stringx.From(rpc.Name + "_logic").ToCamel()
comment := parser.GetComment(rpc.Doc())
buffer, err := util.With("fun").Parse(text).Execute(map[string]interface{}{
"logicName": logicName,
"method": parser.CamelCase(rpc.Name),
"request": fmt.Sprintf("*%s.%s", goPackage, parser.CamelCase(rpc.RequestType)),
"response": fmt.Sprintf("*%s.%s", goPackage, parser.CamelCase(rpc.ReturnsType)),
"responseType": fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(rpc.ReturnsType)),
"hasComment": len(comment) > 0,
"comment": comment,
})
if err != nil {
return "", err
}
functions = append(functions, buffer.String())
return strings.Join(functions, util.NL), nil
}

View File

@@ -0,0 +1,44 @@
package generator
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/ctx"
)
func TestGenerateLogic(t *testing.T) {
project := "stream"
abs, err := filepath.Abs("./test")
assert.Nil(t, err)
dir := filepath.Join(abs, project)
err = util.MkdirIfNotExist(dir)
assert.Nil(t, err)
defer func() {
_ = os.RemoveAll(abs)
}()
projectCtx, err := ctx.Prepare(dir)
assert.Nil(t, err)
p := parser.NewDefaultProtoParser()
proto, err := p.Parse("./test_stream.proto")
assert.Nil(t, err)
dirCtx, err := mkdir(projectCtx, proto)
assert.Nil(t, err)
g := NewDefaultGenerator()
err = g.Prepare()
if err != nil {
return
}
err = g.GenLogic(dirCtx, proto)
assert.Nil(t, err)
}

View File

@@ -0,0 +1,72 @@
package generator
import (
"fmt"
"path/filepath"
"strings"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
)
const mainTemplate = `{{.head}}
package main
import (
"flag"
"fmt"
{{.imports}}
"github.com/tal-tech/go-zero/core/conf"
"github.com/tal-tech/go-zero/zrpc"
"google.golang.org/grpc"
)
var configFile = flag.String("f", "etc/{{.serviceName}}.yaml", "the config file")
func main() {
flag.Parse()
var c config.Config
conf.MustLoad(*configFile, &c)
ctx := svc.NewServiceContext(c)
srv := server.New{{.serviceNew}}Server(ctx)
s := zrpc.MustNewServer(c.RpcServerConf, func(grpcServer *grpc.Server) {
{{.pkg}}.Register{{.service}}Server(grpcServer, srv)
})
defer s.Stop()
fmt.Printf("Starting rpc server at %s...\n", c.ListenOn)
s.Start()
}
`
func (g *defaultGenerator) GenMain(ctx DirContext, proto parser.Proto) error {
dir := ctx.GetMain()
serviceNameLower := formatFilename(ctx.GetMain().Base)
fileName := filepath.Join(dir.Filename, fmt.Sprintf("%v.go", serviceNameLower))
imports := make([]string, 0)
pbImport := fmt.Sprintf(`"%v"`, ctx.GetPb().Package)
svcImport := fmt.Sprintf(`"%v"`, ctx.GetSvc().Package)
remoteImport := fmt.Sprintf(`"%v"`, ctx.GetServer().Package)
configImport := fmt.Sprintf(`"%v"`, ctx.GetConfig().Package)
imports = append(imports, configImport, pbImport, remoteImport, svcImport)
head := util.GetHead(proto.Name)
text, err := util.LoadTemplate(category, mainTemplateFile, mainTemplate)
if err != nil {
return err
}
return util.With("main").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
"head": head,
"serviceName": serviceNameLower,
"imports": strings.Join(imports, util.NL),
"pkg": proto.PbPackage,
"serviceNew": stringx.From(proto.Service.Name).ToCamel(),
"service": parser.CamelCase(proto.Service.Name),
}, fileName, false)
}

View File

@@ -0,0 +1,45 @@
package generator
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/ctx"
)
func TestGenerateMain(t *testing.T) {
_ = Clean()
project := "stream"
abs, err := filepath.Abs("./test")
assert.Nil(t, err)
dir := filepath.Join(abs, project)
err = util.MkdirIfNotExist(dir)
assert.Nil(t, err)
defer func() {
_ = os.RemoveAll(abs)
}()
projectCtx, err := ctx.Prepare(dir)
assert.Nil(t, err)
p := parser.NewDefaultProtoParser()
proto, err := p.Parse("./test_stream.proto")
assert.Nil(t, err)
dirCtx, err := mkdir(projectCtx, proto)
assert.Nil(t, err)
g := NewDefaultGenerator()
err = g.Prepare()
if err != nil {
return
}
err = g.GenMain(dirCtx, proto)
assert.Nil(t, err)
}

View File

@@ -0,0 +1,31 @@
package generator
import (
"bytes"
"path/filepath"
"strings"
"github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
)
func (g *defaultGenerator) GenPb(ctx DirContext, protoImportPath []string, proto parser.Proto) error {
dir := ctx.GetPb()
cw := new(bytes.Buffer)
base := filepath.Dir(proto.Src)
cw.WriteString("protoc ")
for _, ip := range protoImportPath {
cw.WriteString(" -I=" + ip)
}
cw.WriteString(" -I=" + base)
cw.WriteString(" " + proto.Name)
if strings.Contains(proto.GoPackage, "/") {
cw.WriteString(" --go_out=plugins=grpc:" + ctx.GetMain().Filename)
} else {
cw.WriteString(" --go_out=plugins=grpc:" + dir.Filename)
}
command := cw.String()
g.log.Debug(command)
_, err := execx.Run(command, "")
return err
}

Some files were not shown because too many files have changed in this diff Show More