mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-05-20 13:18:17 +08:00
Compare commits
30 Commits
tools/goct
...
v1.3.2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
78ea0769fd | ||
|
|
e0fa8d820d | ||
|
|
dfd58c213c | ||
|
|
83cacf51b7 | ||
|
|
6dccfa29fd | ||
|
|
7e0b0ab0b1 | ||
|
|
ac18cc470d | ||
|
|
f4471846ff | ||
|
|
9c2d526a11 | ||
|
|
2b9fc26c38 | ||
|
|
321dc2d410 | ||
|
|
500bd87c85 | ||
|
|
e9620c8c05 | ||
|
|
70e51bb352 | ||
|
|
278cd123c8 | ||
|
|
3febb1a5d0 | ||
|
|
d8054d8def | ||
|
|
ec271db7a0 | ||
|
|
bbac994c8a | ||
|
|
c1d9e6a00b | ||
|
|
0aeb49a6b0 | ||
|
|
fe262766b4 | ||
|
|
7181505c8a | ||
|
|
f060a226bc | ||
|
|
93d524b797 | ||
|
|
5c169f4f49 | ||
|
|
d29dfa12e3 | ||
|
|
194f55e08e | ||
|
|
c0f9892fe3 | ||
|
|
227104d7d7 |
2
LICENSE
2
LICENSE
@@ -1,6 +1,6 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2020 xiaoheiban_server_go
|
||||
Copyright (c) 2022 zeromicro
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
|
||||
@@ -448,7 +448,15 @@ func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value, map
|
||||
dereffedBaseType := Deref(baseType)
|
||||
dereffedBaseKind := dereffedBaseType.Kind()
|
||||
refValue := reflect.ValueOf(mapValue)
|
||||
if refValue.IsNil() {
|
||||
return nil
|
||||
}
|
||||
|
||||
conv := reflect.MakeSlice(reflect.SliceOf(baseType), refValue.Len(), refValue.Cap())
|
||||
if refValue.Len() == 0 {
|
||||
value.Set(conv)
|
||||
return nil
|
||||
}
|
||||
|
||||
var valid bool
|
||||
for i := 0; i < refValue.Len(); i++ {
|
||||
|
||||
@@ -198,6 +198,49 @@ func TestUnmarshalIntWithDefault(t *testing.T) {
|
||||
assert.Equal(t, 1, in.Int)
|
||||
}
|
||||
|
||||
func TestUnmarshalBoolSliceRequired(t *testing.T) {
|
||||
type inner struct {
|
||||
Bools []bool `key:"bools"`
|
||||
}
|
||||
|
||||
var in inner
|
||||
assert.NotNil(t, UnmarshalKey(map[string]interface{}{}, &in))
|
||||
}
|
||||
|
||||
func TestUnmarshalBoolSliceNil(t *testing.T) {
|
||||
type inner struct {
|
||||
Bools []bool `key:"bools,optional"`
|
||||
}
|
||||
|
||||
var in inner
|
||||
assert.Nil(t, UnmarshalKey(map[string]interface{}{}, &in))
|
||||
assert.Nil(t, in.Bools)
|
||||
}
|
||||
|
||||
func TestUnmarshalBoolSliceNilExplicit(t *testing.T) {
|
||||
type inner struct {
|
||||
Bools []bool `key:"bools,optional"`
|
||||
}
|
||||
|
||||
var in inner
|
||||
assert.Nil(t, UnmarshalKey(map[string]interface{}{
|
||||
"bools": nil,
|
||||
}, &in))
|
||||
assert.Nil(t, in.Bools)
|
||||
}
|
||||
|
||||
func TestUnmarshalBoolSliceEmpty(t *testing.T) {
|
||||
type inner struct {
|
||||
Bools []bool `key:"bools,optional"`
|
||||
}
|
||||
|
||||
var in inner
|
||||
assert.Nil(t, UnmarshalKey(map[string]interface{}{
|
||||
"bools": []bool{},
|
||||
}, &in))
|
||||
assert.Empty(t, in.Bools)
|
||||
}
|
||||
|
||||
func TestUnmarshalBoolSliceWithDefault(t *testing.T) {
|
||||
type inner struct {
|
||||
Bools []bool `key:"bools,default=[true,false]"`
|
||||
@@ -330,28 +373,34 @@ func TestUnmarshalFloat(t *testing.T) {
|
||||
|
||||
func TestUnmarshalInt64Slice(t *testing.T) {
|
||||
var v struct {
|
||||
Ages []int64 `key:"ages"`
|
||||
Ages []int64 `key:"ages"`
|
||||
Slice []int64 `key:"slice"`
|
||||
}
|
||||
m := map[string]interface{}{
|
||||
"ages": []int64{1, 2},
|
||||
"ages": []int64{1, 2},
|
||||
"slice": []interface{}{},
|
||||
}
|
||||
|
||||
ast := assert.New(t)
|
||||
ast.Nil(UnmarshalKey(m, &v))
|
||||
ast.ElementsMatch([]int64{1, 2}, v.Ages)
|
||||
ast.Equal([]int64{}, v.Slice)
|
||||
}
|
||||
|
||||
func TestUnmarshalIntSlice(t *testing.T) {
|
||||
var v struct {
|
||||
Ages []int `key:"ages"`
|
||||
Ages []int `key:"ages"`
|
||||
Slice []int `key:"slice"`
|
||||
}
|
||||
m := map[string]interface{}{
|
||||
"ages": []int{1, 2},
|
||||
"ages": []int{1, 2},
|
||||
"slice": []interface{}{},
|
||||
}
|
||||
|
||||
ast := assert.New(t)
|
||||
ast.Nil(UnmarshalKey(m, &v))
|
||||
ast.ElementsMatch([]int{1, 2}, v.Ages)
|
||||
ast.Equal([]int{}, v.Slice)
|
||||
}
|
||||
|
||||
func TestUnmarshalString(t *testing.T) {
|
||||
|
||||
@@ -11,6 +11,7 @@ const (
|
||||
mega = 1024 * 1024
|
||||
)
|
||||
|
||||
// DisplayStats prints the goroutine, memory, GC stats with given interval, default to 5 seconds.
|
||||
func DisplayStats(interval ...time.Duration) {
|
||||
duration := defaultInterval
|
||||
for _, val := range interval {
|
||||
|
||||
@@ -41,6 +41,16 @@ func RawFieldNames(in interface{}, postgresSql ...bool) []string {
|
||||
out = append(out, fmt.Sprintf("`%s`", fi.Name))
|
||||
}
|
||||
default:
|
||||
// get tag name with the tag opton, e.g.:
|
||||
// `db:"id"`
|
||||
// `db:"id,type=char,length=16"`
|
||||
// `db:",type=char,length=16"`
|
||||
if strings.Contains(tagv, ",") {
|
||||
tagv = strings.TrimSpace(strings.Split(tagv, ",")[0])
|
||||
}
|
||||
if len(tagv) == 0 {
|
||||
tagv = fi.Name
|
||||
}
|
||||
if pg {
|
||||
out = append(out, tagv)
|
||||
} else {
|
||||
|
||||
@@ -22,3 +22,20 @@ func TestFieldNames(t *testing.T) {
|
||||
assert.Equal(t, expected, out)
|
||||
})
|
||||
}
|
||||
|
||||
type mockedUserWithOptions struct {
|
||||
ID string `db:"id" json:"id,omitempty"`
|
||||
UserName string `db:"user_name,type=varchar,length=255" json:"userName,omitempty"`
|
||||
Sex int `db:"sex" json:"sex,omitempty"`
|
||||
UUID string `db:",type=varchar,length=16" uuid:"uuid,omitempty"`
|
||||
Age int `db:"age" json:"age"`
|
||||
}
|
||||
|
||||
func TestFieldNamesWithTagOptions(t *testing.T) {
|
||||
t.Run("new", func(t *testing.T) {
|
||||
var u mockedUserWithOptions
|
||||
out := RawFieldNames(&u)
|
||||
expected := []string{"`id`", "`user_name`", "`sex`", "`UUID`", "`age`"}
|
||||
assert.Equal(t, expected, out)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@ type (
|
||||
Expire(key string, seconds int) error
|
||||
Expireat(key string, expireTime int64) error
|
||||
Get(key string) (string, error)
|
||||
GetSet(key, value string) (string, error)
|
||||
Hdel(key, field string) (bool, error)
|
||||
Hexists(key, field string) (bool, error)
|
||||
Hget(key, field string) (string, error)
|
||||
@@ -459,6 +460,15 @@ func (cs clusterStore) SetnxEx(key, value string, seconds int) (bool, error) {
|
||||
return node.SetnxEx(key, value, seconds)
|
||||
}
|
||||
|
||||
func (cs clusterStore) GetSet(key, value string) (string, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return node.GetSet(key, value)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Sismember(key string, value interface{}) (bool, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
|
||||
@@ -490,6 +490,29 @@ func TestRedis_SetExNx(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_Getset(t *testing.T) {
|
||||
store := clusterStore{dispatcher: hash.NewConsistentHash()}
|
||||
_, err := store.GetSet("hello", "world")
|
||||
assert.NotNil(t, err)
|
||||
|
||||
runOnCluster(t, func(client Store) {
|
||||
val, err := client.GetSet("hello", "world")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "", val)
|
||||
val, err = client.Get("hello")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "world", val)
|
||||
val, err = client.GetSet("hello", "newworld")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "world", val)
|
||||
val, err = client.Get("hello")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "newworld", val)
|
||||
_, err = client.Del("hello")
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_SetGetDelHashField(t *testing.T) {
|
||||
store := clusterStore{dispatcher: hash.NewConsistentHash()}
|
||||
err := store.Hset("key", "field", "value")
|
||||
|
||||
@@ -640,6 +640,29 @@ func (s *Redis) GetBitCtx(ctx context.Context, key string, offset int64) (val in
|
||||
return
|
||||
}
|
||||
|
||||
// GetSet is the implementation of redis getset command.
|
||||
func (s *Redis) GetSet(key, value string) (string, error) {
|
||||
return s.GetSetCtx(context.Background(), key, value)
|
||||
}
|
||||
|
||||
// GetSetCtx is the implementation of redis getset command.
|
||||
func (s *Redis) GetSetCtx(ctx context.Context, key, value string) (val string, err error) {
|
||||
err = s.brk.DoWithAcceptable(func() error {
|
||||
conn, err := getRedis(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if val, err = conn.GetSet(ctx, key, value).Result(); err == red.Nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return err
|
||||
}, acceptable)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Hdel is the implementation of redis hdel command.
|
||||
func (s *Redis) Hdel(key string, fields ...string) (bool, error) {
|
||||
return s.HdelCtx(context.Background(), key, fields...)
|
||||
@@ -1381,21 +1404,28 @@ func (s *Redis) ScanCtx(ctx context.Context, cursor uint64, match string, count
|
||||
}
|
||||
|
||||
// SetBit is the implementation of redis setbit command.
|
||||
func (s *Redis) SetBit(key string, offset int64, value int) error {
|
||||
func (s *Redis) SetBit(key string, offset int64, value int) (int, error) {
|
||||
return s.SetBitCtx(context.Background(), key, offset, value)
|
||||
}
|
||||
|
||||
// SetBitCtx is the implementation of redis setbit command.
|
||||
func (s *Redis) SetBitCtx(ctx context.Context, key string, offset int64, value int) error {
|
||||
return s.brk.DoWithAcceptable(func() error {
|
||||
func (s *Redis) SetBitCtx(ctx context.Context, key string, offset int64, value int) (val int, err error) {
|
||||
err = s.brk.DoWithAcceptable(func() error {
|
||||
conn, err := getRedis(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = conn.SetBit(ctx, key, offset, value).Result()
|
||||
return err
|
||||
v, err := conn.SetBit(ctx, key, offset, value).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
val = int(v)
|
||||
return nil
|
||||
}, acceptable)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Sscan is the implementation of redis sscan command.
|
||||
|
||||
@@ -387,30 +387,33 @@ func TestRedis_Mget(t *testing.T) {
|
||||
|
||||
func TestRedis_SetBit(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
err := New(client.Addr, badType()).SetBit("key", 1, 1)
|
||||
_, err := New(client.Addr, badType()).SetBit("key", 1, 1)
|
||||
assert.NotNil(t, err)
|
||||
err = client.SetBit("key", 1, 1)
|
||||
val, err := client.SetBit("key", 1, 1)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 0, val)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_GetBit(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
err := client.SetBit("key", 2, 1)
|
||||
val, err := client.SetBit("key", 2, 1)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 0, val)
|
||||
_, err = New(client.Addr, badType()).GetBit("key", 2)
|
||||
assert.NotNil(t, err)
|
||||
val, err := client.GetBit("key", 2)
|
||||
v, err := client.GetBit("key", 2)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, val)
|
||||
assert.Equal(t, 1, v)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_BitCount(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
for i := 0; i < 11; i++ {
|
||||
err := client.SetBit("key", int64(i), 1)
|
||||
val, err := client.SetBit("key", int64(i), 1)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 0, val)
|
||||
}
|
||||
|
||||
_, err := New(client.Addr, badType()).BitCount("key", 0, -1)
|
||||
@@ -701,6 +704,28 @@ func TestRedis_Set(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_GetSet(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
_, err := New(client.Addr, badType()).GetSet("hello", "world")
|
||||
assert.NotNil(t, err)
|
||||
val, err := client.GetSet("hello", "world")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "", val)
|
||||
val, err = client.Get("hello")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "world", val)
|
||||
val, err = client.GetSet("hello", "newworld")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "world", val)
|
||||
val, err = client.Get("hello")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "newworld", val)
|
||||
ret, err := client.Del("hello")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_SetGetDel(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
err := New(client.Addr, badType()).Set("hello", "world")
|
||||
|
||||
@@ -2,6 +2,7 @@ package redis
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
@@ -11,19 +12,26 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
randomLen = 16
|
||||
tolerance = 500 // milliseconds
|
||||
millisPerSecond = 1000
|
||||
lockCommand = `if redis.call("GET", KEYS[1]) == ARGV[1] then
|
||||
redis.call("SET", KEYS[1], ARGV[1], "PX", ARGV[2])
|
||||
return "OK"
|
||||
else
|
||||
return redis.call("SET", KEYS[1], ARGV[1], "NX", "PX", ARGV[2])
|
||||
end`
|
||||
delCommand = `if redis.call("GET", KEYS[1]) == ARGV[1] then
|
||||
return redis.call("DEL", KEYS[1])
|
||||
else
|
||||
return 0
|
||||
end`
|
||||
randomLen = 16
|
||||
)
|
||||
|
||||
// A RedisLock is a redis lock.
|
||||
type RedisLock struct {
|
||||
store *Redis
|
||||
seconds uint32
|
||||
count int32
|
||||
key string
|
||||
id string
|
||||
}
|
||||
@@ -43,35 +51,30 @@ func NewRedisLock(store *Redis, key string) *RedisLock {
|
||||
|
||||
// Acquire acquires the lock.
|
||||
func (rl *RedisLock) Acquire() (bool, error) {
|
||||
newCount := atomic.AddInt32(&rl.count, 1)
|
||||
if newCount > 1 {
|
||||
seconds := atomic.LoadUint32(&rl.seconds)
|
||||
resp, err := rl.store.Eval(lockCommand, []string{rl.key}, []string{
|
||||
rl.id, strconv.Itoa(int(seconds)*millisPerSecond + tolerance),
|
||||
})
|
||||
if err == red.Nil {
|
||||
return false, nil
|
||||
} else if err != nil {
|
||||
logx.Errorf("Error on acquiring lock for %s, %s", rl.key, err.Error())
|
||||
return false, err
|
||||
} else if resp == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
reply, ok := resp.(string)
|
||||
if ok && reply == "OK" {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
seconds := atomic.LoadUint32(&rl.seconds)
|
||||
ok, err := rl.store.SetnxEx(rl.key, rl.id, int(seconds+1)) // +1s for tolerance
|
||||
if err == red.Nil {
|
||||
atomic.AddInt32(&rl.count, -1)
|
||||
return false, nil
|
||||
} else if err != nil {
|
||||
atomic.AddInt32(&rl.count, -1)
|
||||
logx.Errorf("Error on acquiring lock for %s, %s", rl.key, err.Error())
|
||||
return false, err
|
||||
} else if !ok {
|
||||
atomic.AddInt32(&rl.count, -1)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return true, nil
|
||||
logx.Errorf("Unknown reply when acquiring lock for %s: %v", rl.key, resp)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Release releases the lock.
|
||||
func (rl *RedisLock) Release() (bool, error) {
|
||||
newCount := atomic.AddInt32(&rl.count, -1)
|
||||
if newCount > 0 {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
resp, err := rl.store.Eval(delCommand, []string{rl.key}, []string{rl.id})
|
||||
if err != nil {
|
||||
return false, err
|
||||
|
||||
@@ -29,25 +29,5 @@ func TestRedisLock(t *testing.T) {
|
||||
endAcquire, err := secondLock.Acquire()
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, endAcquire)
|
||||
|
||||
endAcquire, err = secondLock.Acquire()
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, endAcquire)
|
||||
|
||||
release, err = secondLock.Release()
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, release)
|
||||
|
||||
againAcquire, err = firstLock.Acquire()
|
||||
assert.Nil(t, err)
|
||||
assert.False(t, againAcquire)
|
||||
|
||||
release, err = secondLock.Release()
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, release)
|
||||
|
||||
firstAcquire, err = firstLock.Acquire()
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, firstAcquire)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -75,6 +75,7 @@ func format(query string, args ...interface{}) (string, error) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if j > i+1 {
|
||||
index, err := strconv.Atoi(query[i+1 : j])
|
||||
if err != nil {
|
||||
@@ -85,7 +86,7 @@ func format(query string, args ...interface{}) (string, error) {
|
||||
if index > argIndex {
|
||||
argIndex = index
|
||||
}
|
||||
|
||||
|
||||
index--
|
||||
if index < 0 || numArgs <= index {
|
||||
return "", fmt.Errorf("error: wrong index %d in sql", index)
|
||||
|
||||
@@ -94,7 +94,7 @@ func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *sta
|
||||
handler.TimeoutHandler(ng.checkedTimeout(fr.timeout)),
|
||||
handler.RecoverHandler,
|
||||
handler.MetricHandler(metrics),
|
||||
handler.MaxBytesHandler(ng.conf.MaxBytes),
|
||||
handler.MaxBytesHandler(ng.checkedMaxBytes(fr.maxBytes)),
|
||||
handler.GunzipHandler,
|
||||
)
|
||||
chain = ng.appendAuthHandler(fr, chain, verifier)
|
||||
@@ -119,6 +119,14 @@ func (ng *engine) bindRoutes(router httpx.Router) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ng *engine) checkedMaxBytes(bytes int64) int64 {
|
||||
if bytes > 0 {
|
||||
return bytes
|
||||
}
|
||||
|
||||
return ng.conf.MaxBytes
|
||||
}
|
||||
|
||||
func (ng *engine) checkedTimeout(timeout time.Duration) time.Duration {
|
||||
if timeout > 0 {
|
||||
return timeout
|
||||
|
||||
@@ -194,6 +194,41 @@ func TestEngine_checkedTimeout(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngine_checkedMaxBytes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
maxBytes int64
|
||||
expect int64
|
||||
}{
|
||||
{
|
||||
name: "not set",
|
||||
expect: 1000,
|
||||
},
|
||||
{
|
||||
name: "less",
|
||||
maxBytes: 500,
|
||||
expect: 500,
|
||||
},
|
||||
{
|
||||
name: "equal",
|
||||
maxBytes: 1000,
|
||||
expect: 1000,
|
||||
},
|
||||
{
|
||||
name: "more",
|
||||
maxBytes: 1500,
|
||||
expect: 1500,
|
||||
},
|
||||
}
|
||||
|
||||
ng := newEngine(RestConf{
|
||||
MaxBytes: 1000,
|
||||
})
|
||||
for _, test := range tests {
|
||||
assert.Equal(t, test.expect, ng.checkedMaxBytes(test.maxBytes))
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngine_notFoundHandler(t *testing.T) {
|
||||
logx.Disable()
|
||||
|
||||
|
||||
@@ -4,5 +4,5 @@ import "net/http"
|
||||
|
||||
type (
|
||||
Interceptor func(r *http.Request) (*http.Request, ResponseHandler)
|
||||
ResponseHandler func(*http.Response)
|
||||
ResponseHandler func(resp *http.Response, err error)
|
||||
)
|
||||
|
||||
@@ -10,15 +10,21 @@ import (
|
||||
|
||||
func LogInterceptor(r *http.Request) (*http.Request, ResponseHandler) {
|
||||
start := timex.Now()
|
||||
return r, func(resp *http.Response) {
|
||||
return r, func(resp *http.Response, err error) {
|
||||
duration := timex.Since(start)
|
||||
if err != nil {
|
||||
logger := logx.WithContext(r.Context()).WithDuration(duration)
|
||||
logger.Errorf("[HTTP] %s %s - %v", r.Method, r.URL, err)
|
||||
return
|
||||
}
|
||||
|
||||
var tc propagation.TraceContext
|
||||
ctx := tc.Extract(r.Context(), propagation.HeaderCarrier(resp.Header))
|
||||
logger := logx.WithContext(ctx).WithDuration(duration)
|
||||
if isOkResponse(resp.StatusCode) {
|
||||
logger.Infof("[HTTP] %d - %s %s/%s", resp.StatusCode, r.Method, r.Host, r.RequestURI)
|
||||
logger.Infof("[HTTP] %d - %s %s", resp.StatusCode, r.Method, r.URL)
|
||||
} else {
|
||||
logger.Errorf("[HTTP] %d - %s %s/%s", resp.StatusCode, r.Method, r.Host, r.RequestURI)
|
||||
logger.Errorf("[HTTP] %d - %s %s", resp.StatusCode, r.Method, r.URL)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,12 +11,13 @@ import (
|
||||
func TestLogInterceptor(t *testing.T) {
|
||||
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
}))
|
||||
defer svr.Close()
|
||||
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
|
||||
assert.Nil(t, err)
|
||||
req, handler := LogInterceptor(req)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
handler(resp, err)
|
||||
assert.Nil(t, err)
|
||||
handler(resp)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
}
|
||||
|
||||
@@ -24,11 +25,27 @@ func TestLogInterceptorServerError(t *testing.T) {
|
||||
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
defer svr.Close()
|
||||
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
|
||||
assert.Nil(t, err)
|
||||
req, handler := LogInterceptor(req)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
handler(resp, err)
|
||||
assert.Nil(t, err)
|
||||
handler(resp)
|
||||
assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestLogInterceptorServerClosed(t *testing.T) {
|
||||
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
defer svr.Close()
|
||||
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
|
||||
assert.Nil(t, err)
|
||||
svr.Close()
|
||||
req, handler := LogInterceptor(req)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
handler(resp, err)
|
||||
assert.NotNil(t, err)
|
||||
assert.Nil(t, resp)
|
||||
}
|
||||
|
||||
@@ -1,21 +1,44 @@
|
||||
package httpc
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/zeromicro/go-zero/rest/httpc/internal"
|
||||
)
|
||||
|
||||
// Do sends an HTTP request to the service assocated with the given key.
|
||||
func Do(key string, r *http.Request) (*http.Response, error) {
|
||||
return NewService(key).Do(r)
|
||||
var interceptors = []internal.Interceptor{
|
||||
internal.LogInterceptor,
|
||||
}
|
||||
|
||||
// Get sends an HTTP GET request to the service assocated with the given key.
|
||||
func Get(key, url string) (*http.Response, error) {
|
||||
return NewService(key).Get(url)
|
||||
// DoRequest sends an HTTP request and returns an HTTP response.
|
||||
func DoRequest(r *http.Request) (*http.Response, error) {
|
||||
return request(r, defaultClient{})
|
||||
}
|
||||
|
||||
// Post sends an HTTP POST request to the service assocated with the given key.
|
||||
func Post(key, url, contentType string, body io.Reader) (*http.Response, error) {
|
||||
return NewService(key).Post(url, contentType, body)
|
||||
type (
|
||||
client interface {
|
||||
do(r *http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
defaultClient struct{}
|
||||
)
|
||||
|
||||
func (c defaultClient) do(r *http.Request) (*http.Response, error) {
|
||||
return http.DefaultClient.Do(r)
|
||||
}
|
||||
|
||||
func request(r *http.Request, cli client) (*http.Response, error) {
|
||||
var respHandlers []internal.ResponseHandler
|
||||
for _, interceptor := range interceptors {
|
||||
var h internal.ResponseHandler
|
||||
r, h = interceptor(r)
|
||||
respHandlers = append(respHandlers, h)
|
||||
}
|
||||
|
||||
resp, err := cli.do(r)
|
||||
for i := len(respHandlers) - 1; i >= 0; i-- {
|
||||
respHandlers[i](resp, err)
|
||||
}
|
||||
|
||||
return resp, err
|
||||
}
|
||||
|
||||
@@ -11,27 +11,31 @@ import (
|
||||
func TestDo(t *testing.T) {
|
||||
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
}))
|
||||
_, err := Get("foo", "tcp://bad request")
|
||||
assert.NotNil(t, err)
|
||||
resp, err := Get("foo", svr.URL)
|
||||
defer svr.Close()
|
||||
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
|
||||
assert.Nil(t, err)
|
||||
resp, err := DoRequest(req)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestDoNotFound(t *testing.T) {
|
||||
svr := httptest.NewServer(http.NotFoundHandler())
|
||||
_, err := Post("foo", "tcp://bad request", "application/json", nil)
|
||||
assert.NotNil(t, err)
|
||||
resp, err := Post("foo", svr.URL, "application/json", nil)
|
||||
defer svr.Close()
|
||||
req, err := http.NewRequest(http.MethodPost, svr.URL, nil)
|
||||
assert.Nil(t, err)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := DoRequest(req)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestDoMoved(t *testing.T) {
|
||||
svr := httptest.NewServer(http.RedirectHandler("/foo", http.StatusMovedPermanently))
|
||||
defer svr.Close()
|
||||
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
|
||||
assert.Nil(t, err)
|
||||
_, err = Do("foo", req)
|
||||
_, err = DoRequest(req)
|
||||
// too many redirects
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
36
rest/httpc/responses.go
Normal file
36
rest/httpc/responses.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package httpc
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/mapping"
|
||||
"github.com/zeromicro/go-zero/rest/internal/encoding"
|
||||
)
|
||||
|
||||
// Parse parses the response.
|
||||
func Parse(resp *http.Response, val interface{}) error {
|
||||
if err := ParseHeaders(resp, val); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return ParseJsonBody(resp, val)
|
||||
}
|
||||
|
||||
// ParseHeaders parses the rsponse headers.
|
||||
func ParseHeaders(resp *http.Response, val interface{}) error {
|
||||
return encoding.ParseHeaders(resp.Header, val)
|
||||
}
|
||||
|
||||
// ParseJsonBody parses the rsponse body, which should be in json content type.
|
||||
func ParseJsonBody(resp *http.Response, val interface{}) error {
|
||||
if withJsonBody(resp) {
|
||||
return mapping.UnmarshalJsonReader(resp.Body, val)
|
||||
}
|
||||
|
||||
return mapping.UnmarshalJsonMap(nil, val)
|
||||
}
|
||||
|
||||
func withJsonBody(r *http.Response) bool {
|
||||
return r.ContentLength > 0 && strings.Contains(r.Header.Get(contentType), applicationJson)
|
||||
}
|
||||
64
rest/httpc/responses_test.go
Normal file
64
rest/httpc/responses_test.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package httpc
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestParse(t *testing.T) {
|
||||
var val struct {
|
||||
Foo string `header:"foo"`
|
||||
Name string `json:"name"`
|
||||
Value int `json:"value"`
|
||||
}
|
||||
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("foo", "bar")
|
||||
w.Header().Set(contentType, applicationJson)
|
||||
w.Write([]byte(`{"name":"kevin","value":100}`))
|
||||
}))
|
||||
defer svr.Close()
|
||||
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
|
||||
assert.Nil(t, err)
|
||||
resp, err := DoRequest(req)
|
||||
assert.Nil(t, err)
|
||||
assert.Nil(t, Parse(resp, &val))
|
||||
assert.Equal(t, "bar", val.Foo)
|
||||
assert.Equal(t, "kevin", val.Name)
|
||||
assert.Equal(t, 100, val.Value)
|
||||
}
|
||||
|
||||
func TestParseHeaderError(t *testing.T) {
|
||||
var val struct {
|
||||
Foo int `header:"foo"`
|
||||
}
|
||||
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("foo", "bar")
|
||||
w.Header().Set(contentType, applicationJson)
|
||||
}))
|
||||
defer svr.Close()
|
||||
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
|
||||
assert.Nil(t, err)
|
||||
resp, err := DoRequest(req)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, Parse(resp, &val))
|
||||
}
|
||||
|
||||
func TestParseNoBody(t *testing.T) {
|
||||
var val struct {
|
||||
Foo string `header:"foo"`
|
||||
}
|
||||
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("foo", "bar")
|
||||
w.Header().Set(contentType, applicationJson)
|
||||
}))
|
||||
defer svr.Close()
|
||||
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
|
||||
assert.Nil(t, err)
|
||||
resp, err := DoRequest(req)
|
||||
assert.Nil(t, err)
|
||||
assert.Nil(t, Parse(resp, &val))
|
||||
assert.Equal(t, "bar", val.Foo)
|
||||
}
|
||||
@@ -1,33 +1,19 @@
|
||||
package httpc
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/breaker"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/rest/httpc/internal"
|
||||
)
|
||||
|
||||
// ContentType means Content-Type.
|
||||
const ContentType = "Content-Type"
|
||||
|
||||
var interceptors = []internal.Interceptor{
|
||||
internal.LogInterceptor,
|
||||
}
|
||||
|
||||
type (
|
||||
// Option is used to customize the *http.Client.
|
||||
Option func(r *http.Request) *http.Request
|
||||
|
||||
// Service represents a remote HTTP service.
|
||||
Service interface {
|
||||
// Do sends an HTTP request to the service.
|
||||
Do(r *http.Request) (*http.Response, error)
|
||||
// Get sends an HTTP GET request to the service.
|
||||
Get(url string) (*http.Response, error)
|
||||
// Post sends an HTTP POST request to the service.
|
||||
Post(url, contentType string, body io.Reader) (*http.Response, error)
|
||||
// DoRequest sends a HTTP request to the service.
|
||||
DoRequest(r *http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
namedService struct {
|
||||
@@ -53,50 +39,12 @@ func NewServiceWithClient(name string, cli *http.Client, opts ...Option) Service
|
||||
}
|
||||
}
|
||||
|
||||
// Do sends an HTTP request to the service.
|
||||
func (s namedService) Do(r *http.Request) (resp *http.Response, err error) {
|
||||
var respHandlers []internal.ResponseHandler
|
||||
for _, interceptor := range interceptors {
|
||||
var h internal.ResponseHandler
|
||||
r, h = interceptor(r)
|
||||
respHandlers = append(respHandlers, h)
|
||||
}
|
||||
|
||||
resp, err = s.doRequest(r)
|
||||
if err != nil {
|
||||
logx.Errorf("[HTTP] %s %s/%s - %v", r.Method, r.Host, r.RequestURI, err)
|
||||
return
|
||||
}
|
||||
|
||||
for i := len(respHandlers) - 1; i >= 0; i-- {
|
||||
respHandlers[i](resp)
|
||||
}
|
||||
|
||||
return
|
||||
// DoRequest sends an HTTP request to the service.
|
||||
func (s namedService) DoRequest(r *http.Request) (*http.Response, error) {
|
||||
return request(r, s)
|
||||
}
|
||||
|
||||
// Get sends an HTTP GET request to the service.
|
||||
func (s namedService) Get(url string) (*http.Response, error) {
|
||||
r, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s.Do(r)
|
||||
}
|
||||
|
||||
// Post sends an HTTP POST request to the service.
|
||||
func (s namedService) Post(url, contentType string, body io.Reader) (*http.Response, error) {
|
||||
r, err := http.NewRequest(http.MethodPost, url, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r.Header.Set(ContentType, contentType)
|
||||
return s.Do(r)
|
||||
}
|
||||
|
||||
func (s namedService) doRequest(r *http.Request) (resp *http.Response, err error) {
|
||||
func (s namedService) do(r *http.Request) (resp *http.Response, err error) {
|
||||
for _, opt := range s.opts {
|
||||
r = opt(r)
|
||||
}
|
||||
|
||||
@@ -10,10 +10,11 @@ import (
|
||||
|
||||
func TestNamedService_Do(t *testing.T) {
|
||||
svr := httptest.NewServer(http.RedirectHandler("/foo", http.StatusMovedPermanently))
|
||||
defer svr.Close()
|
||||
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
|
||||
assert.Nil(t, err)
|
||||
service := NewService("foo")
|
||||
_, err = service.Do(req)
|
||||
_, err = service.DoRequest(req)
|
||||
// too many redirects
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
@@ -22,11 +23,14 @@ func TestNamedService_Get(t *testing.T) {
|
||||
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("foo", r.Header.Get("foo"))
|
||||
}))
|
||||
defer svr.Close()
|
||||
service := NewService("foo", func(r *http.Request) *http.Request {
|
||||
r.Header.Set("foo", "bar")
|
||||
return r
|
||||
})
|
||||
resp, err := service.Get(svr.URL)
|
||||
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
|
||||
assert.Nil(t, err)
|
||||
resp, err := service.DoRequest(req)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
assert.Equal(t, "bar", resp.Header.Get("foo"))
|
||||
@@ -34,10 +38,12 @@ func TestNamedService_Get(t *testing.T) {
|
||||
|
||||
func TestNamedService_Post(t *testing.T) {
|
||||
svr := httptest.NewServer(http.NotFoundHandler())
|
||||
defer svr.Close()
|
||||
service := NewService("foo")
|
||||
_, err := service.Post("tcp://bad request", "application/json", nil)
|
||||
assert.NotNil(t, err)
|
||||
resp, err := service.Post(svr.URL, "application/json", nil)
|
||||
req, err := http.NewRequest(http.MethodPost, svr.URL, nil)
|
||||
assert.Nil(t, err)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := service.DoRequest(req)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
|
||||
}
|
||||
|
||||
6
rest/httpc/vars.go
Normal file
6
rest/httpc/vars.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package httpc
|
||||
|
||||
const (
|
||||
contentType = "Content-Type"
|
||||
applicationJson = "application/json"
|
||||
)
|
||||
@@ -3,17 +3,16 @@ package httpx
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"strings"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/mapping"
|
||||
"github.com/zeromicro/go-zero/rest/internal/encoding"
|
||||
"github.com/zeromicro/go-zero/rest/pathvar"
|
||||
)
|
||||
|
||||
const (
|
||||
formKey = "form"
|
||||
pathKey = "path"
|
||||
headerKey = "header"
|
||||
maxMemory = 32 << 20 // 32MB
|
||||
maxBodyLen = 8 << 20 // 8MB
|
||||
separator = ";"
|
||||
@@ -21,10 +20,8 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
formUnmarshaler = mapping.NewUnmarshaler(formKey, mapping.WithStringValues())
|
||||
pathUnmarshaler = mapping.NewUnmarshaler(pathKey, mapping.WithStringValues())
|
||||
headerUnmarshaler = mapping.NewUnmarshaler(headerKey, mapping.WithStringValues(),
|
||||
mapping.WithCanonicalKeyFunc(textproto.CanonicalMIMEHeaderKey))
|
||||
formUnmarshaler = mapping.NewUnmarshaler(formKey, mapping.WithStringValues())
|
||||
pathUnmarshaler = mapping.NewUnmarshaler(pathKey, mapping.WithStringValues())
|
||||
)
|
||||
|
||||
// Parse parses the request.
|
||||
@@ -46,16 +43,7 @@ func Parse(r *http.Request, v interface{}) error {
|
||||
|
||||
// ParseHeaders parses the headers request.
|
||||
func ParseHeaders(r *http.Request, v interface{}) error {
|
||||
m := map[string]interface{}{}
|
||||
for k, v := range r.Header {
|
||||
if len(v) == 1 {
|
||||
m[k] = v[0]
|
||||
} else {
|
||||
m[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return headerUnmarshaler.Unmarshal(m, v)
|
||||
return encoding.ParseHeaders(r.Header, v)
|
||||
}
|
||||
|
||||
// ParseForm parses the form request.
|
||||
|
||||
27
rest/internal/encoding/parser.go
Normal file
27
rest/internal/encoding/parser.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package encoding
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/mapping"
|
||||
)
|
||||
|
||||
const headerKey = "header"
|
||||
|
||||
var headerUnmarshaler = mapping.NewUnmarshaler(headerKey, mapping.WithStringValues(),
|
||||
mapping.WithCanonicalKeyFunc(textproto.CanonicalMIMEHeaderKey))
|
||||
|
||||
// ParseHeaders parses the headers request.
|
||||
func ParseHeaders(header http.Header, v interface{}) error {
|
||||
m := map[string]interface{}{}
|
||||
for k, v := range header {
|
||||
if len(v) == 1 {
|
||||
m[k] = v[0]
|
||||
} else {
|
||||
m[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return headerUnmarshaler.Unmarshal(m, v)
|
||||
}
|
||||
40
rest/internal/encoding/parser_test.go
Normal file
40
rest/internal/encoding/parser_test.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package encoding
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestParseHeaders(t *testing.T) {
|
||||
var val struct {
|
||||
Foo string `header:"foo"`
|
||||
Baz int `header:"baz"`
|
||||
Qux bool `header:"qux,default=true"`
|
||||
}
|
||||
r := httptest.NewRequest(http.MethodGet, "/any", nil)
|
||||
r.Header.Set("foo", "bar")
|
||||
r.Header.Set("baz", "1")
|
||||
assert.Nil(t, ParseHeaders(r.Header, &val))
|
||||
assert.Equal(t, "bar", val.Foo)
|
||||
assert.Equal(t, 1, val.Baz)
|
||||
assert.True(t, val.Qux)
|
||||
}
|
||||
|
||||
func TestParseHeadersMulti(t *testing.T) {
|
||||
var val struct {
|
||||
Foo []string `header:"foo"`
|
||||
Baz int `header:"baz"`
|
||||
Qux bool `header:"qux,default=true"`
|
||||
}
|
||||
r := httptest.NewRequest(http.MethodGet, "/any", nil)
|
||||
r.Header.Set("foo", "bar")
|
||||
r.Header.Add("foo", "bar1")
|
||||
r.Header.Set("baz", "1")
|
||||
assert.Nil(t, ParseHeaders(r.Header, &val))
|
||||
assert.Equal(t, []string{"bar", "bar1"}, val.Foo)
|
||||
assert.Equal(t, 1, val.Baz)
|
||||
assert.True(t, val.Qux)
|
||||
}
|
||||
@@ -137,6 +137,13 @@ func WithJwtTransition(secret, prevSecret string) RouteOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithMaxBytes returns a RouteOption to set maxBytes with the given value.
|
||||
func WithMaxBytes(maxBytes int64) RouteOption {
|
||||
return func(r *featuredRoutes) {
|
||||
r.maxBytes = maxBytes
|
||||
}
|
||||
}
|
||||
|
||||
// WithMiddlewares adds given middlewares to given routes.
|
||||
func WithMiddlewares(ms []Middleware, rs ...Route) []Route {
|
||||
for i := len(ms) - 1; i >= 0; i-- {
|
||||
|
||||
@@ -95,6 +95,13 @@ Port: 54321
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithMaxBytes(t *testing.T) {
|
||||
const maxBytes = 1000
|
||||
var fr featuredRoutes
|
||||
WithMaxBytes(maxBytes)(&fr)
|
||||
assert.Equal(t, int64(maxBytes), fr.maxBytes)
|
||||
}
|
||||
|
||||
func TestWithMiddleware(t *testing.T) {
|
||||
m := make(map[string]string)
|
||||
rt := router.NewRouter()
|
||||
|
||||
@@ -36,5 +36,6 @@ type (
|
||||
jwt jwtSetting
|
||||
signature signatureSetting
|
||||
routes []Route
|
||||
maxBytes int64
|
||||
}
|
||||
)
|
||||
|
||||
34
tools/goctl/Dockerfile
Normal file
34
tools/goctl/Dockerfile
Normal file
@@ -0,0 +1,34 @@
|
||||
FROM golang:alpine AS builder
|
||||
|
||||
LABEL stage=gobuilder
|
||||
|
||||
ENV CGO_ENABLED 0
|
||||
ENV GOPROXY https://goproxy.cn,direct
|
||||
|
||||
RUN apk update --no-cache && apk add --no-cache tzdata
|
||||
RUN go install google.golang.org/protobuf/cmd/protoc-gen-go@latest
|
||||
RUN go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@latest
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
ADD go.mod .
|
||||
ADD go.sum .
|
||||
RUN go mod download
|
||||
COPY . .
|
||||
RUN go build -ldflags="-s -w" -o /app/goctl ./goctl.go
|
||||
|
||||
|
||||
FROM alpine
|
||||
|
||||
RUN apk update --no-cache && apk add --no-cache protoc
|
||||
|
||||
COPY --from=builder /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ca-certificates.crt
|
||||
COPY --from=builder /usr/share/zoneinfo/Asia/Shanghai /usr/share/zoneinfo/Asia/Shanghai
|
||||
COPY --from=builder /go/bin/protoc-gen-go /usr/bin/protoc-gen-go
|
||||
COPY --from=builder /go/bin/protoc-gen-go-grpc /usr/bin/protoc-gen-go-grpc
|
||||
ENV TZ Asia/Shanghai
|
||||
|
||||
WORKDIR /app
|
||||
COPY --from=builder /app/goctl /app/goctl
|
||||
|
||||
CMD ["./goctl"]
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"go/format"
|
||||
"go/scanner"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -29,14 +30,14 @@ const (
|
||||
func GoFormatApi(c *cli.Context) error {
|
||||
useStdin := c.Bool("stdin")
|
||||
skipCheckDeclare := c.Bool("declare")
|
||||
dir := c.String("dir")
|
||||
|
||||
var be errorx.BatchError
|
||||
if useStdin {
|
||||
if err := apiFormatByStdin(skipCheckDeclare); err != nil {
|
||||
if err := apiFormatReader(os.Stdin, dir, skipCheckDeclare); err != nil {
|
||||
be.Add(err)
|
||||
}
|
||||
} else {
|
||||
dir := c.String("dir")
|
||||
if len(dir) == 0 {
|
||||
return errors.New("missing -dir")
|
||||
}
|
||||
@@ -65,13 +66,14 @@ func GoFormatApi(c *cli.Context) error {
|
||||
return be.Err()
|
||||
}
|
||||
|
||||
func apiFormatByStdin(skipCheckDeclare bool) error {
|
||||
data, err := ioutil.ReadAll(os.Stdin)
|
||||
// apiFormatReader
|
||||
// filename is needed when there are `import` literals.
|
||||
func apiFormatReader(reader io.Reader, filename string, skipCheckDeclare bool) error {
|
||||
data, err := ioutil.ReadAll(reader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result, err := apiFormat(string(data), skipCheckDeclare)
|
||||
result, err := apiFormat(string(data), skipCheckDeclare, filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1,15 +1,21 @@
|
||||
package format
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
notFormattedStr = `
|
||||
type Request struct {
|
||||
Name string ` + "`" + `path:"name,options=you|me"` + "`" + `
|
||||
Name string ` + "`" + `path:"name,options=you|me"` + "`" + `
|
||||
}
|
||||
type Response struct {
|
||||
Message string ` + "`" + `json:"message"` + "`" + `
|
||||
@@ -45,3 +51,26 @@ func TestFormat(t *testing.T) {
|
||||
_, err = apiFormat(notFormattedStr, false)
|
||||
assert.Errorf(t, err, " line 7:13 can not found declaration 'Student' in context")
|
||||
}
|
||||
|
||||
func Test_apiFormatReader_issue1721(t *testing.T) {
|
||||
dir, err := os.MkdirTemp("", "goctl-api-format")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(dir)
|
||||
subDir := path.Join(dir, "sub")
|
||||
err = os.MkdirAll(subDir, fs.ModePerm)
|
||||
require.NoError(t, err)
|
||||
|
||||
importedFilename := path.Join(dir, "foo.api")
|
||||
err = ioutil.WriteFile(importedFilename, []byte{}, fs.ModePerm)
|
||||
require.NoError(t, err)
|
||||
|
||||
filename := path.Join(subDir, "bar.api")
|
||||
err = ioutil.WriteFile(filename, []byte(fmt.Sprintf(`import "%s"`, importedFilename)), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
f, err := os.Open(filename)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = apiFormatReader(f, filename, false)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"sort"
|
||||
"strings"
|
||||
"text/template"
|
||||
"time"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/collection"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||
@@ -23,7 +24,8 @@ const (
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http"{{if .hasTimeout}}
|
||||
"time"{{end}}
|
||||
|
||||
{{.importPackages}}
|
||||
)
|
||||
@@ -34,9 +36,10 @@ func RegisterHandlers(server *rest.Server, serverCtx *svc.ServiceContext) {
|
||||
`
|
||||
routesAdditionTemplate = `
|
||||
server.AddRoutes(
|
||||
{{.routes}} {{.jwt}}{{.signature}} {{.prefix}}
|
||||
{{.routes}} {{.jwt}}{{.signature}} {{.prefix}} {{.timeout}}
|
||||
)
|
||||
`
|
||||
timeoutThreshold = time.Millisecond
|
||||
)
|
||||
|
||||
var mapping = map[string]string{
|
||||
@@ -57,6 +60,7 @@ type (
|
||||
jwtEnabled bool
|
||||
signatureEnabled bool
|
||||
authName string
|
||||
timeout string
|
||||
middlewares []string
|
||||
prefix string
|
||||
jwtTrans string
|
||||
@@ -80,6 +84,7 @@ func genRoutes(dir, rootPkg string, cfg *config.Config, api *spec.ApiSpec) error
|
||||
return err
|
||||
}
|
||||
|
||||
var hasTimeout bool
|
||||
gt := template.Must(template.New("groupTemplate").Parse(templateText))
|
||||
for _, g := range groups {
|
||||
var gbuilder strings.Builder
|
||||
@@ -110,6 +115,22 @@ func genRoutes(dir, rootPkg string, cfg *config.Config, api *spec.ApiSpec) error
|
||||
rest.WithPrefix("%s"),`, g.prefix)
|
||||
}
|
||||
|
||||
var timeout string
|
||||
if len(g.timeout) > 0 {
|
||||
duration, err := time.ParseDuration(g.timeout)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// why we check this, maybe some users set value 1, it's 1ns, not 1s.
|
||||
if duration < timeoutThreshold {
|
||||
return fmt.Errorf("timeout should not less than 1ms, now %v", duration)
|
||||
}
|
||||
|
||||
timeout = fmt.Sprintf("rest.WithTimeout(%d * time.Millisecond),", duration/time.Millisecond)
|
||||
hasTimeout = true
|
||||
}
|
||||
|
||||
var routes string
|
||||
if len(g.middlewares) > 0 {
|
||||
gbuilder.WriteString("\n}...,")
|
||||
@@ -130,6 +151,7 @@ rest.WithPrefix("%s"),`, g.prefix)
|
||||
"jwt": jwt,
|
||||
"signature": signature,
|
||||
"prefix": prefix,
|
||||
"timeout": timeout,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -139,8 +161,8 @@ rest.WithPrefix("%s"),`, g.prefix)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
routeFilename = routeFilename + ".go"
|
||||
|
||||
routeFilename = routeFilename + ".go"
|
||||
filename := path.Join(dir, handlerDir, routeFilename)
|
||||
os.Remove(filename)
|
||||
|
||||
@@ -152,7 +174,8 @@ rest.WithPrefix("%s"),`, g.prefix)
|
||||
category: category,
|
||||
templateFile: routesTemplateFile,
|
||||
builtinTemplate: routesTemplate,
|
||||
data: map[string]string{
|
||||
data: map[string]interface{}{
|
||||
"hasTimeout": hasTimeout,
|
||||
"importPackages": genRouteImports(rootPkg, api),
|
||||
"routesAdditions": strings.TrimSpace(builder.String()),
|
||||
},
|
||||
@@ -171,7 +194,8 @@ func genRouteImports(parentPkg string, api *spec.ApiSpec) string {
|
||||
continue
|
||||
}
|
||||
}
|
||||
importSet.AddStr(fmt.Sprintf("%s \"%s\"", toPrefix(folder), pathx.JoinPackages(parentPkg, handlerDir, folder)))
|
||||
importSet.AddStr(fmt.Sprintf("%s \"%s\"", toPrefix(folder),
|
||||
pathx.JoinPackages(parentPkg, handlerDir, folder)))
|
||||
}
|
||||
}
|
||||
imports := importSet.KeysStr()
|
||||
@@ -205,6 +229,8 @@ func getRoutes(api *spec.ApiSpec) ([]group, error) {
|
||||
})
|
||||
}
|
||||
|
||||
groupedRoutes.timeout = g.GetAnnotation("timeout")
|
||||
|
||||
jwt := g.GetAnnotation("jwt")
|
||||
if len(jwt) > 0 {
|
||||
groupedRoutes.authName = jwt
|
||||
|
||||
@@ -21,7 +21,6 @@ func TestImportRegex(t *testing.T) {
|
||||
{`"../foo/bar.api"`, true},
|
||||
{`"../../foo/bar.api"`, true},
|
||||
|
||||
{`"bar..api"`, false},
|
||||
{`"//bar.api"`, false},
|
||||
{`"/foo/foo_bar.api"`, true},
|
||||
}
|
||||
|
||||
@@ -48,17 +48,15 @@ func Completion(c *cli.Context) error {
|
||||
|
||||
flag := magic
|
||||
err = ioutil.WriteFile(zshF, zsh, os.ModePerm)
|
||||
if err != nil {
|
||||
return err
|
||||
if err == nil {
|
||||
flag |= flagZsh
|
||||
}
|
||||
|
||||
flag |= flagZsh
|
||||
err = ioutil.WriteFile(bashF, bash, os.ModePerm)
|
||||
if err != nil {
|
||||
return err
|
||||
if err == nil {
|
||||
flag |= flagBash
|
||||
}
|
||||
|
||||
flag |= flagBash
|
||||
buffer.WriteString(aurora.BrightGreen("generation auto completion success!\n").String())
|
||||
buffer.WriteString(aurora.BrightGreen("executes the following script to setting shell:\n").String())
|
||||
switch flag {
|
||||
|
||||
@@ -28,7 +28,7 @@ type Docker struct {
|
||||
GoRelPath string
|
||||
GoFile string
|
||||
ExeFile string
|
||||
Scratch bool
|
||||
BaseImage string
|
||||
HasPort bool
|
||||
Port int
|
||||
Argument string
|
||||
@@ -74,10 +74,10 @@ func DockerCommand(c *cli.Context) (err error) {
|
||||
return fmt.Errorf("file %q not found", goFile)
|
||||
}
|
||||
|
||||
scratch := c.Bool("scratch")
|
||||
base := c.String("base")
|
||||
port := c.Int("port")
|
||||
if _, err := os.Stat(etcDir); os.IsNotExist(err) {
|
||||
return generateDockerfile(goFile, scratch, port, version, timezone)
|
||||
return generateDockerfile(goFile, base, port, version, timezone)
|
||||
}
|
||||
|
||||
cfg, err := findConfig(goFile, etcDir)
|
||||
@@ -85,7 +85,7 @@ func DockerCommand(c *cli.Context) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := generateDockerfile(goFile, scratch, port, version, timezone, "-f", "etc/"+cfg); err != nil {
|
||||
if err := generateDockerfile(goFile, base, port, version, timezone, "-f", "etc/"+cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -126,7 +126,7 @@ func findConfig(file, dir string) (string, error) {
|
||||
return files[0], nil
|
||||
}
|
||||
|
||||
func generateDockerfile(goFile string, scratch bool, port int, version, timezone string, args ...string) error {
|
||||
func generateDockerfile(goFile, base string, port int, version, timezone string, args ...string) error {
|
||||
projPath, err := getFilePath(filepath.Dir(goFile))
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -159,7 +159,7 @@ func generateDockerfile(goFile string, scratch bool, port int, version, timezone
|
||||
GoRelPath: projPath,
|
||||
GoFile: goFile,
|
||||
ExeFile: pathx.FileNameWithoutExt(filepath.Base(goFile)),
|
||||
Scratch: scratch,
|
||||
BaseImage: base,
|
||||
HasPort: port > 0,
|
||||
Port: port,
|
||||
Argument: builder.String(),
|
||||
|
||||
@@ -27,9 +27,9 @@ COPY . .
|
||||
{{end}}RUN go build -ldflags="-s -w" -o /app/{{.ExeFile}} {{.GoRelPath}}/{{.GoFile}}
|
||||
|
||||
|
||||
FROM {{if .Scratch}}scratch{{else}}alpine{{end}}
|
||||
FROM {{.BaseImage}}
|
||||
|
||||
{{if .Scratch}}COPY --from=builder /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ca-certificates.crt{{else}}RUN apk update --no-cache && apk add --no-cache ca-certificates{{end}}
|
||||
COPY --from=builder /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ca-certificates.crt
|
||||
{{if .HasTimezone}}COPY --from=builder /usr/share/zoneinfo/{{.Timezone}} /usr/share/zoneinfo/{{.Timezone}}
|
||||
ENV TZ {{.Timezone}}
|
||||
{{end}}
|
||||
|
||||
34
tools/goctl/env/check.go
vendored
34
tools/goctl/env/check.go
vendored
@@ -40,21 +40,23 @@ var bins = []bin{
|
||||
func Check(ctx *cli.Context) error {
|
||||
install := ctx.Bool("install")
|
||||
force := ctx.Bool("force")
|
||||
return Prepare(install, force)
|
||||
verbose := ctx.Bool("verbose")
|
||||
return Prepare(install, force, verbose)
|
||||
}
|
||||
|
||||
func Prepare(install, force bool) error {
|
||||
func Prepare(install, force, verbose bool) error {
|
||||
log := console.NewColorConsole(verbose)
|
||||
pending := true
|
||||
console.Info("[goctl-env]: preparing to check env")
|
||||
log.Info("[goctl-env]: preparing to check env")
|
||||
defer func() {
|
||||
if p := recover(); p != nil {
|
||||
console.Error("%+v", p)
|
||||
log.Error("%+v", p)
|
||||
return
|
||||
}
|
||||
if pending {
|
||||
console.Success("\n[goctl-env]: congratulations! your goctl environment is ready!")
|
||||
log.Success("\n[goctl-env]: congratulations! your goctl environment is ready!")
|
||||
} else {
|
||||
console.Error(`
|
||||
log.Error(`
|
||||
[goctl-env]: check env finish, some dependencies is not found in PATH, you can execute
|
||||
command 'goctl env check --install' to install it, for details, please execute command
|
||||
'goctl env check --help'`)
|
||||
@@ -62,29 +64,29 @@ command 'goctl env check --install' to install it, for details, please execute c
|
||||
}()
|
||||
for _, e := range bins {
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
console.Info("")
|
||||
console.Info("[goctl-env]: looking up %q", e.name)
|
||||
log.Info("")
|
||||
log.Info("[goctl-env]: looking up %q", e.name)
|
||||
if e.exists {
|
||||
console.Success("[goctl-env]: %q is installed", e.name)
|
||||
log.Success("[goctl-env]: %q is installed", e.name)
|
||||
continue
|
||||
}
|
||||
console.Warning("[goctl-env]: %q is not found in PATH", e.name)
|
||||
log.Warning("[goctl-env]: %q is not found in PATH", e.name)
|
||||
if install {
|
||||
install := func() {
|
||||
console.Info("[goctl-env]: preparing to install %q", e.name)
|
||||
log.Info("[goctl-env]: preparing to install %q", e.name)
|
||||
path, err := e.get(env.Get(env.GoctlCache))
|
||||
if err != nil {
|
||||
console.Error("[goctl-env]: an error interrupted the installation: %+v", err)
|
||||
log.Error("[goctl-env]: an error interrupted the installation: %+v", err)
|
||||
pending = false
|
||||
} else {
|
||||
console.Success("[goctl-env]: %q is already installed in %q", e.name, path)
|
||||
log.Success("[goctl-env]: %q is already installed in %q", e.name, path)
|
||||
}
|
||||
}
|
||||
if force {
|
||||
install()
|
||||
continue
|
||||
}
|
||||
console.Info("[goctl-env]: do you want to install %q [y: YES, n: No]", e.name)
|
||||
log.Info("[goctl-env]: do you want to install %q [y: YES, n: No]", e.name)
|
||||
for {
|
||||
var in string
|
||||
fmt.Scanln(&in)
|
||||
@@ -95,10 +97,10 @@ command 'goctl env check --install' to install it, for details, please execute c
|
||||
brk = true
|
||||
case strings.EqualFold(in, "n"):
|
||||
pending = false
|
||||
console.Info("[goctl-env]: %q installation is ignored", e.name)
|
||||
log.Info("[goctl-env]: %q installation is ignored", e.name)
|
||||
brk = true
|
||||
default:
|
||||
console.Error("[goctl-env]: invalid input, input 'y' for yes, 'n' for no")
|
||||
log.Error("[goctl-env]: invalid input, input 'y' for yes, 'n' for no")
|
||||
}
|
||||
if brk {
|
||||
break
|
||||
|
||||
@@ -23,7 +23,6 @@ import (
|
||||
"github.com/zeromicro/go-zero/tools/goctl/completion"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/docker"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/env"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/internal/errorx"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/internal/version"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/kube"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/migrate"
|
||||
@@ -70,6 +69,10 @@ var commands = []cli.Command{
|
||||
Name: "force, f",
|
||||
Usage: "silent installation of non-existent dependencies",
|
||||
},
|
||||
cli.BoolFlag{
|
||||
Name: "verbose, v",
|
||||
Usage: "enable log output",
|
||||
},
|
||||
},
|
||||
Action: env.Check,
|
||||
},
|
||||
@@ -345,9 +348,10 @@ var commands = []cli.Command{
|
||||
Name: "go",
|
||||
Usage: "the file that contains main function",
|
||||
},
|
||||
cli.BoolFlag{
|
||||
Name: "scratch",
|
||||
Usage: "use scratch for the base docker image",
|
||||
cli.StringFlag{
|
||||
Name: "base",
|
||||
Usage: "the base image to build the docker image, default scratch",
|
||||
Value: "scratch",
|
||||
},
|
||||
cli.IntFlag{
|
||||
Name: "port",
|
||||
@@ -492,9 +496,8 @@ var commands = []cli.Command{
|
||||
Usage: "generate rpc code",
|
||||
Subcommands: []cli.Command{
|
||||
{
|
||||
Name: "new",
|
||||
Usage: `generate rpc demo service`,
|
||||
Description: aurora.Yellow(`deprecated: zrpc code generation use "goctl rpc protoc" instead, for the details see "goctl rpc protoc --help"`).String(),
|
||||
Name: "new",
|
||||
Usage: `generate rpc demo service`,
|
||||
Flags: []cli.Flag{
|
||||
cli.StringFlag{
|
||||
Name: "style",
|
||||
@@ -519,6 +522,10 @@ var commands = []cli.Command{
|
||||
Name: "branch",
|
||||
Usage: "the branch of the remote repo, it does work with --remote",
|
||||
},
|
||||
cli.BoolFlag{
|
||||
Name: "verbose, v",
|
||||
Usage: "enable log output",
|
||||
},
|
||||
},
|
||||
Action: rpc.RPCNew,
|
||||
},
|
||||
@@ -601,54 +608,11 @@ var commands = []cli.Command{
|
||||
Name: "branch",
|
||||
Usage: "the branch of the remote repo, it does work with --remote",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "proto",
|
||||
Usage: `generate rpc from proto`,
|
||||
Description: aurora.Yellow(`deprecated: zrpc code generation use "goctl rpc protoc" instead, for the details see "goctl rpc protoc --help"`).String(),
|
||||
Flags: []cli.Flag{
|
||||
cli.StringFlag{
|
||||
Name: "src, s",
|
||||
Usage: "the file path of the proto source file",
|
||||
},
|
||||
cli.StringSliceFlag{
|
||||
Name: "proto_path, I",
|
||||
Usage: `native command of protoc, specify the directory in which to search for imports. [optional]`,
|
||||
},
|
||||
cli.StringSliceFlag{
|
||||
Name: "go_opt",
|
||||
Usage: `native command of protoc-gen-go, specify the mapping from proto to go, eg --go_opt=proto_import=go_package_import. [optional]`,
|
||||
},
|
||||
cli.StringFlag{
|
||||
Name: "dir, d",
|
||||
Usage: `the target path of the code`,
|
||||
},
|
||||
cli.StringFlag{
|
||||
Name: "style",
|
||||
Usage: "the file naming format, see [https://github.com/zeromicro/go-zero/tree/master/tools/goctl/config/readme.md]",
|
||||
},
|
||||
cli.BoolFlag{
|
||||
Name: "idea",
|
||||
Usage: "whether the command execution environment is from idea plugin. [optional]",
|
||||
},
|
||||
cli.StringFlag{
|
||||
Name: "home",
|
||||
Usage: "the goctl home path of the template, --home and --remote cannot be set at the same time, " +
|
||||
"if they are, --remote has higher priority",
|
||||
},
|
||||
cli.StringFlag{
|
||||
Name: "remote",
|
||||
Usage: "the remote git repo of the template, --home and --remote cannot be set at the same time, " +
|
||||
"if they are, --remote has higher priority\n\tThe git repo directory must be consistent with the " +
|
||||
"https://github.com/zeromicro/go-zero-template directory structure",
|
||||
},
|
||||
cli.StringFlag{
|
||||
Name: "branch",
|
||||
Usage: "the branch of the remote repo, it does work with --remote",
|
||||
Name: "verbose, v",
|
||||
Usage: "enable log output",
|
||||
},
|
||||
},
|
||||
Action: rpc.RPC,
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -941,7 +905,7 @@ func main() {
|
||||
|
||||
// cli already print error messages.
|
||||
if err := app.Run(os.Args); err != nil {
|
||||
fmt.Println(aurora.Red(errorx.Wrap(err).Error()))
|
||||
fmt.Println(aurora.Red(err.Error()))
|
||||
os.Exit(codeFailure)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
)
|
||||
|
||||
// BuildVersion is the version of goctl.
|
||||
const BuildVersion = "1.3.3"
|
||||
const BuildVersion = "1.3.4"
|
||||
|
||||
var tag = map[string]int{"pre-alpha": 0, "alpha": 1, "pre-bata": 2, "beta": 3, "released": 4, "": 5}
|
||||
|
||||
|
||||
@@ -20,16 +20,13 @@ import (
|
||||
"github.com/zeromicro/go-zero/tools/goctl/util/stringx"
|
||||
)
|
||||
|
||||
const (
|
||||
pwd = "."
|
||||
createTableFlag = `(?m)^(?i)CREATE\s+TABLE` // ignore case
|
||||
)
|
||||
const pwd = "."
|
||||
|
||||
type (
|
||||
defaultGenerator struct {
|
||||
// source string
|
||||
dir string
|
||||
console.Console
|
||||
// source string
|
||||
dir string
|
||||
pkg string
|
||||
cfg *config.Config
|
||||
isPostgreSql bool
|
||||
@@ -48,6 +45,12 @@ type (
|
||||
updateCode string
|
||||
deleteCode string
|
||||
cacheExtra string
|
||||
tableName string
|
||||
}
|
||||
|
||||
codeTuple struct {
|
||||
modelCode string
|
||||
modelCustomCode string
|
||||
}
|
||||
)
|
||||
|
||||
@@ -109,7 +112,7 @@ func (g *defaultGenerator) StartFromDDL(filename string, withCache bool, databas
|
||||
}
|
||||
|
||||
func (g *defaultGenerator) StartFromInformationSchema(tables map[string]*model.Table, withCache bool) error {
|
||||
m := make(map[string]string)
|
||||
m := make(map[string]*codeTuple)
|
||||
for _, each := range tables {
|
||||
table, err := parser.ConvertDataType(each)
|
||||
if err != nil {
|
||||
@@ -120,14 +123,21 @@ func (g *defaultGenerator) StartFromInformationSchema(tables map[string]*model.T
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
customCode, err := g.genModelCustom(*table, withCache)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m[table.Name.Source()] = code
|
||||
m[table.Name.Source()] = &codeTuple{
|
||||
modelCode: code,
|
||||
modelCustomCode: customCode,
|
||||
}
|
||||
}
|
||||
|
||||
return g.createFile(m)
|
||||
}
|
||||
|
||||
func (g *defaultGenerator) createFile(modelList map[string]string) error {
|
||||
func (g *defaultGenerator) createFile(modelList map[string]*codeTuple) error {
|
||||
dirAbs, err := filepath.Abs(g.dir)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -140,20 +150,28 @@ func (g *defaultGenerator) createFile(modelList map[string]string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
for tableName, code := range modelList {
|
||||
for tableName, codes := range modelList {
|
||||
tn := stringx.From(tableName)
|
||||
modelFilename, err := format.FileNamingFormat(g.cfg.NamingFormat, fmt.Sprintf("%s_model", tn.Source()))
|
||||
modelFilename, err := format.FileNamingFormat(g.cfg.NamingFormat,
|
||||
fmt.Sprintf("%s_model", tn.Source()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
name := util.SafeString(modelFilename) + ".go"
|
||||
name := util.SafeString(modelFilename) + "_gen.go"
|
||||
filename := filepath.Join(dirAbs, name)
|
||||
err = ioutil.WriteFile(filename, []byte(codes.modelCode), os.ModePerm)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
name = util.SafeString(modelFilename) + ".go"
|
||||
filename = filepath.Join(dirAbs, name)
|
||||
if pathx.FileExists(filename) {
|
||||
g.Warning("%s already exists, ignored.", name)
|
||||
continue
|
||||
}
|
||||
err = ioutil.WriteFile(filename, []byte(code), os.ModePerm)
|
||||
err = ioutil.WriteFile(filename, []byte(codes.modelCustomCode), os.ModePerm)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -183,8 +201,9 @@ func (g *defaultGenerator) createFile(modelList map[string]string) error {
|
||||
}
|
||||
|
||||
// ret1: key-table name,value-code
|
||||
func (g *defaultGenerator) genFromDDL(filename string, withCache bool, database string) (map[string]string, error) {
|
||||
m := make(map[string]string)
|
||||
func (g *defaultGenerator) genFromDDL(filename string, withCache bool, database string) (
|
||||
map[string]*codeTuple, error) {
|
||||
m := make(map[string]*codeTuple)
|
||||
tables, err := parser.Parse(filename, database)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -195,8 +214,15 @@ func (g *defaultGenerator) genFromDDL(filename string, withCache bool, database
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
customCode, err := g.genModelCustom(*e, withCache)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m[e.Name.Source()] = code
|
||||
m[e.Name.Source()] = &codeTuple{
|
||||
modelCode: code,
|
||||
modelCustomCode: customCode,
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
@@ -223,7 +249,7 @@ func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, er
|
||||
table.UniqueCacheKey = uniqueKey
|
||||
table.ContainsUniqueCacheKey = len(uniqueKey) > 0
|
||||
|
||||
importsCode, err := genImports(withCache, in.ContainsTime(), table)
|
||||
importsCode, err := genImports(table, withCache, in.ContainsTime())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -261,7 +287,8 @@ func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, er
|
||||
}
|
||||
|
||||
var list []string
|
||||
list = append(list, insertCodeMethod, findOneCodeMethod, ret.findOneInterfaceMethod, updateCodeMethod, deleteCodeMethod)
|
||||
list = append(list, insertCodeMethod, findOneCodeMethod, ret.findOneInterfaceMethod,
|
||||
updateCodeMethod, deleteCodeMethod)
|
||||
typesCode, err := genTypes(table, strings.Join(modelutil.TrimStringSlice(list), pathx.NL), withCache)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -272,6 +299,11 @@ func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, er
|
||||
return "", err
|
||||
}
|
||||
|
||||
tableName, err := genTableName(table)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
code := &code{
|
||||
importsCode: importsCode,
|
||||
varsCode: varsCode,
|
||||
@@ -282,6 +314,7 @@ func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, er
|
||||
updateCode: updateCode,
|
||||
deleteCode: deleteCode,
|
||||
cacheExtra: ret.cacheExtra,
|
||||
tableName: tableName,
|
||||
}
|
||||
|
||||
output, err := g.executeModel(table, code)
|
||||
@@ -292,8 +325,30 @@ func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, er
|
||||
return output.String(), nil
|
||||
}
|
||||
|
||||
func (g *defaultGenerator) genModelCustom(in parser.Table, withCache bool) (string, error) {
|
||||
text, err := pathx.LoadTemplate(category, modelCustomTemplateFile, template.ModelCustom)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
t := util.With("model-custom").
|
||||
Parse(text).
|
||||
GoFmt(true)
|
||||
output, err := t.Execute(map[string]interface{}{
|
||||
"pkg": g.pkg,
|
||||
"withCache": withCache,
|
||||
"upperStartCamelObject": in.Name.ToCamel(),
|
||||
"lowerStartCamelObject": stringx.From(in.Name.ToCamel()).Untitle(),
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return output.String(), nil
|
||||
}
|
||||
|
||||
func (g *defaultGenerator) executeModel(table Table, code *code) (*bytes.Buffer, error) {
|
||||
text, err := pathx.LoadTemplate(category, modelTemplateFile, template.Model)
|
||||
text, err := pathx.LoadTemplate(category, modelGenTemplateFile, template.ModelGen)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -311,6 +366,7 @@ func (g *defaultGenerator) executeModel(table Table, code *code) (*bytes.Buffer,
|
||||
"update": code.updateCode,
|
||||
"delete": code.deleteCode,
|
||||
"extraMethod": code.cacheExtra,
|
||||
"tableName": code.tableName,
|
||||
"data": table,
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
@@ -4,16 +4,19 @@ import (
|
||||
"database/sql"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/stringx"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/config"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/model/sql/builderx"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/model/sql/parser"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
|
||||
)
|
||||
|
||||
@@ -121,3 +124,31 @@ func TestFields(t *testing.T) {
|
||||
assert.Equal(t, "`name`,`age`,`score`", studentRowsExpectAutoSet)
|
||||
assert.Equal(t, "`name`=?,`age`=?,`score`=?", studentRowsWithPlaceHolder)
|
||||
}
|
||||
|
||||
func Test_genPublicModel(t *testing.T) {
|
||||
var err error
|
||||
dir := pathx.MustTempDir()
|
||||
modelDir := path.Join(dir, "model")
|
||||
err = os.MkdirAll(modelDir, 0777)
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
modelFilename := filepath.Join(modelDir, "foo.sql")
|
||||
err = ioutil.WriteFile(modelFilename, []byte(source), 0777)
|
||||
require.NoError(t, err)
|
||||
|
||||
g, err := NewDefaultGenerator(modelDir, &config.Config{
|
||||
NamingFormat: config.DefaultFormat,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
tables, err := parser.Parse(modelFilename, "")
|
||||
require.Equal(t, 1, len(tables))
|
||||
|
||||
code, err := g.genModelCustom(*tables[0], false)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, strings.Contains(code, "package model"))
|
||||
assert.True(t, strings.Contains(code, "TestUserModel interface {\n\t\ttestUserModel\n\t}\n"))
|
||||
assert.True(t, strings.Contains(code, "customTestUserModel struct {\n\t\t*defaultTestUserModel\n\t}\n"))
|
||||
assert.True(t, strings.Contains(code, "func NewTestUserModel(conn sqlx.SqlConn) TestUserModel {"))
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
|
||||
)
|
||||
|
||||
func genImports(withCache, timeImport bool, table Table) (string, error) {
|
||||
func genImports(table Table, withCache, timeImport bool) (string, error) {
|
||||
if withCache {
|
||||
text, err := pathx.LoadTemplate(category, importsTemplateFile, template.Imports)
|
||||
if err != nil {
|
||||
|
||||
@@ -55,7 +55,6 @@ func genInsert(table Table, withCache, postgreSql bool) (string, string, error)
|
||||
Parse(text).
|
||||
Execute(map[string]interface{}{
|
||||
"withCache": withCache,
|
||||
"containsIndexCache": table.ContainsUniqueCacheKey,
|
||||
"upperStartCamelObject": camel,
|
||||
"lowerStartCamelObject": stringx.From(camel).Untitle(),
|
||||
"expression": strings.Join(expressions, ", "),
|
||||
|
||||
26
tools/goctl/model/sql/gen/tablename.go
Normal file
26
tools/goctl/model/sql/gen/tablename.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package gen
|
||||
|
||||
import (
|
||||
"github.com/zeromicro/go-zero/tools/goctl/model/sql/template"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/util"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
|
||||
)
|
||||
|
||||
func genTableName(table Table) (string, error) {
|
||||
text, err := pathx.LoadTemplate(category, tableNameTemplateFile, template.TableName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
output, err := util.With("tableName").
|
||||
Parse(text).
|
||||
Execute(map[string]interface{}{
|
||||
"tableName": table.Name.Source(),
|
||||
"upperStartCamelObject": table.Name.ToCamel(),
|
||||
})
|
||||
if err != nil {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return output.String(), nil
|
||||
}
|
||||
@@ -22,8 +22,10 @@ const (
|
||||
importsWithNoCacheTemplateFile = "import-no-cache.tpl"
|
||||
insertTemplateFile = "insert.tpl"
|
||||
insertTemplateMethodFile = "interface-insert.tpl"
|
||||
modelTemplateFile = "model.tpl"
|
||||
modelGenTemplateFile = "model-gen.tpl"
|
||||
modelCustomTemplateFile = "model.tpl"
|
||||
modelNewTemplateFile = "model-new.tpl"
|
||||
tableNameTemplateFile = "table-name.tpl"
|
||||
tagTemplateFile = "tag.tpl"
|
||||
typesTemplateFile = "types.tpl"
|
||||
updateTemplateFile = "update.tpl"
|
||||
@@ -45,8 +47,10 @@ var templates = map[string]string{
|
||||
importsWithNoCacheTemplateFile: template.ImportsNoCache,
|
||||
insertTemplateFile: template.Insert,
|
||||
insertTemplateMethodFile: template.InsertMethod,
|
||||
modelTemplateFile: template.Model,
|
||||
modelGenTemplateFile: template.ModelGen,
|
||||
modelCustomTemplateFile: template.ModelCustom,
|
||||
modelNewTemplateFile: template.New,
|
||||
tableNameTemplateFile: template.TableName,
|
||||
tagTemplateFile: template.Tag,
|
||||
typesTemplateFile: template.Types,
|
||||
updateTemplateFile: template.Update,
|
||||
@@ -70,7 +74,7 @@ func GenTemplates(_ *cli.Context) error {
|
||||
return pathx.InitTemplates(category, templates)
|
||||
}
|
||||
|
||||
// RevertTemplate recovers the delete template files
|
||||
// RevertTemplate reverts the deleted template files
|
||||
func RevertTemplate(name string) error {
|
||||
content, ok := templates[name]
|
||||
if !ok {
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"github.com/zeromicro/go-zero/tools/goctl/model/sql/template"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/util"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/util/stringx"
|
||||
)
|
||||
|
||||
func genTypes(table Table, methods string, withCache bool) (string, error) {
|
||||
@@ -24,6 +25,7 @@ func genTypes(table Table, methods string, withCache bool) (string, error) {
|
||||
"withCache": withCache,
|
||||
"method": methods,
|
||||
"upperStartCamelObject": table.Name.ToCamel(),
|
||||
"lowerStartCamelObject": stringx.From(table.Name.ToCamel()).Untitle(),
|
||||
"fields": fieldsString,
|
||||
"data": table,
|
||||
})
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
package template
|
||||
|
||||
// Delete defines a delete template
|
||||
var Delete = `
|
||||
const (
|
||||
// Delete defines a delete template
|
||||
Delete = `
|
||||
func (m *default{{.upperStartCamelObject}}Model) Delete(ctx context.Context, {{.lowerStartCamelPrimaryKey}} {{.dataType}}) error {
|
||||
{{if .withCache}}{{if .containsIndexCache}}data, err:=m.FindOne(ctx, {{.lowerStartCamelPrimaryKey}})
|
||||
if err!=nil{
|
||||
@@ -18,5 +19,6 @@ func (m *default{{.upperStartCamelObject}}Model) Delete(ctx context.Context, {{.
|
||||
}
|
||||
`
|
||||
|
||||
// DeleteMethod defines a delete template for interface method
|
||||
var DeleteMethod = `Delete(ctx context.Context, {{.lowerStartCamelPrimaryKey}} {{.dataType}}) error`
|
||||
// DeleteMethod defines a delete template for interface method
|
||||
DeleteMethod = `Delete(ctx context.Context, {{.lowerStartCamelPrimaryKey}} {{.dataType}}) error`
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package template
|
||||
|
||||
// Error defines an error template
|
||||
var Error = `package {{.pkg}}
|
||||
const Error = `package {{.pkg}}
|
||||
|
||||
import "github.com/zeromicro/go-zero/core/stores/sqlx"
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package template
|
||||
|
||||
// Field defines a filed template for types
|
||||
var Field = `{{.name}} {{.type}} {{.tag}} {{if .hasComment}}// {{.comment}}{{end}}`
|
||||
const Field = `{{.name}} {{.type}} {{.tag}} {{if .hasComment}}// {{.comment}}{{end}}`
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
package template
|
||||
|
||||
// FindOne defines find row by id.
|
||||
var FindOne = `
|
||||
const (
|
||||
// FindOne defines find row by id.
|
||||
FindOne = `
|
||||
func (m *default{{.upperStartCamelObject}}Model) FindOne(ctx context.Context, {{.lowerStartCamelPrimaryKey}} {{.dataType}}) (*{{.upperStartCamelObject}}, error) {
|
||||
{{if .withCache}}{{.cacheKey}}
|
||||
var resp {{.upperStartCamelObject}}
|
||||
@@ -30,8 +31,8 @@ func (m *default{{.upperStartCamelObject}}Model) FindOne(ctx context.Context, {{
|
||||
}
|
||||
`
|
||||
|
||||
// FindOneByField defines find row by field.
|
||||
var FindOneByField = `
|
||||
// FindOneByField defines find row by field.
|
||||
FindOneByField = `
|
||||
func (m *default{{.upperStartCamelObject}}Model) FindOneBy{{.upperField}}(ctx context.Context, {{.in}}) (*{{.upperStartCamelObject}}, error) {
|
||||
{{if .withCache}}{{.cacheKey}}
|
||||
var resp {{.upperStartCamelObject}}
|
||||
@@ -64,8 +65,8 @@ func (m *default{{.upperStartCamelObject}}Model) FindOneBy{{.upperField}}(ctx co
|
||||
}{{end}}
|
||||
`
|
||||
|
||||
// FindOneByFieldExtraMethod defines find row by field with extras.
|
||||
var FindOneByFieldExtraMethod = `
|
||||
// FindOneByFieldExtraMethod defines find row by field with extras.
|
||||
FindOneByFieldExtraMethod = `
|
||||
func (m *default{{.upperStartCamelObject}}Model) formatPrimary(primary interface{}) string {
|
||||
return fmt.Sprintf("%s%v", {{.primaryKeyLeft}}, primary)
|
||||
}
|
||||
@@ -76,8 +77,9 @@ func (m *default{{.upperStartCamelObject}}Model) queryPrimary(ctx context.Contex
|
||||
}
|
||||
`
|
||||
|
||||
// FindOneMethod defines find row method.
|
||||
var FindOneMethod = `FindOne(ctx context.Context, {{.lowerStartCamelPrimaryKey}} {{.dataType}}) (*{{.upperStartCamelObject}}, error)`
|
||||
// FindOneMethod defines find row method.
|
||||
FindOneMethod = `FindOne(ctx context.Context, {{.lowerStartCamelPrimaryKey}} {{.dataType}}) (*{{.upperStartCamelObject}}, error)`
|
||||
|
||||
// FindOneByFieldMethod defines find row by field method.
|
||||
var FindOneByFieldMethod = `FindOneBy{{.upperField}}(ctx context.Context, {{.in}}) (*{{.upperStartCamelObject}}, error) `
|
||||
// FindOneByFieldMethod defines find row by field method.
|
||||
FindOneByFieldMethod = `FindOneBy{{.upperField}}(ctx context.Context, {{.in}}) (*{{.upperStartCamelObject}}, error) `
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package template
|
||||
|
||||
var (
|
||||
const (
|
||||
// Imports defines a import template for model in cache case
|
||||
Imports = `import (
|
||||
"context"
|
||||
|
||||
@@ -1,19 +1,19 @@
|
||||
package template
|
||||
|
||||
// Insert defines a template for insert code in model
|
||||
var Insert = `
|
||||
const (
|
||||
// Insert defines a template for insert code in model
|
||||
Insert = `
|
||||
func (m *default{{.upperStartCamelObject}}Model) Insert(ctx context.Context, data *{{.upperStartCamelObject}}) (sql.Result,error) {
|
||||
{{if .withCache}}{{if .containsIndexCache}}{{.keys}}
|
||||
{{if .withCache}}{{.keys}}
|
||||
ret, err := m.ExecCtx(ctx, func(ctx context.Context, conn sqlx.SqlConn) (result sql.Result, err error) {
|
||||
query := fmt.Sprintf("insert into %s (%s) values ({{.expression}})", m.table, {{.lowerStartCamelObject}}RowsExpectAutoSet)
|
||||
return conn.ExecCtx(ctx, query, {{.expressionValues}})
|
||||
}, {{.keyValues}}){{else}}query := fmt.Sprintf("insert into %s (%s) values ({{.expression}})", m.table, {{.lowerStartCamelObject}}RowsExpectAutoSet)
|
||||
ret,err:=m.ExecNoCacheCtx(ctx, query, {{.expressionValues}})
|
||||
{{end}}{{else}}query := fmt.Sprintf("insert into %s (%s) values ({{.expression}})", m.table, {{.lowerStartCamelObject}}RowsExpectAutoSet)
|
||||
ret,err:=m.conn.ExecCtx(ctx, query, {{.expressionValues}}){{end}}
|
||||
return ret,err
|
||||
}
|
||||
`
|
||||
|
||||
// InsertMethod defines an interface method template for insert code in model
|
||||
var InsertMethod = `Insert(ctx context.Context, data *{{.upperStartCamelObject}}) (sql.Result,error)`
|
||||
// InsertMethod defines an interface method template for insert code in model
|
||||
InsertMethod = `Insert(ctx context.Context, data *{{.upperStartCamelObject}}) (sql.Result,error)`
|
||||
)
|
||||
|
||||
@@ -1,7 +1,47 @@
|
||||
package template
|
||||
|
||||
// Model defines a template for model
|
||||
var Model = `package {{.pkg}}
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/zeromicro/go-zero/tools/goctl/util"
|
||||
)
|
||||
|
||||
// ModelCustom defines a template for extension
|
||||
const ModelCustom = `package {{.pkg}}
|
||||
{{if .withCache}}
|
||||
import (
|
||||
"github.com/zeromicro/go-zero/core/stores/cache"
|
||||
"github.com/zeromicro/go-zero/core/stores/sqlx"
|
||||
)
|
||||
{{else}}
|
||||
import "github.com/zeromicro/go-zero/core/stores/sqlx"
|
||||
{{end}}
|
||||
var _ {{.upperStartCamelObject}}Model = (*custom{{.upperStartCamelObject}}Model)(nil)
|
||||
|
||||
type (
|
||||
// {{.upperStartCamelObject}}Model is an interface to be customized, add more methods here,
|
||||
// and implement the added methods in custom{{.upperStartCamelObject}}Model.
|
||||
{{.upperStartCamelObject}}Model interface {
|
||||
{{.lowerStartCamelObject}}Model
|
||||
}
|
||||
|
||||
custom{{.upperStartCamelObject}}Model struct {
|
||||
*default{{.upperStartCamelObject}}Model
|
||||
}
|
||||
)
|
||||
|
||||
// New{{.upperStartCamelObject}}Model returns a model for the database table.
|
||||
func New{{.upperStartCamelObject}}Model(conn sqlx.SqlConn{{if .withCache}}, c cache.CacheConf{{end}}) {{.upperStartCamelObject}}Model {
|
||||
return &custom{{.upperStartCamelObject}}Model{
|
||||
default{{.upperStartCamelObject}}Model: new{{.upperStartCamelObject}}Model(conn{{if .withCache}}, c{{end}}),
|
||||
}
|
||||
}
|
||||
`
|
||||
|
||||
// ModelGen defines a template for model
|
||||
var ModelGen = fmt.Sprintf(`%s
|
||||
|
||||
package {{.pkg}}
|
||||
{{.imports}}
|
||||
{{.vars}}
|
||||
{{.types}}
|
||||
@@ -11,4 +51,5 @@ var Model = `package {{.pkg}}
|
||||
{{.update}}
|
||||
{{.delete}}
|
||||
{{.extraMethod}}
|
||||
`
|
||||
{{.tableName}}
|
||||
`, util.DoNotEditHead)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package template
|
||||
|
||||
// New defines an template for creating model instance
|
||||
var New = `
|
||||
func New{{.upperStartCamelObject}}Model(conn sqlx.SqlConn{{if .withCache}}, c cache.CacheConf{{end}}) {{.upperStartCamelObject}}Model {
|
||||
// New defines the template for creating model instance.
|
||||
const New = `
|
||||
func new{{.upperStartCamelObject}}Model(conn sqlx.SqlConn{{if .withCache}}, c cache.CacheConf{{end}}) *default{{.upperStartCamelObject}}Model {
|
||||
return &default{{.upperStartCamelObject}}Model{
|
||||
{{if .withCache}}CachedConn: sqlc.NewConn(conn, c){{else}}conn:conn{{end}},
|
||||
table: {{.table}},
|
||||
|
||||
8
tools/goctl/model/sql/template/tablename.go
Normal file
8
tools/goctl/model/sql/template/tablename.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package template
|
||||
|
||||
// TableName defines a template that generate the tableName method.
|
||||
const TableName = `
|
||||
func (m *default{{.upperStartCamelObject}}Model) tableName() string {
|
||||
return m.table
|
||||
}
|
||||
`
|
||||
@@ -1,4 +1,4 @@
|
||||
package template
|
||||
|
||||
// Tag defines a tag template text
|
||||
var Tag = "`db:\"{{.field}}\"`"
|
||||
const Tag = "`db:\"{{.field}}\"`"
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
package template
|
||||
|
||||
// Types defines a template for types in model
|
||||
var Types = `
|
||||
// Types defines a template for types in model.
|
||||
const Types = `
|
||||
type (
|
||||
{{.upperStartCamelObject}}Model interface{
|
||||
{{.lowerStartCamelObject}}Model interface{
|
||||
{{.method}}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
package template
|
||||
|
||||
// Update defines a template for generating update codes
|
||||
var Update = `
|
||||
const (
|
||||
// Update defines a template for generating update codes
|
||||
Update = `
|
||||
func (m *default{{.upperStartCamelObject}}Model) Update(ctx context.Context, data *{{.upperStartCamelObject}}) error {
|
||||
{{if .withCache}}{{.keys}}
|
||||
_, err := m.ExecCtx(ctx, func(ctx context.Context, conn sqlx.SqlConn) (result sql.Result, err error) {
|
||||
@@ -13,5 +14,6 @@ func (m *default{{.upperStartCamelObject}}Model) Update(ctx context.Context, dat
|
||||
}
|
||||
`
|
||||
|
||||
// UpdateMethod defines an interface method template for generating update codes
|
||||
var UpdateMethod = `Update(ctx context.Context, data *{{.upperStartCamelObject}}) error`
|
||||
// UpdateMethod defines an interface method template for generating update codes
|
||||
UpdateMethod = `Update(ctx context.Context, data *{{.upperStartCamelObject}}) error`
|
||||
)
|
||||
|
||||
@@ -28,7 +28,7 @@ func New() *SortedMap {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *SortedMap) SetExpression(expression string) (key interface{}, value interface{}, err error) {
|
||||
func (m *SortedMap) SetExpression(expression string) (key, value interface{}, err error) {
|
||||
idx := strings.Index(expression, "=")
|
||||
if idx == -1 {
|
||||
return "", "", ErrInvalidKVExpression
|
||||
@@ -86,7 +86,7 @@ func (m *SortedMap) Get(key interface{}) (interface{}, bool) {
|
||||
return e.Value.(KV)[1], true
|
||||
}
|
||||
|
||||
func (m *SortedMap) GetOr(key interface{}, dft interface{}) interface{} {
|
||||
func (m *SortedMap) GetOr(key, dft interface{}) interface{} {
|
||||
e, ok := m.keys[key]
|
||||
if !ok {
|
||||
return dft
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"os"
|
||||
)
|
||||
|
||||
func Download(url string, filename string) error {
|
||||
func Download(url, filename string) error {
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
2
tools/goctl/pkg/env/env.go
vendored
2
tools/goctl/pkg/env/env.go
vendored
@@ -88,7 +88,7 @@ func Get(key string) string {
|
||||
return GetOr(key, "")
|
||||
}
|
||||
|
||||
func GetOr(key string, def string) string {
|
||||
func GetOr(key, def string) string {
|
||||
return goctlEnv.GetStringOr(key, def)
|
||||
}
|
||||
|
||||
|
||||
@@ -4,84 +4,16 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
"github.com/urfave/cli"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/rpc/generator"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/util"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/util/console"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/util/env"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
|
||||
)
|
||||
|
||||
// Deprecated: use ZRPC instead.
|
||||
// RPC is to generate rpc service code from a proto file by specifying a proto file using flag src,
|
||||
// you can specify a target folder for code generation, when the proto file has import, you can specify
|
||||
// the import search directory through the proto_path command, for specific usage, please refer to protoc -h
|
||||
func RPC(c *cli.Context) error {
|
||||
console.Warning("deprecated: use %q instead, for the details see %q",
|
||||
"goctl rpc protoc", "goctl rpc protoc --help")
|
||||
|
||||
if err := prepare(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
src := c.String("src")
|
||||
out := c.String("dir")
|
||||
style := c.String("style")
|
||||
protoImportPath := c.StringSlice("proto_path")
|
||||
goOptions := c.StringSlice("go_opt")
|
||||
home := c.String("home")
|
||||
remote := c.String("remote")
|
||||
branch := c.String("branch")
|
||||
if len(remote) > 0 {
|
||||
repo, _ := util.CloneIntoGitHome(remote, branch)
|
||||
if len(repo) > 0 {
|
||||
home = repo
|
||||
}
|
||||
}
|
||||
if len(home) > 0 {
|
||||
pathx.RegisterGoctlHome(home)
|
||||
}
|
||||
|
||||
if len(src) == 0 {
|
||||
return errors.New("missing -src")
|
||||
}
|
||||
|
||||
if len(out) == 0 {
|
||||
return errors.New("missing -dir")
|
||||
}
|
||||
|
||||
g, err := generator.NewDefaultRPCGenerator(style)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return g.Generate(src, out, protoImportPath, goOptions...)
|
||||
}
|
||||
|
||||
func prepare() error {
|
||||
if !env.CanExec() {
|
||||
return fmt.Errorf("%s: can not start new processes using os.StartProcess or exec.Command", runtime.GOOS)
|
||||
}
|
||||
if _, err := env.LookUpGo(); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := env.LookUpProtoc(); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := env.LookUpProtocGenGo(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RPCNew is to generate rpc greet service, this greet service can speed
|
||||
// up your understanding of the zrpc service structure
|
||||
func RPCNew(c *cli.Context) error {
|
||||
console.Warning("deprecated: it will be removed in the feature, zrpc code generation please use %q instead",
|
||||
"goctl rpc protoc")
|
||||
|
||||
rpcname := c.Args().First()
|
||||
ext := filepath.Ext(rpcname)
|
||||
if len(ext) > 0 {
|
||||
@@ -91,6 +23,7 @@ func RPCNew(c *cli.Context) error {
|
||||
home := c.String("home")
|
||||
remote := c.String("remote")
|
||||
branch := c.String("branch")
|
||||
verbose := c.Bool("verbose")
|
||||
if len(remote) > 0 {
|
||||
repo, _ := util.CloneIntoGitHome(remote, branch)
|
||||
if len(repo) > 0 {
|
||||
@@ -113,12 +46,15 @@ func RPCNew(c *cli.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
g, err := generator.NewDefaultRPCGenerator(style)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return g.Generate(src, filepath.Dir(src), nil)
|
||||
var ctx generator.ZRpcContext
|
||||
ctx.Src = src
|
||||
ctx.GoOutput = filepath.Dir(src)
|
||||
ctx.GrpcOutput = filepath.Dir(src)
|
||||
ctx.IsGooglePlugin = true
|
||||
ctx.Output = filepath.Dir(src)
|
||||
ctx.ProtocCmd = fmt.Sprintf("protoc -I=%s %s --go_out=%s --go-grpc_out=%s", filepath.Dir(src), filepath.Base(src), filepath.Dir(src), filepath.Dir(src))
|
||||
g := generator.NewGenerator(style, verbose)
|
||||
return g.Generate(&ctx)
|
||||
}
|
||||
|
||||
// RPCTemplate is the entry for generate rpc template
|
||||
|
||||
@@ -42,6 +42,7 @@ func ZRPC(c *cli.Context) error {
|
||||
home := c.String("home")
|
||||
remote := c.String("remote")
|
||||
branch := c.String("branch")
|
||||
verbose := c.Bool("verbose")
|
||||
if len(grpcOutList) == 0 {
|
||||
return errInvalidGrpcOutput
|
||||
}
|
||||
@@ -107,12 +108,8 @@ func ZRPC(c *cli.Context) error {
|
||||
ctx.IsGooglePlugin = isGooglePlugin
|
||||
ctx.Output = zrpcOut
|
||||
ctx.ProtocCmd = strings.Join(protocArgs, " ")
|
||||
g, err := generator.NewDefaultRPCGenerator(style, generator.WithZRpcContext(&ctx))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return g.Generate(source, zrpcOut, nil)
|
||||
g := generator.NewGenerator(style, verbose)
|
||||
return g.Generate(&ctx)
|
||||
}
|
||||
|
||||
func removeGoctlFlag(args []string) []string {
|
||||
@@ -121,11 +118,18 @@ func removeGoctlFlag(args []string) []string {
|
||||
for step < len(args) {
|
||||
arg := args[step]
|
||||
switch {
|
||||
case arg == "--style", arg == "--home", arg == "--zrpc_out":
|
||||
case arg == "--style", arg == "--home",
|
||||
arg == "--zrpc_out", arg == "--verbose",
|
||||
arg == "-v", arg == "--remote",
|
||||
arg == "--branch":
|
||||
step += 2
|
||||
continue
|
||||
case strings.HasPrefix(arg, "--style="),
|
||||
strings.HasPrefix(arg, "--home="),
|
||||
strings.HasPrefix(arg, "--verbose="),
|
||||
strings.HasPrefix(arg, "-v="),
|
||||
strings.HasPrefix(arg, "--remote="),
|
||||
strings.HasPrefix(arg, "--branch="),
|
||||
strings.HasPrefix(arg, "--zrpc_out="):
|
||||
step += 1
|
||||
continue
|
||||
|
||||
@@ -83,6 +83,14 @@ func Test_RemoveGoctlFlag(t *testing.T) {
|
||||
source: strings.Fields(`protoc --go_opt=. --go-grpc_out=. --zrpc_out=. foo.proto`),
|
||||
expected: "protoc --go_opt=. --go-grpc_out=. foo.proto",
|
||||
},
|
||||
{
|
||||
source: strings.Fields(`protoc --go_opt=. --go-grpc_out=. --zrpc_out=. --remote=foo --branch=bar foo.proto`),
|
||||
expected: "protoc --go_opt=. --go-grpc_out=. foo.proto",
|
||||
},
|
||||
{
|
||||
source: strings.Fields(`protoc --go_opt=. --go-grpc_out=. --zrpc_out=. --remote foo --branch bar foo.proto`),
|
||||
expected: "protoc --go_opt=. --go-grpc_out=. foo.proto",
|
||||
},
|
||||
}
|
||||
for _, e := range testData {
|
||||
cmd := strings.Join(removeGoctlFlag(e.source), " ")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package common;
|
||||
option go_package="./common";
|
||||
|
||||
message User {
|
||||
string name = 1;
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
package generator
|
||||
|
||||
import (
|
||||
"github.com/zeromicro/go-zero/tools/goctl/env"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/util/console"
|
||||
)
|
||||
|
||||
// DefaultGenerator defines the environment needs of rpc service generation
|
||||
type DefaultGenerator struct {
|
||||
log console.Console
|
||||
}
|
||||
|
||||
// just test interface implement
|
||||
var _ Generator = (*DefaultGenerator)(nil)
|
||||
|
||||
// NewDefaultGenerator returns an instance of DefaultGenerator
|
||||
func NewDefaultGenerator() Generator {
|
||||
log := console.NewColorConsole()
|
||||
return &DefaultGenerator{
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare provides environment detection generated by rpc service,
|
||||
// including go environment, protoc, whether protoc-gen-go is installed or not
|
||||
func (g *DefaultGenerator) Prepare() error {
|
||||
return env.Prepare(true, true)
|
||||
}
|
||||
@@ -3,22 +3,12 @@ package generator
|
||||
import (
|
||||
"path/filepath"
|
||||
|
||||
conf "github.com/zeromicro/go-zero/tools/goctl/config"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/rpc/parser"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/util/console"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/util/ctx"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
|
||||
)
|
||||
|
||||
// RPCGenerator defines a generator and configure
|
||||
type RPCGenerator struct {
|
||||
g Generator
|
||||
cfg *conf.Config
|
||||
ctx *ZRpcContext
|
||||
}
|
||||
|
||||
type RPCGeneratorOption func(g *RPCGenerator)
|
||||
|
||||
type ZRpcContext struct {
|
||||
Src string
|
||||
ProtocCmd string
|
||||
@@ -30,38 +20,11 @@ type ZRpcContext struct {
|
||||
Output string
|
||||
}
|
||||
|
||||
// NewDefaultRPCGenerator wraps Generator with configure
|
||||
func NewDefaultRPCGenerator(style string, options ...RPCGeneratorOption) (*RPCGenerator, error) {
|
||||
cfg, err := conf.NewConfig(style)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewRPCGenerator(NewDefaultGenerator(), cfg, options...), nil
|
||||
}
|
||||
|
||||
// NewRPCGenerator creates an instance for RPCGenerator
|
||||
func NewRPCGenerator(g Generator, cfg *conf.Config, options ...RPCGeneratorOption) *RPCGenerator {
|
||||
out := &RPCGenerator{
|
||||
g: g,
|
||||
cfg: cfg,
|
||||
}
|
||||
for _, opt := range options {
|
||||
opt(out)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func WithZRpcContext(c *ZRpcContext) RPCGeneratorOption {
|
||||
return func(g *RPCGenerator) {
|
||||
g.ctx = c
|
||||
}
|
||||
}
|
||||
|
||||
// Generate generates an rpc service, through the proto file,
|
||||
// code storage directory, and proto import parameters to control
|
||||
// the source file and target location of the rpc service that needs to be generated
|
||||
func (g *RPCGenerator) Generate(src, target string, protoImportPath []string, goOptions ...string) error {
|
||||
abs, err := filepath.Abs(target)
|
||||
func (g *Generator) Generate(zctx *ZRpcContext) error {
|
||||
abs, err := filepath.Abs(zctx.Output)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -71,7 +34,7 @@ func (g *RPCGenerator) Generate(src, target string, protoImportPath []string, go
|
||||
return err
|
||||
}
|
||||
|
||||
err = g.g.Prepare()
|
||||
err = g.Prepare()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -82,52 +45,52 @@ func (g *RPCGenerator) Generate(src, target string, protoImportPath []string, go
|
||||
}
|
||||
|
||||
p := parser.NewDefaultProtoParser()
|
||||
proto, err := p.Parse(src)
|
||||
proto, err := p.Parse(zctx.Src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dirCtx, err := mkdir(projectCtx, proto, g.cfg, g.ctx)
|
||||
dirCtx, err := mkdir(projectCtx, proto, g.cfg, zctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = g.g.GenEtc(dirCtx, proto, g.cfg)
|
||||
err = g.GenEtc(dirCtx, proto, g.cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = g.g.GenPb(dirCtx, protoImportPath, proto, g.cfg, g.ctx, goOptions...)
|
||||
err = g.GenPb(dirCtx, zctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = g.g.GenConfig(dirCtx, proto, g.cfg)
|
||||
err = g.GenConfig(dirCtx, proto, g.cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = g.g.GenSvc(dirCtx, proto, g.cfg)
|
||||
err = g.GenSvc(dirCtx, proto, g.cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = g.g.GenLogic(dirCtx, proto, g.cfg)
|
||||
err = g.GenLogic(dirCtx, proto, g.cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = g.g.GenServer(dirCtx, proto, g.cfg)
|
||||
err = g.GenServer(dirCtx, proto, g.cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = g.g.GenMain(dirCtx, proto, g.cfg)
|
||||
err = g.GenMain(dirCtx, proto, g.cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = g.g.GenCall(dirCtx, proto, g.cfg)
|
||||
err = g.GenCall(dirCtx, proto, g.cfg)
|
||||
|
||||
console.NewColorConsole().MarkDone()
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package generator
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"go/build"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -10,26 +11,19 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/stringx"
|
||||
conf "github.com/zeromicro/go-zero/tools/goctl/config"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/rpc/execx"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
|
||||
)
|
||||
|
||||
var cfg = &conf.Config{
|
||||
NamingFormat: "gozero",
|
||||
}
|
||||
|
||||
func TestRpcGenerate(t *testing.T) {
|
||||
_ = Clean()
|
||||
dispatcher := NewDefaultGenerator()
|
||||
err := dispatcher.Prepare()
|
||||
g := NewGenerator("gozero", true)
|
||||
err := g.Prepare()
|
||||
if err != nil {
|
||||
logx.Error(err)
|
||||
return
|
||||
}
|
||||
projectName := stringx.Rand()
|
||||
g := NewRPCGenerator(dispatcher, cfg)
|
||||
|
||||
src := filepath.Join(build.Default.GOPATH, "src")
|
||||
_, err = os.Stat(src)
|
||||
if err != nil {
|
||||
@@ -46,7 +40,15 @@ func TestRpcGenerate(t *testing.T) {
|
||||
|
||||
// case go path
|
||||
t.Run("GOPATH", func(t *testing.T) {
|
||||
err = g.Generate("./test.proto", projectDir, []string{common}, "Mbase/common.proto=./base")
|
||||
ctx := &ZRpcContext{
|
||||
Src: "./test.proto",
|
||||
ProtocCmd: fmt.Sprintf("protoc -I=%s test.proto --go_out=%s --go_opt=Mbase/common.proto=./base --go-grpc_out=%s", common, projectDir, projectDir),
|
||||
IsGooglePlugin: true,
|
||||
GoOutput: projectDir,
|
||||
GrpcOutput: projectDir,
|
||||
Output: projectDir,
|
||||
}
|
||||
err = g.Generate(ctx)
|
||||
assert.Nil(t, err)
|
||||
_, err = execx.Run("go test "+projectName, projectDir)
|
||||
if err != nil {
|
||||
@@ -67,7 +69,15 @@ func TestRpcGenerate(t *testing.T) {
|
||||
}
|
||||
|
||||
projectDir = filepath.Join(workDir, projectName)
|
||||
err = g.Generate("./test.proto", projectDir, []string{common}, "Mbase/common.proto=./base")
|
||||
ctx := &ZRpcContext{
|
||||
Src: "./test.proto",
|
||||
ProtocCmd: fmt.Sprintf("protoc -I=%s test.proto --go_out=%s --go_opt=Mbase/common.proto=./base --go-grpc_out=%s", common, projectDir, projectDir),
|
||||
IsGooglePlugin: true,
|
||||
GoOutput: projectDir,
|
||||
GrpcOutput: projectDir,
|
||||
Output: projectDir,
|
||||
}
|
||||
err = g.Generate(ctx)
|
||||
assert.Nil(t, err)
|
||||
_, err = execx.Run("go test "+projectName, projectDir)
|
||||
if err != nil {
|
||||
@@ -79,7 +89,15 @@ func TestRpcGenerate(t *testing.T) {
|
||||
|
||||
// case not in go mod and go path
|
||||
t.Run("OTHER", func(t *testing.T) {
|
||||
err = g.Generate("./test.proto", projectDir, []string{common, src}, "Mbase/common.proto=./base")
|
||||
ctx := &ZRpcContext{
|
||||
Src: "./test.proto",
|
||||
ProtocCmd: fmt.Sprintf("protoc -I=%s test.proto --go_out=%s --go_opt=Mbase/common.proto=./base --go-grpc_out=%s", common, projectDir, projectDir),
|
||||
IsGooglePlugin: true,
|
||||
GoOutput: projectDir,
|
||||
GrpcOutput: projectDir,
|
||||
Output: projectDir,
|
||||
}
|
||||
err = g.Generate(ctx)
|
||||
assert.Nil(t, err)
|
||||
_, err = execx.Run("go test "+projectName, projectDir)
|
||||
if err != nil {
|
||||
|
||||
@@ -58,7 +58,7 @@ func New{{.serviceName}}(cli zrpc.Client) {{.serviceName}} {
|
||||
callFunctionTemplate = `
|
||||
{{if .hasComment}}{{.comment}}{{end}}
|
||||
func (m *default{{.serviceName}}) {{.method}}(ctx context.Context{{if .hasReq}}, in *{{.pbRequest}}{{end}}, opts ...grpc.CallOption) ({{if .notStream}}*{{.pbResponse}}, {{else}}{{.streamBody}},{{end}} error) {
|
||||
client := {{.package}}.New{{.rpcServiceName}}Client(m.cli.Conn())
|
||||
client := {{if .isCallPkgSameToGrpcPkg}}{{else}}{{.package}}.{{end}}New{{.rpcServiceName}}Client(m.cli.Conn())
|
||||
return client.{{.method}}(ctx{{if .hasReq}}, in{{end}}, opts...)
|
||||
}
|
||||
`
|
||||
@@ -66,10 +66,12 @@ func (m *default{{.serviceName}}) {{.method}}(ctx context.Context{{if .hasReq}},
|
||||
|
||||
// GenCall generates the rpc client code, which is the entry point for the rpc service call.
|
||||
// It is a layer of encapsulation for the rpc client and shields the details in the pb.
|
||||
func (g *DefaultGenerator) GenCall(ctx DirContext, proto parser.Proto, cfg *conf.Config) error {
|
||||
func (g *Generator) GenCall(ctx DirContext, proto parser.Proto, cfg *conf.Config) error {
|
||||
dir := ctx.GetCall()
|
||||
service := proto.Service
|
||||
head := util.GetHead(proto.Name)
|
||||
isCallPkgSameToPbPkg := ctx.GetCall().Filename == ctx.GetPb().Filename
|
||||
isCallPkgSameToGrpcPkg := ctx.GetCall().Filename == ctx.GetProtoGo().Filename
|
||||
|
||||
callFilename, err := format.FileNamingFormat(cfg.NamingFormat, service.Name)
|
||||
if err != nil {
|
||||
@@ -77,12 +79,12 @@ func (g *DefaultGenerator) GenCall(ctx DirContext, proto parser.Proto, cfg *conf
|
||||
}
|
||||
|
||||
filename := filepath.Join(dir.Filename, fmt.Sprintf("%s.go", callFilename))
|
||||
functions, err := g.genFunction(proto.PbPackage, service)
|
||||
functions, err := g.genFunction(proto.PbPackage, service, isCallPkgSameToGrpcPkg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
iFunctions, err := g.getInterfaceFuncs(proto.PbPackage, service)
|
||||
iFunctions, err := g.getInterfaceFuncs(proto.PbPackage, service, isCallPkgSameToGrpcPkg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -93,11 +95,19 @@ func (g *DefaultGenerator) GenCall(ctx DirContext, proto parser.Proto, cfg *conf
|
||||
}
|
||||
|
||||
alias := collection.NewSet()
|
||||
for _, item := range proto.Message {
|
||||
msgName := getMessageName(*item.Message)
|
||||
alias.AddStr(fmt.Sprintf("%s = %s", parser.CamelCase(msgName), fmt.Sprintf("%s.%s", proto.PbPackage, parser.CamelCase(msgName))))
|
||||
if !isCallPkgSameToPbPkg {
|
||||
for _, item := range proto.Message {
|
||||
msgName := getMessageName(*item.Message)
|
||||
alias.AddStr(fmt.Sprintf("%s = %s", parser.CamelCase(msgName), fmt.Sprintf("%s.%s", proto.PbPackage, parser.CamelCase(msgName))))
|
||||
}
|
||||
}
|
||||
|
||||
pbPackage := fmt.Sprintf(`"%s"`, ctx.GetPb().Package)
|
||||
protoGoPackage := fmt.Sprintf(`"%s"`, ctx.GetProtoGo().Package)
|
||||
if isCallPkgSameToGrpcPkg {
|
||||
pbPackage = ""
|
||||
protoGoPackage = ""
|
||||
}
|
||||
aliasKeys := alias.KeysStr()
|
||||
sort.Strings(aliasKeys)
|
||||
err = util.With("shared").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
|
||||
@@ -105,8 +115,8 @@ func (g *DefaultGenerator) GenCall(ctx DirContext, proto parser.Proto, cfg *conf
|
||||
"alias": strings.Join(aliasKeys, pathx.NL),
|
||||
"head": head,
|
||||
"filePackage": dir.Base,
|
||||
"pbPackage": fmt.Sprintf(`"%s"`, ctx.GetPb().Package),
|
||||
"protoGoPackage": fmt.Sprintf(`"%s"`, ctx.GetProtoGo().Package),
|
||||
"pbPackage": pbPackage,
|
||||
"protoGoPackage": protoGoPackage,
|
||||
"serviceName": stringx.From(service.Name).ToCamel(),
|
||||
"functions": strings.Join(functions, pathx.NL),
|
||||
"interface": strings.Join(iFunctions, pathx.NL),
|
||||
@@ -136,7 +146,7 @@ func getMessageName(msg proto.Message) string {
|
||||
return strings.Join(list, "_")
|
||||
}
|
||||
|
||||
func (g *DefaultGenerator) genFunction(goPackage string, service parser.Service) ([]string, error) {
|
||||
func (g *Generator) genFunction(goPackage string, service parser.Service, isCallPkgSameToGrpcPkg bool) ([]string, error) {
|
||||
functions := make([]string, 0)
|
||||
|
||||
for _, rpc := range service.RPC {
|
||||
@@ -147,18 +157,22 @@ func (g *DefaultGenerator) genFunction(goPackage string, service parser.Service)
|
||||
|
||||
comment := parser.GetComment(rpc.Doc())
|
||||
streamServer := fmt.Sprintf("%s.%s_%s%s", goPackage, parser.CamelCase(service.Name), parser.CamelCase(rpc.Name), "Client")
|
||||
if isCallPkgSameToGrpcPkg {
|
||||
streamServer = fmt.Sprintf("%s_%s%s", parser.CamelCase(service.Name), parser.CamelCase(rpc.Name), "Client")
|
||||
}
|
||||
buffer, err := util.With("sharedFn").Parse(text).Execute(map[string]interface{}{
|
||||
"serviceName": stringx.From(service.Name).ToCamel(),
|
||||
"rpcServiceName": parser.CamelCase(service.Name),
|
||||
"method": parser.CamelCase(rpc.Name),
|
||||
"package": goPackage,
|
||||
"pbRequest": parser.CamelCase(rpc.RequestType),
|
||||
"pbResponse": parser.CamelCase(rpc.ReturnsType),
|
||||
"hasComment": len(comment) > 0,
|
||||
"comment": comment,
|
||||
"hasReq": !rpc.StreamsRequest,
|
||||
"notStream": !rpc.StreamsRequest && !rpc.StreamsReturns,
|
||||
"streamBody": streamServer,
|
||||
"serviceName": stringx.From(service.Name).ToCamel(),
|
||||
"rpcServiceName": parser.CamelCase(service.Name),
|
||||
"method": parser.CamelCase(rpc.Name),
|
||||
"package": goPackage,
|
||||
"pbRequest": parser.CamelCase(rpc.RequestType),
|
||||
"pbResponse": parser.CamelCase(rpc.ReturnsType),
|
||||
"hasComment": len(comment) > 0,
|
||||
"comment": comment,
|
||||
"hasReq": !rpc.StreamsRequest,
|
||||
"notStream": !rpc.StreamsRequest && !rpc.StreamsReturns,
|
||||
"streamBody": streamServer,
|
||||
"isCallPkgSameToGrpcPkg": isCallPkgSameToGrpcPkg,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -170,7 +184,7 @@ func (g *DefaultGenerator) genFunction(goPackage string, service parser.Service)
|
||||
return functions, nil
|
||||
}
|
||||
|
||||
func (g *DefaultGenerator) getInterfaceFuncs(goPackage string, service parser.Service) ([]string, error) {
|
||||
func (g *Generator) getInterfaceFuncs(goPackage string, service parser.Service, isCallPkgSameToGrpcPkg bool) ([]string, error) {
|
||||
functions := make([]string, 0)
|
||||
|
||||
for _, rpc := range service.RPC {
|
||||
@@ -181,6 +195,9 @@ func (g *DefaultGenerator) getInterfaceFuncs(goPackage string, service parser.Se
|
||||
|
||||
comment := parser.GetComment(rpc.Doc())
|
||||
streamServer := fmt.Sprintf("%s.%s_%s%s", goPackage, parser.CamelCase(service.Name), parser.CamelCase(rpc.Name), "Client")
|
||||
if isCallPkgSameToGrpcPkg {
|
||||
streamServer = fmt.Sprintf("%s_%s%s", parser.CamelCase(service.Name), parser.CamelCase(rpc.Name), "Client")
|
||||
}
|
||||
buffer, err := util.With("interfaceFn").Parse(text).Execute(
|
||||
map[string]interface{}{
|
||||
"hasComment": len(comment) > 0,
|
||||
|
||||
@@ -24,7 +24,7 @@ type Config struct {
|
||||
// which contains the zrpc.RpcServerConf configuration item by default.
|
||||
// You can specify the naming style of the target file name through config.Config. For details,
|
||||
// see https://github.com/zeromicro/go-zero/tree/master/tools/goctl/config/config.go
|
||||
func (g *DefaultGenerator) GenConfig(ctx DirContext, _ parser.Proto, cfg *conf.Config) error {
|
||||
func (g *Generator) GenConfig(ctx DirContext, _ parser.Proto, cfg *conf.Config) error {
|
||||
dir := ctx.GetConfig()
|
||||
configFilename, err := format.FileNamingFormat(cfg.NamingFormat, "config")
|
||||
if err != nil {
|
||||
|
||||
@@ -1,19 +1,36 @@
|
||||
package generator
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
conf "github.com/zeromicro/go-zero/tools/goctl/config"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/rpc/parser"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/env"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/util/console"
|
||||
)
|
||||
|
||||
// Generator defines a generator interface to describe how to generate rpc service
|
||||
type Generator interface {
|
||||
Prepare() error
|
||||
GenMain(ctx DirContext, proto parser.Proto, cfg *conf.Config) error
|
||||
GenCall(ctx DirContext, proto parser.Proto, cfg *conf.Config) error
|
||||
GenEtc(ctx DirContext, proto parser.Proto, cfg *conf.Config) error
|
||||
GenConfig(ctx DirContext, proto parser.Proto, cfg *conf.Config) error
|
||||
GenLogic(ctx DirContext, proto parser.Proto, cfg *conf.Config) error
|
||||
GenServer(ctx DirContext, proto parser.Proto, cfg *conf.Config) error
|
||||
GenSvc(ctx DirContext, proto parser.Proto, cfg *conf.Config) error
|
||||
GenPb(ctx DirContext, protoImportPath []string, proto parser.Proto, cfg *conf.Config, c *ZRpcContext, goOptions ...string) error
|
||||
// Generator defines the environment needs of rpc service generation
|
||||
type Generator struct {
|
||||
log console.Console
|
||||
cfg *conf.Config
|
||||
verbose bool
|
||||
}
|
||||
|
||||
// NewGenerator returns an instance of Generator
|
||||
func NewGenerator(style string, verbose bool) *Generator {
|
||||
cfg, err := conf.NewConfig(style)
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
log := console.NewColorConsole(verbose)
|
||||
return &Generator{
|
||||
log: log,
|
||||
cfg: cfg,
|
||||
verbose: verbose,
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare provides environment detection generated by rpc service,
|
||||
// including go environment, protoc, whether protoc-gen-go is installed or not
|
||||
func (g *Generator) Prepare() error {
|
||||
return env.Prepare(true, true, g.verbose)
|
||||
}
|
||||
|
||||
@@ -23,7 +23,7 @@ Etcd:
|
||||
|
||||
// GenEtc generates the yaml configuration file of the rpc service,
|
||||
// including host, port monitoring configuration items and etcd configuration
|
||||
func (g *DefaultGenerator) GenEtc(ctx DirContext, _ parser.Proto, cfg *conf.Config) error {
|
||||
func (g *Generator) GenEtc(ctx DirContext, _ parser.Proto, cfg *conf.Config) error {
|
||||
dir := ctx.GetEtc()
|
||||
etcFilename, err := format.FileNamingFormat(cfg.NamingFormat, ctx.GetServiceName().Source())
|
||||
if err != nil {
|
||||
|
||||
@@ -50,7 +50,7 @@ func (l *{{.logicName}}) {{.method}} ({{if .hasReq}}in {{.request}}{{if .stream}
|
||||
)
|
||||
|
||||
// GenLogic generates the logic file of the rpc service, which corresponds to the RPC definition items in proto.
|
||||
func (g *DefaultGenerator) GenLogic(ctx DirContext, proto parser.Proto, cfg *conf.Config) error {
|
||||
func (g *Generator) GenLogic(ctx DirContext, proto parser.Proto, cfg *conf.Config) error {
|
||||
dir := ctx.GetLogic()
|
||||
service := proto.Service.Service.Name
|
||||
for _, rpc := range proto.Service.RPC {
|
||||
@@ -84,7 +84,7 @@ func (g *DefaultGenerator) GenLogic(ctx DirContext, proto parser.Proto, cfg *con
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *DefaultGenerator) genLogicFunction(serviceName, goPackage string, rpc *parser.RPC) (string, error) {
|
||||
func (g *Generator) genLogicFunction(serviceName, goPackage string, rpc *parser.RPC) (string, error) {
|
||||
functions := make([]string, 0)
|
||||
text, err := pathx.LoadTemplate(category, logicFuncTemplateFileFile, logicFunctionTemplate)
|
||||
if err != nil {
|
||||
|
||||
@@ -53,7 +53,7 @@ func main() {
|
||||
`
|
||||
|
||||
// GenMain generates the main file of the rpc service, which is an rpc service program call entry
|
||||
func (g *DefaultGenerator) GenMain(ctx DirContext, proto parser.Proto, cfg *conf.Config) error {
|
||||
func (g *Generator) GenMain(ctx DirContext, proto parser.Proto, cfg *conf.Config) error {
|
||||
mainFilename, err := format.FileNamingFormat(cfg.NamingFormat, ctx.GetServiceName().Source())
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -1,118 +1,22 @@
|
||||
package generator
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/collection"
|
||||
conf "github.com/zeromicro/go-zero/tools/goctl/config"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/rpc/execx"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/rpc/parser"
|
||||
)
|
||||
|
||||
const googleProtocGenGoErr = `--go_out: protoc-gen-go: plugins are not supported; use 'protoc --go-grpc_out=...' to generate gRPC`
|
||||
|
||||
// GenPb generates the pb.go file, which is a layer of packaging for protoc to generate gprc,
|
||||
// but the commands and flags in protoc are not completely joined in goctl. At present, proto_path(-I) is introduced
|
||||
func (g *DefaultGenerator) GenPb(ctx DirContext, protoImportPath []string, proto parser.Proto, _ *conf.Config, c *ZRpcContext, goOptions ...string) error {
|
||||
if c != nil {
|
||||
return g.genPbDirect(ctx, c)
|
||||
}
|
||||
|
||||
// deprecated: use genPbDirect instead.
|
||||
dir := ctx.GetPb()
|
||||
cw := new(bytes.Buffer)
|
||||
directory, base := filepath.Split(proto.Src)
|
||||
directory = filepath.Clean(directory)
|
||||
cw.WriteString("protoc ")
|
||||
protoImportPathSet := collection.NewSet()
|
||||
isSamePackage := true
|
||||
for _, ip := range protoImportPath {
|
||||
pip := " --proto_path=" + ip
|
||||
if protoImportPathSet.Contains(pip) {
|
||||
continue
|
||||
}
|
||||
|
||||
abs, err := filepath.Abs(ip)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if abs == directory {
|
||||
isSamePackage = true
|
||||
} else {
|
||||
isSamePackage = false
|
||||
}
|
||||
|
||||
protoImportPathSet.AddStr(pip)
|
||||
cw.WriteString(pip)
|
||||
}
|
||||
currentPath := " --proto_path=" + directory
|
||||
if !protoImportPathSet.Contains(currentPath) {
|
||||
cw.WriteString(currentPath)
|
||||
}
|
||||
|
||||
cw.WriteString(" " + proto.Name)
|
||||
if strings.Contains(proto.GoPackage, "/") {
|
||||
cw.WriteString(" --go_out=plugins=grpc:" + ctx.GetMain().Filename)
|
||||
} else {
|
||||
cw.WriteString(" --go_out=plugins=grpc:" + dir.Filename)
|
||||
}
|
||||
|
||||
// Compatible with version 1.4.0,github.com/golang/protobuf/protoc-gen-go@v1.4.0
|
||||
// --go_opt usage please see https://developers.google.com/protocol-buffers/docs/reference/go-generated#package
|
||||
optSet := collection.NewSet()
|
||||
for _, op := range goOptions {
|
||||
opt := " --go_opt=" + op
|
||||
if optSet.Contains(opt) {
|
||||
continue
|
||||
}
|
||||
|
||||
optSet.AddStr(op)
|
||||
cw.WriteString(" --go_opt=" + op)
|
||||
}
|
||||
|
||||
var currentFileOpt string
|
||||
if !isSamePackage || (len(proto.GoPackage) > 0 && proto.GoPackage != proto.Package.Name) {
|
||||
if filepath.IsAbs(proto.GoPackage) {
|
||||
currentFileOpt = " --go_opt=M" + base + "=" + proto.GoPackage
|
||||
} else if strings.Contains(proto.GoPackage, string(filepath.Separator)) {
|
||||
currentFileOpt = " --go_opt=M" + base + "=./" + proto.GoPackage
|
||||
} else {
|
||||
currentFileOpt = " --go_opt=M" + base + "=../" + proto.GoPackage
|
||||
}
|
||||
} else {
|
||||
currentFileOpt = " --go_opt=M" + base + "=."
|
||||
}
|
||||
|
||||
if !optSet.Contains(currentFileOpt) {
|
||||
cw.WriteString(currentFileOpt)
|
||||
}
|
||||
|
||||
command := cw.String()
|
||||
g.log.Debug(command)
|
||||
_, err := execx.Run(command, "")
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), googleProtocGenGoErr) {
|
||||
return errors.New(`unsupported plugin protoc-gen-go which installed from the following source:
|
||||
google.golang.org/protobuf/cmd/protoc-gen-go,
|
||||
github.com/protocolbuffers/protobuf-go/cmd/protoc-gen-go;
|
||||
|
||||
Please replace it by the following command, we recommend to use version before v1.3.5:
|
||||
go get -u github.com/golang/protobuf/protoc-gen-go`)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
func (g *Generator) GenPb(ctx DirContext, c *ZRpcContext) error {
|
||||
return g.genPbDirect(ctx, c)
|
||||
}
|
||||
|
||||
func (g *DefaultGenerator) genPbDirect(ctx DirContext, c *ZRpcContext) error {
|
||||
func (g *Generator) genPbDirect(ctx DirContext, c *ZRpcContext) error {
|
||||
g.log.Debug("[command]: %s", c.ProtocCmd)
|
||||
pwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
@@ -126,7 +30,7 @@ func (g *DefaultGenerator) genPbDirect(ctx DirContext, c *ZRpcContext) error {
|
||||
return g.setPbDir(ctx, c)
|
||||
}
|
||||
|
||||
func (g *DefaultGenerator) setPbDir(ctx DirContext, c *ZRpcContext) error {
|
||||
func (g *Generator) setPbDir(ctx DirContext, c *ZRpcContext) error {
|
||||
pbDir, err := findPbFile(c.GoOutput, false)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -25,7 +25,7 @@ message Resp{}
|
||||
service Greeter {
|
||||
rpc greet(Req) returns (Resp);
|
||||
}
|
||||
`), 0666)
|
||||
`), 0o666)
|
||||
if err != nil {
|
||||
t.Log(err)
|
||||
return
|
||||
|
||||
@@ -48,7 +48,7 @@ func (s *{{.server}}Server) {{.method}} ({{if .notStream}}ctx context.Context,{{
|
||||
)
|
||||
|
||||
// GenServer generates rpc server file, which is an implementation of rpc server
|
||||
func (g *DefaultGenerator) GenServer(ctx DirContext, proto parser.Proto, cfg *conf.Config) error {
|
||||
func (g *Generator) GenServer(ctx DirContext, proto parser.Proto, cfg *conf.Config) error {
|
||||
dir := ctx.GetServer()
|
||||
logicImport := fmt.Sprintf(`"%v"`, ctx.GetLogic().Package)
|
||||
svcImport := fmt.Sprintf(`"%v"`, ctx.GetSvc().Package)
|
||||
@@ -94,7 +94,7 @@ func (g *DefaultGenerator) GenServer(ctx DirContext, proto parser.Proto, cfg *co
|
||||
return err
|
||||
}
|
||||
|
||||
func (g *DefaultGenerator) genFunctions(goPackage string, service parser.Service) ([]string, error) {
|
||||
func (g *Generator) genFunctions(goPackage string, service parser.Service) ([]string, error) {
|
||||
var functionList []string
|
||||
for _, rpc := range service.RPC {
|
||||
text, err := pathx.LoadTemplate(category, serverFuncTemplateFile, functionTemplate)
|
||||
|
||||
@@ -28,7 +28,7 @@ func NewServiceContext(c config.Config) *ServiceContext {
|
||||
|
||||
// GenSvc generates the servicecontext.go file, which is the resource dependency of a service,
|
||||
// such as rpc dependency, model dependency, etc.
|
||||
func (g *DefaultGenerator) GenSvc(ctx DirContext, _ parser.Proto, cfg *conf.Config) error {
|
||||
func (g *Generator) GenSvc(ctx DirContext, _ parser.Proto, cfg *conf.Config) error {
|
||||
dir := ctx.GetSvc()
|
||||
svcFilename, err := format.FileNamingFormat(cfg.NamingFormat, "service_context")
|
||||
if err != nil {
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
const rpcTemplateText = `syntax = "proto3";
|
||||
|
||||
package {{.package}};
|
||||
option go_package="./{{.package}}";
|
||||
|
||||
message Request {
|
||||
string ping = 1;
|
||||
|
||||
@@ -24,7 +24,9 @@ type (
|
||||
Must(err error)
|
||||
}
|
||||
|
||||
colorConsole struct{}
|
||||
colorConsole struct {
|
||||
enable bool
|
||||
}
|
||||
|
||||
// for idea log
|
||||
ideaConsole struct{}
|
||||
@@ -39,45 +41,75 @@ func NewConsole(idea bool) Console {
|
||||
}
|
||||
|
||||
// NewColorConsole returns an instance of colorConsole
|
||||
func NewColorConsole() Console {
|
||||
return &colorConsole{}
|
||||
func NewColorConsole(enable ...bool) Console {
|
||||
logEnable := true
|
||||
for _, e := range enable {
|
||||
logEnable = e
|
||||
}
|
||||
return &colorConsole{
|
||||
enable: logEnable,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *colorConsole) Info(format string, a ...interface{}) {
|
||||
if !c.enable {
|
||||
return
|
||||
}
|
||||
msg := fmt.Sprintf(format, a...)
|
||||
fmt.Println(msg)
|
||||
}
|
||||
|
||||
func (c *colorConsole) Debug(format string, a ...interface{}) {
|
||||
if !c.enable {
|
||||
return
|
||||
}
|
||||
msg := fmt.Sprintf(format, a...)
|
||||
println(aurora.BrightCyan(msg))
|
||||
}
|
||||
|
||||
func (c *colorConsole) Success(format string, a ...interface{}) {
|
||||
if !c.enable {
|
||||
return
|
||||
}
|
||||
msg := fmt.Sprintf(format, a...)
|
||||
println(aurora.BrightGreen(msg))
|
||||
}
|
||||
|
||||
func (c *colorConsole) Warning(format string, a ...interface{}) {
|
||||
if !c.enable {
|
||||
return
|
||||
}
|
||||
msg := fmt.Sprintf(format, a...)
|
||||
println(aurora.BrightYellow(msg))
|
||||
}
|
||||
|
||||
func (c *colorConsole) Error(format string, a ...interface{}) {
|
||||
if !c.enable {
|
||||
return
|
||||
}
|
||||
msg := fmt.Sprintf(format, a...)
|
||||
println(aurora.BrightRed(msg))
|
||||
}
|
||||
|
||||
func (c *colorConsole) Fatalln(format string, a ...interface{}) {
|
||||
if !c.enable {
|
||||
return
|
||||
}
|
||||
c.Error(format, a...)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
func (c *colorConsole) MarkDone() {
|
||||
if !c.enable {
|
||||
return
|
||||
}
|
||||
c.Success("Done.")
|
||||
}
|
||||
|
||||
func (c *colorConsole) Must(err error) {
|
||||
if !c.enable {
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
c.Fatalln("%+v", err)
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
|
||||
)
|
||||
|
||||
func CloneIntoGitHome(url string, branch string) (dir string, err error) {
|
||||
func CloneIntoGitHome(url, branch string) (dir string, err error) {
|
||||
gitHome, err := pathx.GetGitHome()
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
package util
|
||||
|
||||
var headTemplate = `// Code generated by goctl. DO NOT EDIT!
|
||||
const (
|
||||
// DoNotEditHead added to the beginning of a file to prompt the user not to edit
|
||||
DoNotEditHead = "// Code generated by goctl. DO NOT EDIT!"
|
||||
|
||||
headTemplate = DoNotEditHead + `
|
||||
// Source: {{.source}}`
|
||||
)
|
||||
|
||||
// GetHead returns a code head string with source filename
|
||||
func GetHead(source string) string {
|
||||
|
||||
@@ -79,7 +79,7 @@ func TestGetGoctlHome(t *testing.T) {
|
||||
t.Run("goctl_is_file", func(t *testing.T) {
|
||||
tmpFile := filepath.Join(t.TempDir(), "a.tmp")
|
||||
backupTempFile := tmpFile + ".old"
|
||||
err := ioutil.WriteFile(tmpFile, nil, 0666)
|
||||
err := ioutil.WriteFile(tmpFile, nil, 0o666)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -104,5 +104,4 @@ func TestGetGoctlHome(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, dir, home)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user