Compare commits

...

30 Commits

Author SHA1 Message Date
Kevin Wan
78ea0769fd feat: simplify httpc (#1748)
* feat: simplify httpc

* chore: fix lint errors

* chore: fix log url issue

* chore: fix log url issue

* refactor: handle resp & err in ResponseHandler

* chore: remove unnecessary var names in return clause
2022-04-03 14:32:27 +08:00
Kevin Wan
e0fa8d820d feat: return original value of setbit in redis (#1746) 2022-04-02 20:25:51 +08:00
Kevin Wan
dfd58c213c fix: model generation bug on with cache (#1743)
* fix: model generation bug on with cache

* chore: refine template

* chore: fix test failure
2022-04-02 15:36:06 +08:00
Kevin Wan
83cacf51b7 chore: update goctl version to 1.3.4 (#1742) 2022-04-02 14:19:34 +08:00
Kevin Wan
6dccfa29fd feat: let model customizable (#1738) 2022-04-01 22:19:52 +08:00
anqiansong
7e0b0ab0b1 Fix zrpc code generation error with --remote (#1739)
Co-authored-by: anqiansong <anqiansong@bytedance.com>
2022-04-01 22:19:33 +08:00
Kevin Wan
ac18cc470d chore: refactor to use const instead of var (#1731) 2022-04-01 15:23:45 +08:00
Fyn
f4471846ff feat(goctl): supports model code 'DO NOT EDIT' (#1728)
Resolves: #1710
2022-04-01 14:48:45 +08:00
anqiansong
9c2d526a11 Fix unit test (#1730)
Co-authored-by: anqiansong <anqiansong@bytedance.com>
2022-04-01 14:46:12 +08:00
Kevin Wan
2b9fc26c38 refactor: guard timeout on API files (#1726) 2022-03-31 21:39:02 +08:00
Xiaoju Jiang
321dc2d410 Added support for setting the parameter size accepted by the interface and custom timeout and maxbytes in API file (#1713)
* Added support for setting the parameter size accepted by the interface

* support custom timeout and maxbytes in API file

* support timeout used unit

* remove goctl maxBytes
2022-03-31 20:20:00 +08:00
Fyn
500bd87c85 fix(goctl): api format with reader input (#1722)
resolves #1721
2022-03-31 00:20:51 +08:00
Kevin Wan
e9620c8c05 chore: refactor code (#1708) 2022-03-24 22:10:15 +08:00
aimuz
70e51bb352 fix: empty slice are set to nil (#1702)
support for empty slce, Same behavior as json.Unmarshal
2022-03-24 21:41:38 +08:00
Kevin Wan
278cd123c8 feat: remove reentrance in redislock, timeout bug (#1704) 2022-03-24 16:17:01 +08:00
Kevin Wan
3febb1a5d0 chore: refactor code (#1700) 2022-03-23 19:09:45 +08:00
Mikael
d8054d8def fix -cache=true insert no clean cache (#1672)
* fix -cache=true insert no clean cache

* fix -cache=true insert no clean cache
2022-03-23 18:55:16 +08:00
Kevin Wan
ec271db7a0 chore: refactor code (#1699) 2022-03-23 18:24:44 +08:00
benqi
bbac994c8a feat: add getset command in redis and kv (#1693) 2022-03-23 18:02:56 +08:00
Kevin Wan
c1d9e6a00b feat: add httpc.Parse (#1698) 2022-03-23 17:58:21 +08:00
anqiansong
0aeb49a6b0 Add verbose flag (#1696)
Co-authored-by: anqiansong <anqiansong@bytedance.com>
2022-03-22 21:00:26 +08:00
Kevin Wan
fe262766b4 chore: fix lint issue (#1694) 2022-03-22 13:31:05 +08:00
Kevin Wan
7181505c8a Update LICENSE 2022-03-21 10:32:41 +08:00
Kevin Wan
f060a226bc refactor: simplify the code (#1670) 2022-03-20 17:26:12 +08:00
Mervin.Wong
93d524b797 fix: the new RawFieldNames considers the tag with options. (#1663)
Co-authored-by: JinfaWang <wangjinfa@iie.ac.cn>
2022-03-20 16:59:19 +08:00
anqiansong
5c169f4f49 Remove debug log (#1669)
Co-authored-by: anqiansong <anqiansong@bytedance.com>
2022-03-20 16:28:36 +08:00
Kevin Wan
d29dfa12e3 feat: support -base to specify base image for goctl docker (#1668)
* feat: support -base to specify base image for goctl docker

* chore: update usage
2022-03-20 11:17:55 +08:00
anqiansong
194f55e08e Remove unused code (#1667)
Co-authored-by: anqiansong <anqiansong@bytedance.com>
2022-03-19 23:15:11 +08:00
Kevin Wan
c0f9892fe3 feat: add Dockerfile for goctl (#1666) 2022-03-19 23:07:17 +08:00
anqiansong
227104d7d7 feat: Remove command goctl rpc proto (#1665)
* Fix goctl completion expression

* Fix code generation error if the pkg of pb/grpc is same to zrpc call client pkg

* Remove deprecated comment on action goctl rpc new

* Remove zrpc code generation on action goctl rpc proto

* Remove zrpc code generation on action goctl rpc proto

* Remove Generator interface

Co-authored-by: anqiansong <anqiansong@bytedance.com>
2022-03-19 22:50:22 +08:00
86 changed files with 1132 additions and 648 deletions

View File

@@ -1,6 +1,6 @@
MIT License
Copyright (c) 2020 xiaoheiban_server_go
Copyright (c) 2022 zeromicro
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal

View File

@@ -448,7 +448,15 @@ func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value, map
dereffedBaseType := Deref(baseType)
dereffedBaseKind := dereffedBaseType.Kind()
refValue := reflect.ValueOf(mapValue)
if refValue.IsNil() {
return nil
}
conv := reflect.MakeSlice(reflect.SliceOf(baseType), refValue.Len(), refValue.Cap())
if refValue.Len() == 0 {
value.Set(conv)
return nil
}
var valid bool
for i := 0; i < refValue.Len(); i++ {

View File

@@ -198,6 +198,49 @@ func TestUnmarshalIntWithDefault(t *testing.T) {
assert.Equal(t, 1, in.Int)
}
func TestUnmarshalBoolSliceRequired(t *testing.T) {
type inner struct {
Bools []bool `key:"bools"`
}
var in inner
assert.NotNil(t, UnmarshalKey(map[string]interface{}{}, &in))
}
func TestUnmarshalBoolSliceNil(t *testing.T) {
type inner struct {
Bools []bool `key:"bools,optional"`
}
var in inner
assert.Nil(t, UnmarshalKey(map[string]interface{}{}, &in))
assert.Nil(t, in.Bools)
}
func TestUnmarshalBoolSliceNilExplicit(t *testing.T) {
type inner struct {
Bools []bool `key:"bools,optional"`
}
var in inner
assert.Nil(t, UnmarshalKey(map[string]interface{}{
"bools": nil,
}, &in))
assert.Nil(t, in.Bools)
}
func TestUnmarshalBoolSliceEmpty(t *testing.T) {
type inner struct {
Bools []bool `key:"bools,optional"`
}
var in inner
assert.Nil(t, UnmarshalKey(map[string]interface{}{
"bools": []bool{},
}, &in))
assert.Empty(t, in.Bools)
}
func TestUnmarshalBoolSliceWithDefault(t *testing.T) {
type inner struct {
Bools []bool `key:"bools,default=[true,false]"`
@@ -330,28 +373,34 @@ func TestUnmarshalFloat(t *testing.T) {
func TestUnmarshalInt64Slice(t *testing.T) {
var v struct {
Ages []int64 `key:"ages"`
Ages []int64 `key:"ages"`
Slice []int64 `key:"slice"`
}
m := map[string]interface{}{
"ages": []int64{1, 2},
"ages": []int64{1, 2},
"slice": []interface{}{},
}
ast := assert.New(t)
ast.Nil(UnmarshalKey(m, &v))
ast.ElementsMatch([]int64{1, 2}, v.Ages)
ast.Equal([]int64{}, v.Slice)
}
func TestUnmarshalIntSlice(t *testing.T) {
var v struct {
Ages []int `key:"ages"`
Ages []int `key:"ages"`
Slice []int `key:"slice"`
}
m := map[string]interface{}{
"ages": []int{1, 2},
"ages": []int{1, 2},
"slice": []interface{}{},
}
ast := assert.New(t)
ast.Nil(UnmarshalKey(m, &v))
ast.ElementsMatch([]int{1, 2}, v.Ages)
ast.Equal([]int{}, v.Slice)
}
func TestUnmarshalString(t *testing.T) {

View File

@@ -11,6 +11,7 @@ const (
mega = 1024 * 1024
)
// DisplayStats prints the goroutine, memory, GC stats with given interval, default to 5 seconds.
func DisplayStats(interval ...time.Duration) {
duration := defaultInterval
for _, val := range interval {

View File

@@ -41,6 +41,16 @@ func RawFieldNames(in interface{}, postgresSql ...bool) []string {
out = append(out, fmt.Sprintf("`%s`", fi.Name))
}
default:
// get tag name with the tag opton, e.g.:
// `db:"id"`
// `db:"id,type=char,length=16"`
// `db:",type=char,length=16"`
if strings.Contains(tagv, ",") {
tagv = strings.TrimSpace(strings.Split(tagv, ",")[0])
}
if len(tagv) == 0 {
tagv = fi.Name
}
if pg {
out = append(out, tagv)
} else {

View File

@@ -22,3 +22,20 @@ func TestFieldNames(t *testing.T) {
assert.Equal(t, expected, out)
})
}
type mockedUserWithOptions struct {
ID string `db:"id" json:"id,omitempty"`
UserName string `db:"user_name,type=varchar,length=255" json:"userName,omitempty"`
Sex int `db:"sex" json:"sex,omitempty"`
UUID string `db:",type=varchar,length=16" uuid:"uuid,omitempty"`
Age int `db:"age" json:"age"`
}
func TestFieldNamesWithTagOptions(t *testing.T) {
t.Run("new", func(t *testing.T) {
var u mockedUserWithOptions
out := RawFieldNames(&u)
expected := []string{"`id`", "`user_name`", "`sex`", "`UUID`", "`age`"}
assert.Equal(t, expected, out)
})
}

View File

@@ -24,6 +24,7 @@ type (
Expire(key string, seconds int) error
Expireat(key string, expireTime int64) error
Get(key string) (string, error)
GetSet(key, value string) (string, error)
Hdel(key, field string) (bool, error)
Hexists(key, field string) (bool, error)
Hget(key, field string) (string, error)
@@ -459,6 +460,15 @@ func (cs clusterStore) SetnxEx(key, value string, seconds int) (bool, error) {
return node.SetnxEx(key, value, seconds)
}
func (cs clusterStore) GetSet(key, value string) (string, error) {
node, err := cs.getRedis(key)
if err != nil {
return "", err
}
return node.GetSet(key, value)
}
func (cs clusterStore) Sismember(key string, value interface{}) (bool, error) {
node, err := cs.getRedis(key)
if err != nil {

View File

@@ -490,6 +490,29 @@ func TestRedis_SetExNx(t *testing.T) {
})
}
func TestRedis_Getset(t *testing.T) {
store := clusterStore{dispatcher: hash.NewConsistentHash()}
_, err := store.GetSet("hello", "world")
assert.NotNil(t, err)
runOnCluster(t, func(client Store) {
val, err := client.GetSet("hello", "world")
assert.Nil(t, err)
assert.Equal(t, "", val)
val, err = client.Get("hello")
assert.Nil(t, err)
assert.Equal(t, "world", val)
val, err = client.GetSet("hello", "newworld")
assert.Nil(t, err)
assert.Equal(t, "world", val)
val, err = client.Get("hello")
assert.Nil(t, err)
assert.Equal(t, "newworld", val)
_, err = client.Del("hello")
assert.Nil(t, err)
})
}
func TestRedis_SetGetDelHashField(t *testing.T) {
store := clusterStore{dispatcher: hash.NewConsistentHash()}
err := store.Hset("key", "field", "value")

View File

@@ -640,6 +640,29 @@ func (s *Redis) GetBitCtx(ctx context.Context, key string, offset int64) (val in
return
}
// GetSet is the implementation of redis getset command.
func (s *Redis) GetSet(key, value string) (string, error) {
return s.GetSetCtx(context.Background(), key, value)
}
// GetSetCtx is the implementation of redis getset command.
func (s *Redis) GetSetCtx(ctx context.Context, key, value string) (val string, err error) {
err = s.brk.DoWithAcceptable(func() error {
conn, err := getRedis(s)
if err != nil {
return err
}
if val, err = conn.GetSet(ctx, key, value).Result(); err == red.Nil {
return nil
}
return err
}, acceptable)
return
}
// Hdel is the implementation of redis hdel command.
func (s *Redis) Hdel(key string, fields ...string) (bool, error) {
return s.HdelCtx(context.Background(), key, fields...)
@@ -1381,21 +1404,28 @@ func (s *Redis) ScanCtx(ctx context.Context, cursor uint64, match string, count
}
// SetBit is the implementation of redis setbit command.
func (s *Redis) SetBit(key string, offset int64, value int) error {
func (s *Redis) SetBit(key string, offset int64, value int) (int, error) {
return s.SetBitCtx(context.Background(), key, offset, value)
}
// SetBitCtx is the implementation of redis setbit command.
func (s *Redis) SetBitCtx(ctx context.Context, key string, offset int64, value int) error {
return s.brk.DoWithAcceptable(func() error {
func (s *Redis) SetBitCtx(ctx context.Context, key string, offset int64, value int) (val int, err error) {
err = s.brk.DoWithAcceptable(func() error {
conn, err := getRedis(s)
if err != nil {
return err
}
_, err = conn.SetBit(ctx, key, offset, value).Result()
return err
v, err := conn.SetBit(ctx, key, offset, value).Result()
if err != nil {
return err
}
val = int(v)
return nil
}, acceptable)
return
}
// Sscan is the implementation of redis sscan command.

View File

@@ -387,30 +387,33 @@ func TestRedis_Mget(t *testing.T) {
func TestRedis_SetBit(t *testing.T) {
runOnRedis(t, func(client *Redis) {
err := New(client.Addr, badType()).SetBit("key", 1, 1)
_, err := New(client.Addr, badType()).SetBit("key", 1, 1)
assert.NotNil(t, err)
err = client.SetBit("key", 1, 1)
val, err := client.SetBit("key", 1, 1)
assert.Nil(t, err)
assert.Equal(t, 0, val)
})
}
func TestRedis_GetBit(t *testing.T) {
runOnRedis(t, func(client *Redis) {
err := client.SetBit("key", 2, 1)
val, err := client.SetBit("key", 2, 1)
assert.Nil(t, err)
assert.Equal(t, 0, val)
_, err = New(client.Addr, badType()).GetBit("key", 2)
assert.NotNil(t, err)
val, err := client.GetBit("key", 2)
v, err := client.GetBit("key", 2)
assert.Nil(t, err)
assert.Equal(t, 1, val)
assert.Equal(t, 1, v)
})
}
func TestRedis_BitCount(t *testing.T) {
runOnRedis(t, func(client *Redis) {
for i := 0; i < 11; i++ {
err := client.SetBit("key", int64(i), 1)
val, err := client.SetBit("key", int64(i), 1)
assert.Nil(t, err)
assert.Equal(t, 0, val)
}
_, err := New(client.Addr, badType()).BitCount("key", 0, -1)
@@ -701,6 +704,28 @@ func TestRedis_Set(t *testing.T) {
})
}
func TestRedis_GetSet(t *testing.T) {
runOnRedis(t, func(client *Redis) {
_, err := New(client.Addr, badType()).GetSet("hello", "world")
assert.NotNil(t, err)
val, err := client.GetSet("hello", "world")
assert.Nil(t, err)
assert.Equal(t, "", val)
val, err = client.Get("hello")
assert.Nil(t, err)
assert.Equal(t, "world", val)
val, err = client.GetSet("hello", "newworld")
assert.Nil(t, err)
assert.Equal(t, "world", val)
val, err = client.Get("hello")
assert.Nil(t, err)
assert.Equal(t, "newworld", val)
ret, err := client.Del("hello")
assert.Nil(t, err)
assert.Equal(t, 1, ret)
})
}
func TestRedis_SetGetDel(t *testing.T) {
runOnRedis(t, func(client *Redis) {
err := New(client.Addr, badType()).Set("hello", "world")

View File

@@ -2,6 +2,7 @@ package redis
import (
"math/rand"
"strconv"
"sync/atomic"
"time"
@@ -11,19 +12,26 @@ import (
)
const (
randomLen = 16
tolerance = 500 // milliseconds
millisPerSecond = 1000
lockCommand = `if redis.call("GET", KEYS[1]) == ARGV[1] then
redis.call("SET", KEYS[1], ARGV[1], "PX", ARGV[2])
return "OK"
else
return redis.call("SET", KEYS[1], ARGV[1], "NX", "PX", ARGV[2])
end`
delCommand = `if redis.call("GET", KEYS[1]) == ARGV[1] then
return redis.call("DEL", KEYS[1])
else
return 0
end`
randomLen = 16
)
// A RedisLock is a redis lock.
type RedisLock struct {
store *Redis
seconds uint32
count int32
key string
id string
}
@@ -43,35 +51,30 @@ func NewRedisLock(store *Redis, key string) *RedisLock {
// Acquire acquires the lock.
func (rl *RedisLock) Acquire() (bool, error) {
newCount := atomic.AddInt32(&rl.count, 1)
if newCount > 1 {
seconds := atomic.LoadUint32(&rl.seconds)
resp, err := rl.store.Eval(lockCommand, []string{rl.key}, []string{
rl.id, strconv.Itoa(int(seconds)*millisPerSecond + tolerance),
})
if err == red.Nil {
return false, nil
} else if err != nil {
logx.Errorf("Error on acquiring lock for %s, %s", rl.key, err.Error())
return false, err
} else if resp == nil {
return false, nil
}
reply, ok := resp.(string)
if ok && reply == "OK" {
return true, nil
}
seconds := atomic.LoadUint32(&rl.seconds)
ok, err := rl.store.SetnxEx(rl.key, rl.id, int(seconds+1)) // +1s for tolerance
if err == red.Nil {
atomic.AddInt32(&rl.count, -1)
return false, nil
} else if err != nil {
atomic.AddInt32(&rl.count, -1)
logx.Errorf("Error on acquiring lock for %s, %s", rl.key, err.Error())
return false, err
} else if !ok {
atomic.AddInt32(&rl.count, -1)
return false, nil
}
return true, nil
logx.Errorf("Unknown reply when acquiring lock for %s: %v", rl.key, resp)
return false, nil
}
// Release releases the lock.
func (rl *RedisLock) Release() (bool, error) {
newCount := atomic.AddInt32(&rl.count, -1)
if newCount > 0 {
return true, nil
}
resp, err := rl.store.Eval(delCommand, []string{rl.key}, []string{rl.id})
if err != nil {
return false, err

View File

@@ -29,25 +29,5 @@ func TestRedisLock(t *testing.T) {
endAcquire, err := secondLock.Acquire()
assert.Nil(t, err)
assert.True(t, endAcquire)
endAcquire, err = secondLock.Acquire()
assert.Nil(t, err)
assert.True(t, endAcquire)
release, err = secondLock.Release()
assert.Nil(t, err)
assert.True(t, release)
againAcquire, err = firstLock.Acquire()
assert.Nil(t, err)
assert.False(t, againAcquire)
release, err = secondLock.Release()
assert.Nil(t, err)
assert.True(t, release)
firstAcquire, err = firstLock.Acquire()
assert.Nil(t, err)
assert.True(t, firstAcquire)
})
}

View File

@@ -75,6 +75,7 @@ func format(query string, args ...interface{}) (string, error) {
break
}
}
if j > i+1 {
index, err := strconv.Atoi(query[i+1 : j])
if err != nil {
@@ -85,7 +86,7 @@ func format(query string, args ...interface{}) (string, error) {
if index > argIndex {
argIndex = index
}
index--
if index < 0 || numArgs <= index {
return "", fmt.Errorf("error: wrong index %d in sql", index)

View File

@@ -94,7 +94,7 @@ func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *sta
handler.TimeoutHandler(ng.checkedTimeout(fr.timeout)),
handler.RecoverHandler,
handler.MetricHandler(metrics),
handler.MaxBytesHandler(ng.conf.MaxBytes),
handler.MaxBytesHandler(ng.checkedMaxBytes(fr.maxBytes)),
handler.GunzipHandler,
)
chain = ng.appendAuthHandler(fr, chain, verifier)
@@ -119,6 +119,14 @@ func (ng *engine) bindRoutes(router httpx.Router) error {
return nil
}
func (ng *engine) checkedMaxBytes(bytes int64) int64 {
if bytes > 0 {
return bytes
}
return ng.conf.MaxBytes
}
func (ng *engine) checkedTimeout(timeout time.Duration) time.Duration {
if timeout > 0 {
return timeout

View File

@@ -194,6 +194,41 @@ func TestEngine_checkedTimeout(t *testing.T) {
}
}
func TestEngine_checkedMaxBytes(t *testing.T) {
tests := []struct {
name string
maxBytes int64
expect int64
}{
{
name: "not set",
expect: 1000,
},
{
name: "less",
maxBytes: 500,
expect: 500,
},
{
name: "equal",
maxBytes: 1000,
expect: 1000,
},
{
name: "more",
maxBytes: 1500,
expect: 1500,
},
}
ng := newEngine(RestConf{
MaxBytes: 1000,
})
for _, test := range tests {
assert.Equal(t, test.expect, ng.checkedMaxBytes(test.maxBytes))
}
}
func TestEngine_notFoundHandler(t *testing.T) {
logx.Disable()

View File

@@ -4,5 +4,5 @@ import "net/http"
type (
Interceptor func(r *http.Request) (*http.Request, ResponseHandler)
ResponseHandler func(*http.Response)
ResponseHandler func(resp *http.Response, err error)
)

View File

@@ -10,15 +10,21 @@ import (
func LogInterceptor(r *http.Request) (*http.Request, ResponseHandler) {
start := timex.Now()
return r, func(resp *http.Response) {
return r, func(resp *http.Response, err error) {
duration := timex.Since(start)
if err != nil {
logger := logx.WithContext(r.Context()).WithDuration(duration)
logger.Errorf("[HTTP] %s %s - %v", r.Method, r.URL, err)
return
}
var tc propagation.TraceContext
ctx := tc.Extract(r.Context(), propagation.HeaderCarrier(resp.Header))
logger := logx.WithContext(ctx).WithDuration(duration)
if isOkResponse(resp.StatusCode) {
logger.Infof("[HTTP] %d - %s %s/%s", resp.StatusCode, r.Method, r.Host, r.RequestURI)
logger.Infof("[HTTP] %d - %s %s", resp.StatusCode, r.Method, r.URL)
} else {
logger.Errorf("[HTTP] %d - %s %s/%s", resp.StatusCode, r.Method, r.Host, r.RequestURI)
logger.Errorf("[HTTP] %d - %s %s", resp.StatusCode, r.Method, r.URL)
}
}
}

View File

@@ -11,12 +11,13 @@ import (
func TestLogInterceptor(t *testing.T) {
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
}))
defer svr.Close()
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
assert.Nil(t, err)
req, handler := LogInterceptor(req)
resp, err := http.DefaultClient.Do(req)
handler(resp, err)
assert.Nil(t, err)
handler(resp)
assert.Equal(t, http.StatusOK, resp.StatusCode)
}
@@ -24,11 +25,27 @@ func TestLogInterceptorServerError(t *testing.T) {
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
defer svr.Close()
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
assert.Nil(t, err)
req, handler := LogInterceptor(req)
resp, err := http.DefaultClient.Do(req)
handler(resp, err)
assert.Nil(t, err)
handler(resp)
assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
}
func TestLogInterceptorServerClosed(t *testing.T) {
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
defer svr.Close()
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
assert.Nil(t, err)
svr.Close()
req, handler := LogInterceptor(req)
resp, err := http.DefaultClient.Do(req)
handler(resp, err)
assert.NotNil(t, err)
assert.Nil(t, resp)
}

View File

@@ -1,21 +1,44 @@
package httpc
import (
"io"
"net/http"
"github.com/zeromicro/go-zero/rest/httpc/internal"
)
// Do sends an HTTP request to the service assocated with the given key.
func Do(key string, r *http.Request) (*http.Response, error) {
return NewService(key).Do(r)
var interceptors = []internal.Interceptor{
internal.LogInterceptor,
}
// Get sends an HTTP GET request to the service assocated with the given key.
func Get(key, url string) (*http.Response, error) {
return NewService(key).Get(url)
// DoRequest sends an HTTP request and returns an HTTP response.
func DoRequest(r *http.Request) (*http.Response, error) {
return request(r, defaultClient{})
}
// Post sends an HTTP POST request to the service assocated with the given key.
func Post(key, url, contentType string, body io.Reader) (*http.Response, error) {
return NewService(key).Post(url, contentType, body)
type (
client interface {
do(r *http.Request) (*http.Response, error)
}
defaultClient struct{}
)
func (c defaultClient) do(r *http.Request) (*http.Response, error) {
return http.DefaultClient.Do(r)
}
func request(r *http.Request, cli client) (*http.Response, error) {
var respHandlers []internal.ResponseHandler
for _, interceptor := range interceptors {
var h internal.ResponseHandler
r, h = interceptor(r)
respHandlers = append(respHandlers, h)
}
resp, err := cli.do(r)
for i := len(respHandlers) - 1; i >= 0; i-- {
respHandlers[i](resp, err)
}
return resp, err
}

View File

@@ -11,27 +11,31 @@ import (
func TestDo(t *testing.T) {
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
}))
_, err := Get("foo", "tcp://bad request")
assert.NotNil(t, err)
resp, err := Get("foo", svr.URL)
defer svr.Close()
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
assert.Nil(t, err)
resp, err := DoRequest(req)
assert.Nil(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
}
func TestDoNotFound(t *testing.T) {
svr := httptest.NewServer(http.NotFoundHandler())
_, err := Post("foo", "tcp://bad request", "application/json", nil)
assert.NotNil(t, err)
resp, err := Post("foo", svr.URL, "application/json", nil)
defer svr.Close()
req, err := http.NewRequest(http.MethodPost, svr.URL, nil)
assert.Nil(t, err)
req.Header.Set("Content-Type", "application/json")
resp, err := DoRequest(req)
assert.Nil(t, err)
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
}
func TestDoMoved(t *testing.T) {
svr := httptest.NewServer(http.RedirectHandler("/foo", http.StatusMovedPermanently))
defer svr.Close()
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
assert.Nil(t, err)
_, err = Do("foo", req)
_, err = DoRequest(req)
// too many redirects
assert.NotNil(t, err)
}

36
rest/httpc/responses.go Normal file
View File

@@ -0,0 +1,36 @@
package httpc
import (
"net/http"
"strings"
"github.com/zeromicro/go-zero/core/mapping"
"github.com/zeromicro/go-zero/rest/internal/encoding"
)
// Parse parses the response.
func Parse(resp *http.Response, val interface{}) error {
if err := ParseHeaders(resp, val); err != nil {
return err
}
return ParseJsonBody(resp, val)
}
// ParseHeaders parses the rsponse headers.
func ParseHeaders(resp *http.Response, val interface{}) error {
return encoding.ParseHeaders(resp.Header, val)
}
// ParseJsonBody parses the rsponse body, which should be in json content type.
func ParseJsonBody(resp *http.Response, val interface{}) error {
if withJsonBody(resp) {
return mapping.UnmarshalJsonReader(resp.Body, val)
}
return mapping.UnmarshalJsonMap(nil, val)
}
func withJsonBody(r *http.Response) bool {
return r.ContentLength > 0 && strings.Contains(r.Header.Get(contentType), applicationJson)
}

View File

@@ -0,0 +1,64 @@
package httpc
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
func TestParse(t *testing.T) {
var val struct {
Foo string `header:"foo"`
Name string `json:"name"`
Value int `json:"value"`
}
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("foo", "bar")
w.Header().Set(contentType, applicationJson)
w.Write([]byte(`{"name":"kevin","value":100}`))
}))
defer svr.Close()
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
assert.Nil(t, err)
resp, err := DoRequest(req)
assert.Nil(t, err)
assert.Nil(t, Parse(resp, &val))
assert.Equal(t, "bar", val.Foo)
assert.Equal(t, "kevin", val.Name)
assert.Equal(t, 100, val.Value)
}
func TestParseHeaderError(t *testing.T) {
var val struct {
Foo int `header:"foo"`
}
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("foo", "bar")
w.Header().Set(contentType, applicationJson)
}))
defer svr.Close()
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
assert.Nil(t, err)
resp, err := DoRequest(req)
assert.Nil(t, err)
assert.NotNil(t, Parse(resp, &val))
}
func TestParseNoBody(t *testing.T) {
var val struct {
Foo string `header:"foo"`
}
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("foo", "bar")
w.Header().Set(contentType, applicationJson)
}))
defer svr.Close()
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
assert.Nil(t, err)
resp, err := DoRequest(req)
assert.Nil(t, err)
assert.Nil(t, Parse(resp, &val))
assert.Equal(t, "bar", val.Foo)
}

View File

@@ -1,33 +1,19 @@
package httpc
import (
"io"
"net/http"
"github.com/zeromicro/go-zero/core/breaker"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/rest/httpc/internal"
)
// ContentType means Content-Type.
const ContentType = "Content-Type"
var interceptors = []internal.Interceptor{
internal.LogInterceptor,
}
type (
// Option is used to customize the *http.Client.
Option func(r *http.Request) *http.Request
// Service represents a remote HTTP service.
Service interface {
// Do sends an HTTP request to the service.
Do(r *http.Request) (*http.Response, error)
// Get sends an HTTP GET request to the service.
Get(url string) (*http.Response, error)
// Post sends an HTTP POST request to the service.
Post(url, contentType string, body io.Reader) (*http.Response, error)
// DoRequest sends a HTTP request to the service.
DoRequest(r *http.Request) (*http.Response, error)
}
namedService struct {
@@ -53,50 +39,12 @@ func NewServiceWithClient(name string, cli *http.Client, opts ...Option) Service
}
}
// Do sends an HTTP request to the service.
func (s namedService) Do(r *http.Request) (resp *http.Response, err error) {
var respHandlers []internal.ResponseHandler
for _, interceptor := range interceptors {
var h internal.ResponseHandler
r, h = interceptor(r)
respHandlers = append(respHandlers, h)
}
resp, err = s.doRequest(r)
if err != nil {
logx.Errorf("[HTTP] %s %s/%s - %v", r.Method, r.Host, r.RequestURI, err)
return
}
for i := len(respHandlers) - 1; i >= 0; i-- {
respHandlers[i](resp)
}
return
// DoRequest sends an HTTP request to the service.
func (s namedService) DoRequest(r *http.Request) (*http.Response, error) {
return request(r, s)
}
// Get sends an HTTP GET request to the service.
func (s namedService) Get(url string) (*http.Response, error) {
r, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, err
}
return s.Do(r)
}
// Post sends an HTTP POST request to the service.
func (s namedService) Post(url, contentType string, body io.Reader) (*http.Response, error) {
r, err := http.NewRequest(http.MethodPost, url, body)
if err != nil {
return nil, err
}
r.Header.Set(ContentType, contentType)
return s.Do(r)
}
func (s namedService) doRequest(r *http.Request) (resp *http.Response, err error) {
func (s namedService) do(r *http.Request) (resp *http.Response, err error) {
for _, opt := range s.opts {
r = opt(r)
}

View File

@@ -10,10 +10,11 @@ import (
func TestNamedService_Do(t *testing.T) {
svr := httptest.NewServer(http.RedirectHandler("/foo", http.StatusMovedPermanently))
defer svr.Close()
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
assert.Nil(t, err)
service := NewService("foo")
_, err = service.Do(req)
_, err = service.DoRequest(req)
// too many redirects
assert.NotNil(t, err)
}
@@ -22,11 +23,14 @@ func TestNamedService_Get(t *testing.T) {
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("foo", r.Header.Get("foo"))
}))
defer svr.Close()
service := NewService("foo", func(r *http.Request) *http.Request {
r.Header.Set("foo", "bar")
return r
})
resp, err := service.Get(svr.URL)
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
assert.Nil(t, err)
resp, err := service.DoRequest(req)
assert.Nil(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "bar", resp.Header.Get("foo"))
@@ -34,10 +38,12 @@ func TestNamedService_Get(t *testing.T) {
func TestNamedService_Post(t *testing.T) {
svr := httptest.NewServer(http.NotFoundHandler())
defer svr.Close()
service := NewService("foo")
_, err := service.Post("tcp://bad request", "application/json", nil)
assert.NotNil(t, err)
resp, err := service.Post(svr.URL, "application/json", nil)
req, err := http.NewRequest(http.MethodPost, svr.URL, nil)
assert.Nil(t, err)
req.Header.Set("Content-Type", "application/json")
resp, err := service.DoRequest(req)
assert.Nil(t, err)
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
}

6
rest/httpc/vars.go Normal file
View File

@@ -0,0 +1,6 @@
package httpc
const (
contentType = "Content-Type"
applicationJson = "application/json"
)

View File

@@ -3,17 +3,16 @@ package httpx
import (
"io"
"net/http"
"net/textproto"
"strings"
"github.com/zeromicro/go-zero/core/mapping"
"github.com/zeromicro/go-zero/rest/internal/encoding"
"github.com/zeromicro/go-zero/rest/pathvar"
)
const (
formKey = "form"
pathKey = "path"
headerKey = "header"
maxMemory = 32 << 20 // 32MB
maxBodyLen = 8 << 20 // 8MB
separator = ";"
@@ -21,10 +20,8 @@ const (
)
var (
formUnmarshaler = mapping.NewUnmarshaler(formKey, mapping.WithStringValues())
pathUnmarshaler = mapping.NewUnmarshaler(pathKey, mapping.WithStringValues())
headerUnmarshaler = mapping.NewUnmarshaler(headerKey, mapping.WithStringValues(),
mapping.WithCanonicalKeyFunc(textproto.CanonicalMIMEHeaderKey))
formUnmarshaler = mapping.NewUnmarshaler(formKey, mapping.WithStringValues())
pathUnmarshaler = mapping.NewUnmarshaler(pathKey, mapping.WithStringValues())
)
// Parse parses the request.
@@ -46,16 +43,7 @@ func Parse(r *http.Request, v interface{}) error {
// ParseHeaders parses the headers request.
func ParseHeaders(r *http.Request, v interface{}) error {
m := map[string]interface{}{}
for k, v := range r.Header {
if len(v) == 1 {
m[k] = v[0]
} else {
m[k] = v
}
}
return headerUnmarshaler.Unmarshal(m, v)
return encoding.ParseHeaders(r.Header, v)
}
// ParseForm parses the form request.

View File

@@ -0,0 +1,27 @@
package encoding
import (
"net/http"
"net/textproto"
"github.com/zeromicro/go-zero/core/mapping"
)
const headerKey = "header"
var headerUnmarshaler = mapping.NewUnmarshaler(headerKey, mapping.WithStringValues(),
mapping.WithCanonicalKeyFunc(textproto.CanonicalMIMEHeaderKey))
// ParseHeaders parses the headers request.
func ParseHeaders(header http.Header, v interface{}) error {
m := map[string]interface{}{}
for k, v := range header {
if len(v) == 1 {
m[k] = v[0]
} else {
m[k] = v
}
}
return headerUnmarshaler.Unmarshal(m, v)
}

View File

@@ -0,0 +1,40 @@
package encoding
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
func TestParseHeaders(t *testing.T) {
var val struct {
Foo string `header:"foo"`
Baz int `header:"baz"`
Qux bool `header:"qux,default=true"`
}
r := httptest.NewRequest(http.MethodGet, "/any", nil)
r.Header.Set("foo", "bar")
r.Header.Set("baz", "1")
assert.Nil(t, ParseHeaders(r.Header, &val))
assert.Equal(t, "bar", val.Foo)
assert.Equal(t, 1, val.Baz)
assert.True(t, val.Qux)
}
func TestParseHeadersMulti(t *testing.T) {
var val struct {
Foo []string `header:"foo"`
Baz int `header:"baz"`
Qux bool `header:"qux,default=true"`
}
r := httptest.NewRequest(http.MethodGet, "/any", nil)
r.Header.Set("foo", "bar")
r.Header.Add("foo", "bar1")
r.Header.Set("baz", "1")
assert.Nil(t, ParseHeaders(r.Header, &val))
assert.Equal(t, []string{"bar", "bar1"}, val.Foo)
assert.Equal(t, 1, val.Baz)
assert.True(t, val.Qux)
}

View File

@@ -137,6 +137,13 @@ func WithJwtTransition(secret, prevSecret string) RouteOption {
}
}
// WithMaxBytes returns a RouteOption to set maxBytes with the given value.
func WithMaxBytes(maxBytes int64) RouteOption {
return func(r *featuredRoutes) {
r.maxBytes = maxBytes
}
}
// WithMiddlewares adds given middlewares to given routes.
func WithMiddlewares(ms []Middleware, rs ...Route) []Route {
for i := len(ms) - 1; i >= 0; i-- {

View File

@@ -95,6 +95,13 @@ Port: 54321
}
}
func TestWithMaxBytes(t *testing.T) {
const maxBytes = 1000
var fr featuredRoutes
WithMaxBytes(maxBytes)(&fr)
assert.Equal(t, int64(maxBytes), fr.maxBytes)
}
func TestWithMiddleware(t *testing.T) {
m := make(map[string]string)
rt := router.NewRouter()

View File

@@ -36,5 +36,6 @@ type (
jwt jwtSetting
signature signatureSetting
routes []Route
maxBytes int64
}
)

34
tools/goctl/Dockerfile Normal file
View File

@@ -0,0 +1,34 @@
FROM golang:alpine AS builder
LABEL stage=gobuilder
ENV CGO_ENABLED 0
ENV GOPROXY https://goproxy.cn,direct
RUN apk update --no-cache && apk add --no-cache tzdata
RUN go install google.golang.org/protobuf/cmd/protoc-gen-go@latest
RUN go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@latest
WORKDIR /build
ADD go.mod .
ADD go.sum .
RUN go mod download
COPY . .
RUN go build -ldflags="-s -w" -o /app/goctl ./goctl.go
FROM alpine
RUN apk update --no-cache && apk add --no-cache protoc
COPY --from=builder /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ca-certificates.crt
COPY --from=builder /usr/share/zoneinfo/Asia/Shanghai /usr/share/zoneinfo/Asia/Shanghai
COPY --from=builder /go/bin/protoc-gen-go /usr/bin/protoc-gen-go
COPY --from=builder /go/bin/protoc-gen-go-grpc /usr/bin/protoc-gen-go-grpc
ENV TZ Asia/Shanghai
WORKDIR /app
COPY --from=builder /app/goctl /app/goctl
CMD ["./goctl"]

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"go/format"
"go/scanner"
"io"
"io/ioutil"
"os"
"path/filepath"
@@ -29,14 +30,14 @@ const (
func GoFormatApi(c *cli.Context) error {
useStdin := c.Bool("stdin")
skipCheckDeclare := c.Bool("declare")
dir := c.String("dir")
var be errorx.BatchError
if useStdin {
if err := apiFormatByStdin(skipCheckDeclare); err != nil {
if err := apiFormatReader(os.Stdin, dir, skipCheckDeclare); err != nil {
be.Add(err)
}
} else {
dir := c.String("dir")
if len(dir) == 0 {
return errors.New("missing -dir")
}
@@ -65,13 +66,14 @@ func GoFormatApi(c *cli.Context) error {
return be.Err()
}
func apiFormatByStdin(skipCheckDeclare bool) error {
data, err := ioutil.ReadAll(os.Stdin)
// apiFormatReader
// filename is needed when there are `import` literals.
func apiFormatReader(reader io.Reader, filename string, skipCheckDeclare bool) error {
data, err := ioutil.ReadAll(reader)
if err != nil {
return err
}
result, err := apiFormat(string(data), skipCheckDeclare)
result, err := apiFormat(string(data), skipCheckDeclare, filename)
if err != nil {
return err
}

View File

@@ -1,15 +1,21 @@
package format
import (
"fmt"
"io/fs"
"io/ioutil"
"os"
"path"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const (
notFormattedStr = `
type Request struct {
Name string ` + "`" + `path:"name,options=you|me"` + "`" + `
Name string ` + "`" + `path:"name,options=you|me"` + "`" + `
}
type Response struct {
Message string ` + "`" + `json:"message"` + "`" + `
@@ -45,3 +51,26 @@ func TestFormat(t *testing.T) {
_, err = apiFormat(notFormattedStr, false)
assert.Errorf(t, err, " line 7:13 can not found declaration 'Student' in context")
}
func Test_apiFormatReader_issue1721(t *testing.T) {
dir, err := os.MkdirTemp("", "goctl-api-format")
require.NoError(t, err)
defer os.RemoveAll(dir)
subDir := path.Join(dir, "sub")
err = os.MkdirAll(subDir, fs.ModePerm)
require.NoError(t, err)
importedFilename := path.Join(dir, "foo.api")
err = ioutil.WriteFile(importedFilename, []byte{}, fs.ModePerm)
require.NoError(t, err)
filename := path.Join(subDir, "bar.api")
err = ioutil.WriteFile(filename, []byte(fmt.Sprintf(`import "%s"`, importedFilename)), 0644)
require.NoError(t, err)
f, err := os.Open(filename)
require.NoError(t, err)
err = apiFormatReader(f, filename, false)
assert.NoError(t, err)
}

View File

@@ -7,6 +7,7 @@ import (
"sort"
"strings"
"text/template"
"time"
"github.com/zeromicro/go-zero/core/collection"
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
@@ -23,7 +24,8 @@ const (
package handler
import (
"net/http"
"net/http"{{if .hasTimeout}}
"time"{{end}}
{{.importPackages}}
)
@@ -34,9 +36,10 @@ func RegisterHandlers(server *rest.Server, serverCtx *svc.ServiceContext) {
`
routesAdditionTemplate = `
server.AddRoutes(
{{.routes}} {{.jwt}}{{.signature}} {{.prefix}}
{{.routes}} {{.jwt}}{{.signature}} {{.prefix}} {{.timeout}}
)
`
timeoutThreshold = time.Millisecond
)
var mapping = map[string]string{
@@ -57,6 +60,7 @@ type (
jwtEnabled bool
signatureEnabled bool
authName string
timeout string
middlewares []string
prefix string
jwtTrans string
@@ -80,6 +84,7 @@ func genRoutes(dir, rootPkg string, cfg *config.Config, api *spec.ApiSpec) error
return err
}
var hasTimeout bool
gt := template.Must(template.New("groupTemplate").Parse(templateText))
for _, g := range groups {
var gbuilder strings.Builder
@@ -110,6 +115,22 @@ func genRoutes(dir, rootPkg string, cfg *config.Config, api *spec.ApiSpec) error
rest.WithPrefix("%s"),`, g.prefix)
}
var timeout string
if len(g.timeout) > 0 {
duration, err := time.ParseDuration(g.timeout)
if err != nil {
return err
}
// why we check this, maybe some users set value 1, it's 1ns, not 1s.
if duration < timeoutThreshold {
return fmt.Errorf("timeout should not less than 1ms, now %v", duration)
}
timeout = fmt.Sprintf("rest.WithTimeout(%d * time.Millisecond),", duration/time.Millisecond)
hasTimeout = true
}
var routes string
if len(g.middlewares) > 0 {
gbuilder.WriteString("\n}...,")
@@ -130,6 +151,7 @@ rest.WithPrefix("%s"),`, g.prefix)
"jwt": jwt,
"signature": signature,
"prefix": prefix,
"timeout": timeout,
}); err != nil {
return err
}
@@ -139,8 +161,8 @@ rest.WithPrefix("%s"),`, g.prefix)
if err != nil {
return err
}
routeFilename = routeFilename + ".go"
routeFilename = routeFilename + ".go"
filename := path.Join(dir, handlerDir, routeFilename)
os.Remove(filename)
@@ -152,7 +174,8 @@ rest.WithPrefix("%s"),`, g.prefix)
category: category,
templateFile: routesTemplateFile,
builtinTemplate: routesTemplate,
data: map[string]string{
data: map[string]interface{}{
"hasTimeout": hasTimeout,
"importPackages": genRouteImports(rootPkg, api),
"routesAdditions": strings.TrimSpace(builder.String()),
},
@@ -171,7 +194,8 @@ func genRouteImports(parentPkg string, api *spec.ApiSpec) string {
continue
}
}
importSet.AddStr(fmt.Sprintf("%s \"%s\"", toPrefix(folder), pathx.JoinPackages(parentPkg, handlerDir, folder)))
importSet.AddStr(fmt.Sprintf("%s \"%s\"", toPrefix(folder),
pathx.JoinPackages(parentPkg, handlerDir, folder)))
}
}
imports := importSet.KeysStr()
@@ -205,6 +229,8 @@ func getRoutes(api *spec.ApiSpec) ([]group, error) {
})
}
groupedRoutes.timeout = g.GetAnnotation("timeout")
jwt := g.GetAnnotation("jwt")
if len(jwt) > 0 {
groupedRoutes.authName = jwt

View File

@@ -21,7 +21,6 @@ func TestImportRegex(t *testing.T) {
{`"../foo/bar.api"`, true},
{`"../../foo/bar.api"`, true},
{`"bar..api"`, false},
{`"//bar.api"`, false},
{`"/foo/foo_bar.api"`, true},
}

View File

@@ -48,17 +48,15 @@ func Completion(c *cli.Context) error {
flag := magic
err = ioutil.WriteFile(zshF, zsh, os.ModePerm)
if err != nil {
return err
if err == nil {
flag |= flagZsh
}
flag |= flagZsh
err = ioutil.WriteFile(bashF, bash, os.ModePerm)
if err != nil {
return err
if err == nil {
flag |= flagBash
}
flag |= flagBash
buffer.WriteString(aurora.BrightGreen("generation auto completion success!\n").String())
buffer.WriteString(aurora.BrightGreen("executes the following script to setting shell:\n").String())
switch flag {

View File

@@ -28,7 +28,7 @@ type Docker struct {
GoRelPath string
GoFile string
ExeFile string
Scratch bool
BaseImage string
HasPort bool
Port int
Argument string
@@ -74,10 +74,10 @@ func DockerCommand(c *cli.Context) (err error) {
return fmt.Errorf("file %q not found", goFile)
}
scratch := c.Bool("scratch")
base := c.String("base")
port := c.Int("port")
if _, err := os.Stat(etcDir); os.IsNotExist(err) {
return generateDockerfile(goFile, scratch, port, version, timezone)
return generateDockerfile(goFile, base, port, version, timezone)
}
cfg, err := findConfig(goFile, etcDir)
@@ -85,7 +85,7 @@ func DockerCommand(c *cli.Context) (err error) {
return err
}
if err := generateDockerfile(goFile, scratch, port, version, timezone, "-f", "etc/"+cfg); err != nil {
if err := generateDockerfile(goFile, base, port, version, timezone, "-f", "etc/"+cfg); err != nil {
return err
}
@@ -126,7 +126,7 @@ func findConfig(file, dir string) (string, error) {
return files[0], nil
}
func generateDockerfile(goFile string, scratch bool, port int, version, timezone string, args ...string) error {
func generateDockerfile(goFile, base string, port int, version, timezone string, args ...string) error {
projPath, err := getFilePath(filepath.Dir(goFile))
if err != nil {
return err
@@ -159,7 +159,7 @@ func generateDockerfile(goFile string, scratch bool, port int, version, timezone
GoRelPath: projPath,
GoFile: goFile,
ExeFile: pathx.FileNameWithoutExt(filepath.Base(goFile)),
Scratch: scratch,
BaseImage: base,
HasPort: port > 0,
Port: port,
Argument: builder.String(),

View File

@@ -27,9 +27,9 @@ COPY . .
{{end}}RUN go build -ldflags="-s -w" -o /app/{{.ExeFile}} {{.GoRelPath}}/{{.GoFile}}
FROM {{if .Scratch}}scratch{{else}}alpine{{end}}
FROM {{.BaseImage}}
{{if .Scratch}}COPY --from=builder /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ca-certificates.crt{{else}}RUN apk update --no-cache && apk add --no-cache ca-certificates{{end}}
COPY --from=builder /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ca-certificates.crt
{{if .HasTimezone}}COPY --from=builder /usr/share/zoneinfo/{{.Timezone}} /usr/share/zoneinfo/{{.Timezone}}
ENV TZ {{.Timezone}}
{{end}}

View File

@@ -40,21 +40,23 @@ var bins = []bin{
func Check(ctx *cli.Context) error {
install := ctx.Bool("install")
force := ctx.Bool("force")
return Prepare(install, force)
verbose := ctx.Bool("verbose")
return Prepare(install, force, verbose)
}
func Prepare(install, force bool) error {
func Prepare(install, force, verbose bool) error {
log := console.NewColorConsole(verbose)
pending := true
console.Info("[goctl-env]: preparing to check env")
log.Info("[goctl-env]: preparing to check env")
defer func() {
if p := recover(); p != nil {
console.Error("%+v", p)
log.Error("%+v", p)
return
}
if pending {
console.Success("\n[goctl-env]: congratulations! your goctl environment is ready!")
log.Success("\n[goctl-env]: congratulations! your goctl environment is ready!")
} else {
console.Error(`
log.Error(`
[goctl-env]: check env finish, some dependencies is not found in PATH, you can execute
command 'goctl env check --install' to install it, for details, please execute command
'goctl env check --help'`)
@@ -62,29 +64,29 @@ command 'goctl env check --install' to install it, for details, please execute c
}()
for _, e := range bins {
time.Sleep(200 * time.Millisecond)
console.Info("")
console.Info("[goctl-env]: looking up %q", e.name)
log.Info("")
log.Info("[goctl-env]: looking up %q", e.name)
if e.exists {
console.Success("[goctl-env]: %q is installed", e.name)
log.Success("[goctl-env]: %q is installed", e.name)
continue
}
console.Warning("[goctl-env]: %q is not found in PATH", e.name)
log.Warning("[goctl-env]: %q is not found in PATH", e.name)
if install {
install := func() {
console.Info("[goctl-env]: preparing to install %q", e.name)
log.Info("[goctl-env]: preparing to install %q", e.name)
path, err := e.get(env.Get(env.GoctlCache))
if err != nil {
console.Error("[goctl-env]: an error interrupted the installation: %+v", err)
log.Error("[goctl-env]: an error interrupted the installation: %+v", err)
pending = false
} else {
console.Success("[goctl-env]: %q is already installed in %q", e.name, path)
log.Success("[goctl-env]: %q is already installed in %q", e.name, path)
}
}
if force {
install()
continue
}
console.Info("[goctl-env]: do you want to install %q [y: YES, n: No]", e.name)
log.Info("[goctl-env]: do you want to install %q [y: YES, n: No]", e.name)
for {
var in string
fmt.Scanln(&in)
@@ -95,10 +97,10 @@ command 'goctl env check --install' to install it, for details, please execute c
brk = true
case strings.EqualFold(in, "n"):
pending = false
console.Info("[goctl-env]: %q installation is ignored", e.name)
log.Info("[goctl-env]: %q installation is ignored", e.name)
brk = true
default:
console.Error("[goctl-env]: invalid input, input 'y' for yes, 'n' for no")
log.Error("[goctl-env]: invalid input, input 'y' for yes, 'n' for no")
}
if brk {
break

View File

@@ -23,7 +23,6 @@ import (
"github.com/zeromicro/go-zero/tools/goctl/completion"
"github.com/zeromicro/go-zero/tools/goctl/docker"
"github.com/zeromicro/go-zero/tools/goctl/env"
"github.com/zeromicro/go-zero/tools/goctl/internal/errorx"
"github.com/zeromicro/go-zero/tools/goctl/internal/version"
"github.com/zeromicro/go-zero/tools/goctl/kube"
"github.com/zeromicro/go-zero/tools/goctl/migrate"
@@ -70,6 +69,10 @@ var commands = []cli.Command{
Name: "force, f",
Usage: "silent installation of non-existent dependencies",
},
cli.BoolFlag{
Name: "verbose, v",
Usage: "enable log output",
},
},
Action: env.Check,
},
@@ -345,9 +348,10 @@ var commands = []cli.Command{
Name: "go",
Usage: "the file that contains main function",
},
cli.BoolFlag{
Name: "scratch",
Usage: "use scratch for the base docker image",
cli.StringFlag{
Name: "base",
Usage: "the base image to build the docker image, default scratch",
Value: "scratch",
},
cli.IntFlag{
Name: "port",
@@ -492,9 +496,8 @@ var commands = []cli.Command{
Usage: "generate rpc code",
Subcommands: []cli.Command{
{
Name: "new",
Usage: `generate rpc demo service`,
Description: aurora.Yellow(`deprecated: zrpc code generation use "goctl rpc protoc" instead, for the details see "goctl rpc protoc --help"`).String(),
Name: "new",
Usage: `generate rpc demo service`,
Flags: []cli.Flag{
cli.StringFlag{
Name: "style",
@@ -519,6 +522,10 @@ var commands = []cli.Command{
Name: "branch",
Usage: "the branch of the remote repo, it does work with --remote",
},
cli.BoolFlag{
Name: "verbose, v",
Usage: "enable log output",
},
},
Action: rpc.RPCNew,
},
@@ -601,54 +608,11 @@ var commands = []cli.Command{
Name: "branch",
Usage: "the branch of the remote repo, it does work with --remote",
},
},
},
{
Name: "proto",
Usage: `generate rpc from proto`,
Description: aurora.Yellow(`deprecated: zrpc code generation use "goctl rpc protoc" instead, for the details see "goctl rpc protoc --help"`).String(),
Flags: []cli.Flag{
cli.StringFlag{
Name: "src, s",
Usage: "the file path of the proto source file",
},
cli.StringSliceFlag{
Name: "proto_path, I",
Usage: `native command of protoc, specify the directory in which to search for imports. [optional]`,
},
cli.StringSliceFlag{
Name: "go_opt",
Usage: `native command of protoc-gen-go, specify the mapping from proto to go, eg --go_opt=proto_import=go_package_import. [optional]`,
},
cli.StringFlag{
Name: "dir, d",
Usage: `the target path of the code`,
},
cli.StringFlag{
Name: "style",
Usage: "the file naming format, see [https://github.com/zeromicro/go-zero/tree/master/tools/goctl/config/readme.md]",
},
cli.BoolFlag{
Name: "idea",
Usage: "whether the command execution environment is from idea plugin. [optional]",
},
cli.StringFlag{
Name: "home",
Usage: "the goctl home path of the template, --home and --remote cannot be set at the same time, " +
"if they are, --remote has higher priority",
},
cli.StringFlag{
Name: "remote",
Usage: "the remote git repo of the template, --home and --remote cannot be set at the same time, " +
"if they are, --remote has higher priority\n\tThe git repo directory must be consistent with the " +
"https://github.com/zeromicro/go-zero-template directory structure",
},
cli.StringFlag{
Name: "branch",
Usage: "the branch of the remote repo, it does work with --remote",
Name: "verbose, v",
Usage: "enable log output",
},
},
Action: rpc.RPC,
},
},
},
@@ -941,7 +905,7 @@ func main() {
// cli already print error messages.
if err := app.Run(os.Args); err != nil {
fmt.Println(aurora.Red(errorx.Wrap(err).Error()))
fmt.Println(aurora.Red(err.Error()))
os.Exit(codeFailure)
}
}

View File

@@ -6,7 +6,7 @@ import (
)
// BuildVersion is the version of goctl.
const BuildVersion = "1.3.3"
const BuildVersion = "1.3.4"
var tag = map[string]int{"pre-alpha": 0, "alpha": 1, "pre-bata": 2, "beta": 3, "released": 4, "": 5}

View File

@@ -20,16 +20,13 @@ import (
"github.com/zeromicro/go-zero/tools/goctl/util/stringx"
)
const (
pwd = "."
createTableFlag = `(?m)^(?i)CREATE\s+TABLE` // ignore case
)
const pwd = "."
type (
defaultGenerator struct {
// source string
dir string
console.Console
// source string
dir string
pkg string
cfg *config.Config
isPostgreSql bool
@@ -48,6 +45,12 @@ type (
updateCode string
deleteCode string
cacheExtra string
tableName string
}
codeTuple struct {
modelCode string
modelCustomCode string
}
)
@@ -109,7 +112,7 @@ func (g *defaultGenerator) StartFromDDL(filename string, withCache bool, databas
}
func (g *defaultGenerator) StartFromInformationSchema(tables map[string]*model.Table, withCache bool) error {
m := make(map[string]string)
m := make(map[string]*codeTuple)
for _, each := range tables {
table, err := parser.ConvertDataType(each)
if err != nil {
@@ -120,14 +123,21 @@ func (g *defaultGenerator) StartFromInformationSchema(tables map[string]*model.T
if err != nil {
return err
}
customCode, err := g.genModelCustom(*table, withCache)
if err != nil {
return err
}
m[table.Name.Source()] = code
m[table.Name.Source()] = &codeTuple{
modelCode: code,
modelCustomCode: customCode,
}
}
return g.createFile(m)
}
func (g *defaultGenerator) createFile(modelList map[string]string) error {
func (g *defaultGenerator) createFile(modelList map[string]*codeTuple) error {
dirAbs, err := filepath.Abs(g.dir)
if err != nil {
return err
@@ -140,20 +150,28 @@ func (g *defaultGenerator) createFile(modelList map[string]string) error {
return err
}
for tableName, code := range modelList {
for tableName, codes := range modelList {
tn := stringx.From(tableName)
modelFilename, err := format.FileNamingFormat(g.cfg.NamingFormat, fmt.Sprintf("%s_model", tn.Source()))
modelFilename, err := format.FileNamingFormat(g.cfg.NamingFormat,
fmt.Sprintf("%s_model", tn.Source()))
if err != nil {
return err
}
name := util.SafeString(modelFilename) + ".go"
name := util.SafeString(modelFilename) + "_gen.go"
filename := filepath.Join(dirAbs, name)
err = ioutil.WriteFile(filename, []byte(codes.modelCode), os.ModePerm)
if err != nil {
return err
}
name = util.SafeString(modelFilename) + ".go"
filename = filepath.Join(dirAbs, name)
if pathx.FileExists(filename) {
g.Warning("%s already exists, ignored.", name)
continue
}
err = ioutil.WriteFile(filename, []byte(code), os.ModePerm)
err = ioutil.WriteFile(filename, []byte(codes.modelCustomCode), os.ModePerm)
if err != nil {
return err
}
@@ -183,8 +201,9 @@ func (g *defaultGenerator) createFile(modelList map[string]string) error {
}
// ret1: key-table name,value-code
func (g *defaultGenerator) genFromDDL(filename string, withCache bool, database string) (map[string]string, error) {
m := make(map[string]string)
func (g *defaultGenerator) genFromDDL(filename string, withCache bool, database string) (
map[string]*codeTuple, error) {
m := make(map[string]*codeTuple)
tables, err := parser.Parse(filename, database)
if err != nil {
return nil, err
@@ -195,8 +214,15 @@ func (g *defaultGenerator) genFromDDL(filename string, withCache bool, database
if err != nil {
return nil, err
}
customCode, err := g.genModelCustom(*e, withCache)
if err != nil {
return nil, err
}
m[e.Name.Source()] = code
m[e.Name.Source()] = &codeTuple{
modelCode: code,
modelCustomCode: customCode,
}
}
return m, nil
@@ -223,7 +249,7 @@ func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, er
table.UniqueCacheKey = uniqueKey
table.ContainsUniqueCacheKey = len(uniqueKey) > 0
importsCode, err := genImports(withCache, in.ContainsTime(), table)
importsCode, err := genImports(table, withCache, in.ContainsTime())
if err != nil {
return "", err
}
@@ -261,7 +287,8 @@ func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, er
}
var list []string
list = append(list, insertCodeMethod, findOneCodeMethod, ret.findOneInterfaceMethod, updateCodeMethod, deleteCodeMethod)
list = append(list, insertCodeMethod, findOneCodeMethod, ret.findOneInterfaceMethod,
updateCodeMethod, deleteCodeMethod)
typesCode, err := genTypes(table, strings.Join(modelutil.TrimStringSlice(list), pathx.NL), withCache)
if err != nil {
return "", err
@@ -272,6 +299,11 @@ func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, er
return "", err
}
tableName, err := genTableName(table)
if err != nil {
return "", err
}
code := &code{
importsCode: importsCode,
varsCode: varsCode,
@@ -282,6 +314,7 @@ func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, er
updateCode: updateCode,
deleteCode: deleteCode,
cacheExtra: ret.cacheExtra,
tableName: tableName,
}
output, err := g.executeModel(table, code)
@@ -292,8 +325,30 @@ func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, er
return output.String(), nil
}
func (g *defaultGenerator) genModelCustom(in parser.Table, withCache bool) (string, error) {
text, err := pathx.LoadTemplate(category, modelCustomTemplateFile, template.ModelCustom)
if err != nil {
return "", err
}
t := util.With("model-custom").
Parse(text).
GoFmt(true)
output, err := t.Execute(map[string]interface{}{
"pkg": g.pkg,
"withCache": withCache,
"upperStartCamelObject": in.Name.ToCamel(),
"lowerStartCamelObject": stringx.From(in.Name.ToCamel()).Untitle(),
})
if err != nil {
return "", err
}
return output.String(), nil
}
func (g *defaultGenerator) executeModel(table Table, code *code) (*bytes.Buffer, error) {
text, err := pathx.LoadTemplate(category, modelTemplateFile, template.Model)
text, err := pathx.LoadTemplate(category, modelGenTemplateFile, template.ModelGen)
if err != nil {
return nil, err
}
@@ -311,6 +366,7 @@ func (g *defaultGenerator) executeModel(table Table, code *code) (*bytes.Buffer,
"update": code.updateCode,
"delete": code.deleteCode,
"extraMethod": code.cacheExtra,
"tableName": code.tableName,
"data": table,
})
if err != nil {

View File

@@ -4,16 +4,19 @@ import (
"database/sql"
"io/ioutil"
"os"
"path"
"path/filepath"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stringx"
"github.com/zeromicro/go-zero/tools/goctl/config"
"github.com/zeromicro/go-zero/tools/goctl/model/sql/builderx"
"github.com/zeromicro/go-zero/tools/goctl/model/sql/parser"
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
)
@@ -121,3 +124,31 @@ func TestFields(t *testing.T) {
assert.Equal(t, "`name`,`age`,`score`", studentRowsExpectAutoSet)
assert.Equal(t, "`name`=?,`age`=?,`score`=?", studentRowsWithPlaceHolder)
}
func Test_genPublicModel(t *testing.T) {
var err error
dir := pathx.MustTempDir()
modelDir := path.Join(dir, "model")
err = os.MkdirAll(modelDir, 0777)
require.NoError(t, err)
defer os.RemoveAll(dir)
modelFilename := filepath.Join(modelDir, "foo.sql")
err = ioutil.WriteFile(modelFilename, []byte(source), 0777)
require.NoError(t, err)
g, err := NewDefaultGenerator(modelDir, &config.Config{
NamingFormat: config.DefaultFormat,
})
require.NoError(t, err)
tables, err := parser.Parse(modelFilename, "")
require.Equal(t, 1, len(tables))
code, err := g.genModelCustom(*tables[0], false)
assert.NoError(t, err)
assert.True(t, strings.Contains(code, "package model"))
assert.True(t, strings.Contains(code, "TestUserModel interface {\n\t\ttestUserModel\n\t}\n"))
assert.True(t, strings.Contains(code, "customTestUserModel struct {\n\t\t*defaultTestUserModel\n\t}\n"))
assert.True(t, strings.Contains(code, "func NewTestUserModel(conn sqlx.SqlConn) TestUserModel {"))
}

View File

@@ -6,7 +6,7 @@ import (
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
)
func genImports(withCache, timeImport bool, table Table) (string, error) {
func genImports(table Table, withCache, timeImport bool) (string, error) {
if withCache {
text, err := pathx.LoadTemplate(category, importsTemplateFile, template.Imports)
if err != nil {

View File

@@ -55,7 +55,6 @@ func genInsert(table Table, withCache, postgreSql bool) (string, string, error)
Parse(text).
Execute(map[string]interface{}{
"withCache": withCache,
"containsIndexCache": table.ContainsUniqueCacheKey,
"upperStartCamelObject": camel,
"lowerStartCamelObject": stringx.From(camel).Untitle(),
"expression": strings.Join(expressions, ", "),

View File

@@ -0,0 +1,26 @@
package gen
import (
"github.com/zeromicro/go-zero/tools/goctl/model/sql/template"
"github.com/zeromicro/go-zero/tools/goctl/util"
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
)
func genTableName(table Table) (string, error) {
text, err := pathx.LoadTemplate(category, tableNameTemplateFile, template.TableName)
if err != nil {
return "", err
}
output, err := util.With("tableName").
Parse(text).
Execute(map[string]interface{}{
"tableName": table.Name.Source(),
"upperStartCamelObject": table.Name.ToCamel(),
})
if err != nil {
return "", nil
}
return output.String(), nil
}

View File

@@ -22,8 +22,10 @@ const (
importsWithNoCacheTemplateFile = "import-no-cache.tpl"
insertTemplateFile = "insert.tpl"
insertTemplateMethodFile = "interface-insert.tpl"
modelTemplateFile = "model.tpl"
modelGenTemplateFile = "model-gen.tpl"
modelCustomTemplateFile = "model.tpl"
modelNewTemplateFile = "model-new.tpl"
tableNameTemplateFile = "table-name.tpl"
tagTemplateFile = "tag.tpl"
typesTemplateFile = "types.tpl"
updateTemplateFile = "update.tpl"
@@ -45,8 +47,10 @@ var templates = map[string]string{
importsWithNoCacheTemplateFile: template.ImportsNoCache,
insertTemplateFile: template.Insert,
insertTemplateMethodFile: template.InsertMethod,
modelTemplateFile: template.Model,
modelGenTemplateFile: template.ModelGen,
modelCustomTemplateFile: template.ModelCustom,
modelNewTemplateFile: template.New,
tableNameTemplateFile: template.TableName,
tagTemplateFile: template.Tag,
typesTemplateFile: template.Types,
updateTemplateFile: template.Update,
@@ -70,7 +74,7 @@ func GenTemplates(_ *cli.Context) error {
return pathx.InitTemplates(category, templates)
}
// RevertTemplate recovers the delete template files
// RevertTemplate reverts the deleted template files
func RevertTemplate(name string) error {
content, ok := templates[name]
if !ok {

View File

@@ -4,6 +4,7 @@ import (
"github.com/zeromicro/go-zero/tools/goctl/model/sql/template"
"github.com/zeromicro/go-zero/tools/goctl/util"
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
"github.com/zeromicro/go-zero/tools/goctl/util/stringx"
)
func genTypes(table Table, methods string, withCache bool) (string, error) {
@@ -24,6 +25,7 @@ func genTypes(table Table, methods string, withCache bool) (string, error) {
"withCache": withCache,
"method": methods,
"upperStartCamelObject": table.Name.ToCamel(),
"lowerStartCamelObject": stringx.From(table.Name.ToCamel()).Untitle(),
"fields": fieldsString,
"data": table,
})

View File

@@ -1,7 +1,8 @@
package template
// Delete defines a delete template
var Delete = `
const (
// Delete defines a delete template
Delete = `
func (m *default{{.upperStartCamelObject}}Model) Delete(ctx context.Context, {{.lowerStartCamelPrimaryKey}} {{.dataType}}) error {
{{if .withCache}}{{if .containsIndexCache}}data, err:=m.FindOne(ctx, {{.lowerStartCamelPrimaryKey}})
if err!=nil{
@@ -18,5 +19,6 @@ func (m *default{{.upperStartCamelObject}}Model) Delete(ctx context.Context, {{.
}
`
// DeleteMethod defines a delete template for interface method
var DeleteMethod = `Delete(ctx context.Context, {{.lowerStartCamelPrimaryKey}} {{.dataType}}) error`
// DeleteMethod defines a delete template for interface method
DeleteMethod = `Delete(ctx context.Context, {{.lowerStartCamelPrimaryKey}} {{.dataType}}) error`
)

View File

@@ -1,7 +1,7 @@
package template
// Error defines an error template
var Error = `package {{.pkg}}
const Error = `package {{.pkg}}
import "github.com/zeromicro/go-zero/core/stores/sqlx"

View File

@@ -1,4 +1,4 @@
package template
// Field defines a filed template for types
var Field = `{{.name}} {{.type}} {{.tag}} {{if .hasComment}}// {{.comment}}{{end}}`
const Field = `{{.name}} {{.type}} {{.tag}} {{if .hasComment}}// {{.comment}}{{end}}`

View File

@@ -1,7 +1,8 @@
package template
// FindOne defines find row by id.
var FindOne = `
const (
// FindOne defines find row by id.
FindOne = `
func (m *default{{.upperStartCamelObject}}Model) FindOne(ctx context.Context, {{.lowerStartCamelPrimaryKey}} {{.dataType}}) (*{{.upperStartCamelObject}}, error) {
{{if .withCache}}{{.cacheKey}}
var resp {{.upperStartCamelObject}}
@@ -30,8 +31,8 @@ func (m *default{{.upperStartCamelObject}}Model) FindOne(ctx context.Context, {{
}
`
// FindOneByField defines find row by field.
var FindOneByField = `
// FindOneByField defines find row by field.
FindOneByField = `
func (m *default{{.upperStartCamelObject}}Model) FindOneBy{{.upperField}}(ctx context.Context, {{.in}}) (*{{.upperStartCamelObject}}, error) {
{{if .withCache}}{{.cacheKey}}
var resp {{.upperStartCamelObject}}
@@ -64,8 +65,8 @@ func (m *default{{.upperStartCamelObject}}Model) FindOneBy{{.upperField}}(ctx co
}{{end}}
`
// FindOneByFieldExtraMethod defines find row by field with extras.
var FindOneByFieldExtraMethod = `
// FindOneByFieldExtraMethod defines find row by field with extras.
FindOneByFieldExtraMethod = `
func (m *default{{.upperStartCamelObject}}Model) formatPrimary(primary interface{}) string {
return fmt.Sprintf("%s%v", {{.primaryKeyLeft}}, primary)
}
@@ -76,8 +77,9 @@ func (m *default{{.upperStartCamelObject}}Model) queryPrimary(ctx context.Contex
}
`
// FindOneMethod defines find row method.
var FindOneMethod = `FindOne(ctx context.Context, {{.lowerStartCamelPrimaryKey}} {{.dataType}}) (*{{.upperStartCamelObject}}, error)`
// FindOneMethod defines find row method.
FindOneMethod = `FindOne(ctx context.Context, {{.lowerStartCamelPrimaryKey}} {{.dataType}}) (*{{.upperStartCamelObject}}, error)`
// FindOneByFieldMethod defines find row by field method.
var FindOneByFieldMethod = `FindOneBy{{.upperField}}(ctx context.Context, {{.in}}) (*{{.upperStartCamelObject}}, error) `
// FindOneByFieldMethod defines find row by field method.
FindOneByFieldMethod = `FindOneBy{{.upperField}}(ctx context.Context, {{.in}}) (*{{.upperStartCamelObject}}, error) `
)

View File

@@ -1,6 +1,6 @@
package template
var (
const (
// Imports defines a import template for model in cache case
Imports = `import (
"context"

View File

@@ -1,19 +1,19 @@
package template
// Insert defines a template for insert code in model
var Insert = `
const (
// Insert defines a template for insert code in model
Insert = `
func (m *default{{.upperStartCamelObject}}Model) Insert(ctx context.Context, data *{{.upperStartCamelObject}}) (sql.Result,error) {
{{if .withCache}}{{if .containsIndexCache}}{{.keys}}
{{if .withCache}}{{.keys}}
ret, err := m.ExecCtx(ctx, func(ctx context.Context, conn sqlx.SqlConn) (result sql.Result, err error) {
query := fmt.Sprintf("insert into %s (%s) values ({{.expression}})", m.table, {{.lowerStartCamelObject}}RowsExpectAutoSet)
return conn.ExecCtx(ctx, query, {{.expressionValues}})
}, {{.keyValues}}){{else}}query := fmt.Sprintf("insert into %s (%s) values ({{.expression}})", m.table, {{.lowerStartCamelObject}}RowsExpectAutoSet)
ret,err:=m.ExecNoCacheCtx(ctx, query, {{.expressionValues}})
{{end}}{{else}}query := fmt.Sprintf("insert into %s (%s) values ({{.expression}})", m.table, {{.lowerStartCamelObject}}RowsExpectAutoSet)
ret,err:=m.conn.ExecCtx(ctx, query, {{.expressionValues}}){{end}}
return ret,err
}
`
// InsertMethod defines an interface method template for insert code in model
var InsertMethod = `Insert(ctx context.Context, data *{{.upperStartCamelObject}}) (sql.Result,error)`
// InsertMethod defines an interface method template for insert code in model
InsertMethod = `Insert(ctx context.Context, data *{{.upperStartCamelObject}}) (sql.Result,error)`
)

View File

@@ -1,7 +1,47 @@
package template
// Model defines a template for model
var Model = `package {{.pkg}}
import (
"fmt"
"github.com/zeromicro/go-zero/tools/goctl/util"
)
// ModelCustom defines a template for extension
const ModelCustom = `package {{.pkg}}
{{if .withCache}}
import (
"github.com/zeromicro/go-zero/core/stores/cache"
"github.com/zeromicro/go-zero/core/stores/sqlx"
)
{{else}}
import "github.com/zeromicro/go-zero/core/stores/sqlx"
{{end}}
var _ {{.upperStartCamelObject}}Model = (*custom{{.upperStartCamelObject}}Model)(nil)
type (
// {{.upperStartCamelObject}}Model is an interface to be customized, add more methods here,
// and implement the added methods in custom{{.upperStartCamelObject}}Model.
{{.upperStartCamelObject}}Model interface {
{{.lowerStartCamelObject}}Model
}
custom{{.upperStartCamelObject}}Model struct {
*default{{.upperStartCamelObject}}Model
}
)
// New{{.upperStartCamelObject}}Model returns a model for the database table.
func New{{.upperStartCamelObject}}Model(conn sqlx.SqlConn{{if .withCache}}, c cache.CacheConf{{end}}) {{.upperStartCamelObject}}Model {
return &custom{{.upperStartCamelObject}}Model{
default{{.upperStartCamelObject}}Model: new{{.upperStartCamelObject}}Model(conn{{if .withCache}}, c{{end}}),
}
}
`
// ModelGen defines a template for model
var ModelGen = fmt.Sprintf(`%s
package {{.pkg}}
{{.imports}}
{{.vars}}
{{.types}}
@@ -11,4 +51,5 @@ var Model = `package {{.pkg}}
{{.update}}
{{.delete}}
{{.extraMethod}}
`
{{.tableName}}
`, util.DoNotEditHead)

View File

@@ -1,8 +1,8 @@
package template
// New defines an template for creating model instance
var New = `
func New{{.upperStartCamelObject}}Model(conn sqlx.SqlConn{{if .withCache}}, c cache.CacheConf{{end}}) {{.upperStartCamelObject}}Model {
// New defines the template for creating model instance.
const New = `
func new{{.upperStartCamelObject}}Model(conn sqlx.SqlConn{{if .withCache}}, c cache.CacheConf{{end}}) *default{{.upperStartCamelObject}}Model {
return &default{{.upperStartCamelObject}}Model{
{{if .withCache}}CachedConn: sqlc.NewConn(conn, c){{else}}conn:conn{{end}},
table: {{.table}},

View File

@@ -0,0 +1,8 @@
package template
// TableName defines a template that generate the tableName method.
const TableName = `
func (m *default{{.upperStartCamelObject}}Model) tableName() string {
return m.table
}
`

View File

@@ -1,4 +1,4 @@
package template
// Tag defines a tag template text
var Tag = "`db:\"{{.field}}\"`"
const Tag = "`db:\"{{.field}}\"`"

View File

@@ -1,9 +1,9 @@
package template
// Types defines a template for types in model
var Types = `
// Types defines a template for types in model.
const Types = `
type (
{{.upperStartCamelObject}}Model interface{
{{.lowerStartCamelObject}}Model interface{
{{.method}}
}

View File

@@ -1,7 +1,8 @@
package template
// Update defines a template for generating update codes
var Update = `
const (
// Update defines a template for generating update codes
Update = `
func (m *default{{.upperStartCamelObject}}Model) Update(ctx context.Context, data *{{.upperStartCamelObject}}) error {
{{if .withCache}}{{.keys}}
_, err := m.ExecCtx(ctx, func(ctx context.Context, conn sqlx.SqlConn) (result sql.Result, err error) {
@@ -13,5 +14,6 @@ func (m *default{{.upperStartCamelObject}}Model) Update(ctx context.Context, dat
}
`
// UpdateMethod defines an interface method template for generating update codes
var UpdateMethod = `Update(ctx context.Context, data *{{.upperStartCamelObject}}) error`
// UpdateMethod defines an interface method template for generating update codes
UpdateMethod = `Update(ctx context.Context, data *{{.upperStartCamelObject}}) error`
)

View File

@@ -28,7 +28,7 @@ func New() *SortedMap {
}
}
func (m *SortedMap) SetExpression(expression string) (key interface{}, value interface{}, err error) {
func (m *SortedMap) SetExpression(expression string) (key, value interface{}, err error) {
idx := strings.Index(expression, "=")
if idx == -1 {
return "", "", ErrInvalidKVExpression
@@ -86,7 +86,7 @@ func (m *SortedMap) Get(key interface{}) (interface{}, bool) {
return e.Value.(KV)[1], true
}
func (m *SortedMap) GetOr(key interface{}, dft interface{}) interface{} {
func (m *SortedMap) GetOr(key, dft interface{}) interface{} {
e, ok := m.keys[key]
if !ok {
return dft

View File

@@ -6,7 +6,7 @@ import (
"os"
)
func Download(url string, filename string) error {
func Download(url, filename string) error {
resp, err := http.Get(url)
if err != nil {
return err

View File

@@ -88,7 +88,7 @@ func Get(key string) string {
return GetOr(key, "")
}
func GetOr(key string, def string) string {
func GetOr(key, def string) string {
return goctlEnv.GetStringOr(key, def)
}

View File

@@ -4,84 +4,16 @@ import (
"errors"
"fmt"
"path/filepath"
"runtime"
"github.com/urfave/cli"
"github.com/zeromicro/go-zero/tools/goctl/rpc/generator"
"github.com/zeromicro/go-zero/tools/goctl/util"
"github.com/zeromicro/go-zero/tools/goctl/util/console"
"github.com/zeromicro/go-zero/tools/goctl/util/env"
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
)
// Deprecated: use ZRPC instead.
// RPC is to generate rpc service code from a proto file by specifying a proto file using flag src,
// you can specify a target folder for code generation, when the proto file has import, you can specify
// the import search directory through the proto_path command, for specific usage, please refer to protoc -h
func RPC(c *cli.Context) error {
console.Warning("deprecated: use %q instead, for the details see %q",
"goctl rpc protoc", "goctl rpc protoc --help")
if err := prepare(); err != nil {
return err
}
src := c.String("src")
out := c.String("dir")
style := c.String("style")
protoImportPath := c.StringSlice("proto_path")
goOptions := c.StringSlice("go_opt")
home := c.String("home")
remote := c.String("remote")
branch := c.String("branch")
if len(remote) > 0 {
repo, _ := util.CloneIntoGitHome(remote, branch)
if len(repo) > 0 {
home = repo
}
}
if len(home) > 0 {
pathx.RegisterGoctlHome(home)
}
if len(src) == 0 {
return errors.New("missing -src")
}
if len(out) == 0 {
return errors.New("missing -dir")
}
g, err := generator.NewDefaultRPCGenerator(style)
if err != nil {
return err
}
return g.Generate(src, out, protoImportPath, goOptions...)
}
func prepare() error {
if !env.CanExec() {
return fmt.Errorf("%s: can not start new processes using os.StartProcess or exec.Command", runtime.GOOS)
}
if _, err := env.LookUpGo(); err != nil {
return err
}
if _, err := env.LookUpProtoc(); err != nil {
return err
}
if _, err := env.LookUpProtocGenGo(); err != nil {
return err
}
return nil
}
// RPCNew is to generate rpc greet service, this greet service can speed
// up your understanding of the zrpc service structure
func RPCNew(c *cli.Context) error {
console.Warning("deprecated: it will be removed in the feature, zrpc code generation please use %q instead",
"goctl rpc protoc")
rpcname := c.Args().First()
ext := filepath.Ext(rpcname)
if len(ext) > 0 {
@@ -91,6 +23,7 @@ func RPCNew(c *cli.Context) error {
home := c.String("home")
remote := c.String("remote")
branch := c.String("branch")
verbose := c.Bool("verbose")
if len(remote) > 0 {
repo, _ := util.CloneIntoGitHome(remote, branch)
if len(repo) > 0 {
@@ -113,12 +46,15 @@ func RPCNew(c *cli.Context) error {
return err
}
g, err := generator.NewDefaultRPCGenerator(style)
if err != nil {
return err
}
return g.Generate(src, filepath.Dir(src), nil)
var ctx generator.ZRpcContext
ctx.Src = src
ctx.GoOutput = filepath.Dir(src)
ctx.GrpcOutput = filepath.Dir(src)
ctx.IsGooglePlugin = true
ctx.Output = filepath.Dir(src)
ctx.ProtocCmd = fmt.Sprintf("protoc -I=%s %s --go_out=%s --go-grpc_out=%s", filepath.Dir(src), filepath.Base(src), filepath.Dir(src), filepath.Dir(src))
g := generator.NewGenerator(style, verbose)
return g.Generate(&ctx)
}
// RPCTemplate is the entry for generate rpc template

View File

@@ -42,6 +42,7 @@ func ZRPC(c *cli.Context) error {
home := c.String("home")
remote := c.String("remote")
branch := c.String("branch")
verbose := c.Bool("verbose")
if len(grpcOutList) == 0 {
return errInvalidGrpcOutput
}
@@ -107,12 +108,8 @@ func ZRPC(c *cli.Context) error {
ctx.IsGooglePlugin = isGooglePlugin
ctx.Output = zrpcOut
ctx.ProtocCmd = strings.Join(protocArgs, " ")
g, err := generator.NewDefaultRPCGenerator(style, generator.WithZRpcContext(&ctx))
if err != nil {
return err
}
return g.Generate(source, zrpcOut, nil)
g := generator.NewGenerator(style, verbose)
return g.Generate(&ctx)
}
func removeGoctlFlag(args []string) []string {
@@ -121,11 +118,18 @@ func removeGoctlFlag(args []string) []string {
for step < len(args) {
arg := args[step]
switch {
case arg == "--style", arg == "--home", arg == "--zrpc_out":
case arg == "--style", arg == "--home",
arg == "--zrpc_out", arg == "--verbose",
arg == "-v", arg == "--remote",
arg == "--branch":
step += 2
continue
case strings.HasPrefix(arg, "--style="),
strings.HasPrefix(arg, "--home="),
strings.HasPrefix(arg, "--verbose="),
strings.HasPrefix(arg, "-v="),
strings.HasPrefix(arg, "--remote="),
strings.HasPrefix(arg, "--branch="),
strings.HasPrefix(arg, "--zrpc_out="):
step += 1
continue

View File

@@ -83,6 +83,14 @@ func Test_RemoveGoctlFlag(t *testing.T) {
source: strings.Fields(`protoc --go_opt=. --go-grpc_out=. --zrpc_out=. foo.proto`),
expected: "protoc --go_opt=. --go-grpc_out=. foo.proto",
},
{
source: strings.Fields(`protoc --go_opt=. --go-grpc_out=. --zrpc_out=. --remote=foo --branch=bar foo.proto`),
expected: "protoc --go_opt=. --go-grpc_out=. foo.proto",
},
{
source: strings.Fields(`protoc --go_opt=. --go-grpc_out=. --zrpc_out=. --remote foo --branch bar foo.proto`),
expected: "protoc --go_opt=. --go-grpc_out=. foo.proto",
},
}
for _, e := range testData {
cmd := strings.Join(removeGoctlFlag(e.source), " ")

View File

@@ -1,6 +1,7 @@
syntax = "proto3";
package common;
option go_package="./common";
message User {
string name = 1;

View File

@@ -1,28 +0,0 @@
package generator
import (
"github.com/zeromicro/go-zero/tools/goctl/env"
"github.com/zeromicro/go-zero/tools/goctl/util/console"
)
// DefaultGenerator defines the environment needs of rpc service generation
type DefaultGenerator struct {
log console.Console
}
// just test interface implement
var _ Generator = (*DefaultGenerator)(nil)
// NewDefaultGenerator returns an instance of DefaultGenerator
func NewDefaultGenerator() Generator {
log := console.NewColorConsole()
return &DefaultGenerator{
log: log,
}
}
// Prepare provides environment detection generated by rpc service,
// including go environment, protoc, whether protoc-gen-go is installed or not
func (g *DefaultGenerator) Prepare() error {
return env.Prepare(true, true)
}

View File

@@ -3,22 +3,12 @@ package generator
import (
"path/filepath"
conf "github.com/zeromicro/go-zero/tools/goctl/config"
"github.com/zeromicro/go-zero/tools/goctl/rpc/parser"
"github.com/zeromicro/go-zero/tools/goctl/util/console"
"github.com/zeromicro/go-zero/tools/goctl/util/ctx"
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
)
// RPCGenerator defines a generator and configure
type RPCGenerator struct {
g Generator
cfg *conf.Config
ctx *ZRpcContext
}
type RPCGeneratorOption func(g *RPCGenerator)
type ZRpcContext struct {
Src string
ProtocCmd string
@@ -30,38 +20,11 @@ type ZRpcContext struct {
Output string
}
// NewDefaultRPCGenerator wraps Generator with configure
func NewDefaultRPCGenerator(style string, options ...RPCGeneratorOption) (*RPCGenerator, error) {
cfg, err := conf.NewConfig(style)
if err != nil {
return nil, err
}
return NewRPCGenerator(NewDefaultGenerator(), cfg, options...), nil
}
// NewRPCGenerator creates an instance for RPCGenerator
func NewRPCGenerator(g Generator, cfg *conf.Config, options ...RPCGeneratorOption) *RPCGenerator {
out := &RPCGenerator{
g: g,
cfg: cfg,
}
for _, opt := range options {
opt(out)
}
return out
}
func WithZRpcContext(c *ZRpcContext) RPCGeneratorOption {
return func(g *RPCGenerator) {
g.ctx = c
}
}
// Generate generates an rpc service, through the proto file,
// code storage directory, and proto import parameters to control
// the source file and target location of the rpc service that needs to be generated
func (g *RPCGenerator) Generate(src, target string, protoImportPath []string, goOptions ...string) error {
abs, err := filepath.Abs(target)
func (g *Generator) Generate(zctx *ZRpcContext) error {
abs, err := filepath.Abs(zctx.Output)
if err != nil {
return err
}
@@ -71,7 +34,7 @@ func (g *RPCGenerator) Generate(src, target string, protoImportPath []string, go
return err
}
err = g.g.Prepare()
err = g.Prepare()
if err != nil {
return err
}
@@ -82,52 +45,52 @@ func (g *RPCGenerator) Generate(src, target string, protoImportPath []string, go
}
p := parser.NewDefaultProtoParser()
proto, err := p.Parse(src)
proto, err := p.Parse(zctx.Src)
if err != nil {
return err
}
dirCtx, err := mkdir(projectCtx, proto, g.cfg, g.ctx)
dirCtx, err := mkdir(projectCtx, proto, g.cfg, zctx)
if err != nil {
return err
}
err = g.g.GenEtc(dirCtx, proto, g.cfg)
err = g.GenEtc(dirCtx, proto, g.cfg)
if err != nil {
return err
}
err = g.g.GenPb(dirCtx, protoImportPath, proto, g.cfg, g.ctx, goOptions...)
err = g.GenPb(dirCtx, zctx)
if err != nil {
return err
}
err = g.g.GenConfig(dirCtx, proto, g.cfg)
err = g.GenConfig(dirCtx, proto, g.cfg)
if err != nil {
return err
}
err = g.g.GenSvc(dirCtx, proto, g.cfg)
err = g.GenSvc(dirCtx, proto, g.cfg)
if err != nil {
return err
}
err = g.g.GenLogic(dirCtx, proto, g.cfg)
err = g.GenLogic(dirCtx, proto, g.cfg)
if err != nil {
return err
}
err = g.g.GenServer(dirCtx, proto, g.cfg)
err = g.GenServer(dirCtx, proto, g.cfg)
if err != nil {
return err
}
err = g.g.GenMain(dirCtx, proto, g.cfg)
err = g.GenMain(dirCtx, proto, g.cfg)
if err != nil {
return err
}
err = g.g.GenCall(dirCtx, proto, g.cfg)
err = g.GenCall(dirCtx, proto, g.cfg)
console.NewColorConsole().MarkDone()

View File

@@ -1,6 +1,7 @@
package generator
import (
"fmt"
"go/build"
"os"
"path/filepath"
@@ -10,26 +11,19 @@ import (
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stringx"
conf "github.com/zeromicro/go-zero/tools/goctl/config"
"github.com/zeromicro/go-zero/tools/goctl/rpc/execx"
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
)
var cfg = &conf.Config{
NamingFormat: "gozero",
}
func TestRpcGenerate(t *testing.T) {
_ = Clean()
dispatcher := NewDefaultGenerator()
err := dispatcher.Prepare()
g := NewGenerator("gozero", true)
err := g.Prepare()
if err != nil {
logx.Error(err)
return
}
projectName := stringx.Rand()
g := NewRPCGenerator(dispatcher, cfg)
src := filepath.Join(build.Default.GOPATH, "src")
_, err = os.Stat(src)
if err != nil {
@@ -46,7 +40,15 @@ func TestRpcGenerate(t *testing.T) {
// case go path
t.Run("GOPATH", func(t *testing.T) {
err = g.Generate("./test.proto", projectDir, []string{common}, "Mbase/common.proto=./base")
ctx := &ZRpcContext{
Src: "./test.proto",
ProtocCmd: fmt.Sprintf("protoc -I=%s test.proto --go_out=%s --go_opt=Mbase/common.proto=./base --go-grpc_out=%s", common, projectDir, projectDir),
IsGooglePlugin: true,
GoOutput: projectDir,
GrpcOutput: projectDir,
Output: projectDir,
}
err = g.Generate(ctx)
assert.Nil(t, err)
_, err = execx.Run("go test "+projectName, projectDir)
if err != nil {
@@ -67,7 +69,15 @@ func TestRpcGenerate(t *testing.T) {
}
projectDir = filepath.Join(workDir, projectName)
err = g.Generate("./test.proto", projectDir, []string{common}, "Mbase/common.proto=./base")
ctx := &ZRpcContext{
Src: "./test.proto",
ProtocCmd: fmt.Sprintf("protoc -I=%s test.proto --go_out=%s --go_opt=Mbase/common.proto=./base --go-grpc_out=%s", common, projectDir, projectDir),
IsGooglePlugin: true,
GoOutput: projectDir,
GrpcOutput: projectDir,
Output: projectDir,
}
err = g.Generate(ctx)
assert.Nil(t, err)
_, err = execx.Run("go test "+projectName, projectDir)
if err != nil {
@@ -79,7 +89,15 @@ func TestRpcGenerate(t *testing.T) {
// case not in go mod and go path
t.Run("OTHER", func(t *testing.T) {
err = g.Generate("./test.proto", projectDir, []string{common, src}, "Mbase/common.proto=./base")
ctx := &ZRpcContext{
Src: "./test.proto",
ProtocCmd: fmt.Sprintf("protoc -I=%s test.proto --go_out=%s --go_opt=Mbase/common.proto=./base --go-grpc_out=%s", common, projectDir, projectDir),
IsGooglePlugin: true,
GoOutput: projectDir,
GrpcOutput: projectDir,
Output: projectDir,
}
err = g.Generate(ctx)
assert.Nil(t, err)
_, err = execx.Run("go test "+projectName, projectDir)
if err != nil {

View File

@@ -58,7 +58,7 @@ func New{{.serviceName}}(cli zrpc.Client) {{.serviceName}} {
callFunctionTemplate = `
{{if .hasComment}}{{.comment}}{{end}}
func (m *default{{.serviceName}}) {{.method}}(ctx context.Context{{if .hasReq}}, in *{{.pbRequest}}{{end}}, opts ...grpc.CallOption) ({{if .notStream}}*{{.pbResponse}}, {{else}}{{.streamBody}},{{end}} error) {
client := {{.package}}.New{{.rpcServiceName}}Client(m.cli.Conn())
client := {{if .isCallPkgSameToGrpcPkg}}{{else}}{{.package}}.{{end}}New{{.rpcServiceName}}Client(m.cli.Conn())
return client.{{.method}}(ctx{{if .hasReq}}, in{{end}}, opts...)
}
`
@@ -66,10 +66,12 @@ func (m *default{{.serviceName}}) {{.method}}(ctx context.Context{{if .hasReq}},
// GenCall generates the rpc client code, which is the entry point for the rpc service call.
// It is a layer of encapsulation for the rpc client and shields the details in the pb.
func (g *DefaultGenerator) GenCall(ctx DirContext, proto parser.Proto, cfg *conf.Config) error {
func (g *Generator) GenCall(ctx DirContext, proto parser.Proto, cfg *conf.Config) error {
dir := ctx.GetCall()
service := proto.Service
head := util.GetHead(proto.Name)
isCallPkgSameToPbPkg := ctx.GetCall().Filename == ctx.GetPb().Filename
isCallPkgSameToGrpcPkg := ctx.GetCall().Filename == ctx.GetProtoGo().Filename
callFilename, err := format.FileNamingFormat(cfg.NamingFormat, service.Name)
if err != nil {
@@ -77,12 +79,12 @@ func (g *DefaultGenerator) GenCall(ctx DirContext, proto parser.Proto, cfg *conf
}
filename := filepath.Join(dir.Filename, fmt.Sprintf("%s.go", callFilename))
functions, err := g.genFunction(proto.PbPackage, service)
functions, err := g.genFunction(proto.PbPackage, service, isCallPkgSameToGrpcPkg)
if err != nil {
return err
}
iFunctions, err := g.getInterfaceFuncs(proto.PbPackage, service)
iFunctions, err := g.getInterfaceFuncs(proto.PbPackage, service, isCallPkgSameToGrpcPkg)
if err != nil {
return err
}
@@ -93,11 +95,19 @@ func (g *DefaultGenerator) GenCall(ctx DirContext, proto parser.Proto, cfg *conf
}
alias := collection.NewSet()
for _, item := range proto.Message {
msgName := getMessageName(*item.Message)
alias.AddStr(fmt.Sprintf("%s = %s", parser.CamelCase(msgName), fmt.Sprintf("%s.%s", proto.PbPackage, parser.CamelCase(msgName))))
if !isCallPkgSameToPbPkg {
for _, item := range proto.Message {
msgName := getMessageName(*item.Message)
alias.AddStr(fmt.Sprintf("%s = %s", parser.CamelCase(msgName), fmt.Sprintf("%s.%s", proto.PbPackage, parser.CamelCase(msgName))))
}
}
pbPackage := fmt.Sprintf(`"%s"`, ctx.GetPb().Package)
protoGoPackage := fmt.Sprintf(`"%s"`, ctx.GetProtoGo().Package)
if isCallPkgSameToGrpcPkg {
pbPackage = ""
protoGoPackage = ""
}
aliasKeys := alias.KeysStr()
sort.Strings(aliasKeys)
err = util.With("shared").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
@@ -105,8 +115,8 @@ func (g *DefaultGenerator) GenCall(ctx DirContext, proto parser.Proto, cfg *conf
"alias": strings.Join(aliasKeys, pathx.NL),
"head": head,
"filePackage": dir.Base,
"pbPackage": fmt.Sprintf(`"%s"`, ctx.GetPb().Package),
"protoGoPackage": fmt.Sprintf(`"%s"`, ctx.GetProtoGo().Package),
"pbPackage": pbPackage,
"protoGoPackage": protoGoPackage,
"serviceName": stringx.From(service.Name).ToCamel(),
"functions": strings.Join(functions, pathx.NL),
"interface": strings.Join(iFunctions, pathx.NL),
@@ -136,7 +146,7 @@ func getMessageName(msg proto.Message) string {
return strings.Join(list, "_")
}
func (g *DefaultGenerator) genFunction(goPackage string, service parser.Service) ([]string, error) {
func (g *Generator) genFunction(goPackage string, service parser.Service, isCallPkgSameToGrpcPkg bool) ([]string, error) {
functions := make([]string, 0)
for _, rpc := range service.RPC {
@@ -147,18 +157,22 @@ func (g *DefaultGenerator) genFunction(goPackage string, service parser.Service)
comment := parser.GetComment(rpc.Doc())
streamServer := fmt.Sprintf("%s.%s_%s%s", goPackage, parser.CamelCase(service.Name), parser.CamelCase(rpc.Name), "Client")
if isCallPkgSameToGrpcPkg {
streamServer = fmt.Sprintf("%s_%s%s", parser.CamelCase(service.Name), parser.CamelCase(rpc.Name), "Client")
}
buffer, err := util.With("sharedFn").Parse(text).Execute(map[string]interface{}{
"serviceName": stringx.From(service.Name).ToCamel(),
"rpcServiceName": parser.CamelCase(service.Name),
"method": parser.CamelCase(rpc.Name),
"package": goPackage,
"pbRequest": parser.CamelCase(rpc.RequestType),
"pbResponse": parser.CamelCase(rpc.ReturnsType),
"hasComment": len(comment) > 0,
"comment": comment,
"hasReq": !rpc.StreamsRequest,
"notStream": !rpc.StreamsRequest && !rpc.StreamsReturns,
"streamBody": streamServer,
"serviceName": stringx.From(service.Name).ToCamel(),
"rpcServiceName": parser.CamelCase(service.Name),
"method": parser.CamelCase(rpc.Name),
"package": goPackage,
"pbRequest": parser.CamelCase(rpc.RequestType),
"pbResponse": parser.CamelCase(rpc.ReturnsType),
"hasComment": len(comment) > 0,
"comment": comment,
"hasReq": !rpc.StreamsRequest,
"notStream": !rpc.StreamsRequest && !rpc.StreamsReturns,
"streamBody": streamServer,
"isCallPkgSameToGrpcPkg": isCallPkgSameToGrpcPkg,
})
if err != nil {
return nil, err
@@ -170,7 +184,7 @@ func (g *DefaultGenerator) genFunction(goPackage string, service parser.Service)
return functions, nil
}
func (g *DefaultGenerator) getInterfaceFuncs(goPackage string, service parser.Service) ([]string, error) {
func (g *Generator) getInterfaceFuncs(goPackage string, service parser.Service, isCallPkgSameToGrpcPkg bool) ([]string, error) {
functions := make([]string, 0)
for _, rpc := range service.RPC {
@@ -181,6 +195,9 @@ func (g *DefaultGenerator) getInterfaceFuncs(goPackage string, service parser.Se
comment := parser.GetComment(rpc.Doc())
streamServer := fmt.Sprintf("%s.%s_%s%s", goPackage, parser.CamelCase(service.Name), parser.CamelCase(rpc.Name), "Client")
if isCallPkgSameToGrpcPkg {
streamServer = fmt.Sprintf("%s_%s%s", parser.CamelCase(service.Name), parser.CamelCase(rpc.Name), "Client")
}
buffer, err := util.With("interfaceFn").Parse(text).Execute(
map[string]interface{}{
"hasComment": len(comment) > 0,

View File

@@ -24,7 +24,7 @@ type Config struct {
// which contains the zrpc.RpcServerConf configuration item by default.
// You can specify the naming style of the target file name through config.Config. For details,
// see https://github.com/zeromicro/go-zero/tree/master/tools/goctl/config/config.go
func (g *DefaultGenerator) GenConfig(ctx DirContext, _ parser.Proto, cfg *conf.Config) error {
func (g *Generator) GenConfig(ctx DirContext, _ parser.Proto, cfg *conf.Config) error {
dir := ctx.GetConfig()
configFilename, err := format.FileNamingFormat(cfg.NamingFormat, "config")
if err != nil {

View File

@@ -1,19 +1,36 @@
package generator
import (
"log"
conf "github.com/zeromicro/go-zero/tools/goctl/config"
"github.com/zeromicro/go-zero/tools/goctl/rpc/parser"
"github.com/zeromicro/go-zero/tools/goctl/env"
"github.com/zeromicro/go-zero/tools/goctl/util/console"
)
// Generator defines a generator interface to describe how to generate rpc service
type Generator interface {
Prepare() error
GenMain(ctx DirContext, proto parser.Proto, cfg *conf.Config) error
GenCall(ctx DirContext, proto parser.Proto, cfg *conf.Config) error
GenEtc(ctx DirContext, proto parser.Proto, cfg *conf.Config) error
GenConfig(ctx DirContext, proto parser.Proto, cfg *conf.Config) error
GenLogic(ctx DirContext, proto parser.Proto, cfg *conf.Config) error
GenServer(ctx DirContext, proto parser.Proto, cfg *conf.Config) error
GenSvc(ctx DirContext, proto parser.Proto, cfg *conf.Config) error
GenPb(ctx DirContext, protoImportPath []string, proto parser.Proto, cfg *conf.Config, c *ZRpcContext, goOptions ...string) error
// Generator defines the environment needs of rpc service generation
type Generator struct {
log console.Console
cfg *conf.Config
verbose bool
}
// NewGenerator returns an instance of Generator
func NewGenerator(style string, verbose bool) *Generator {
cfg, err := conf.NewConfig(style)
if err != nil {
log.Fatalln(err)
}
log := console.NewColorConsole(verbose)
return &Generator{
log: log,
cfg: cfg,
verbose: verbose,
}
}
// Prepare provides environment detection generated by rpc service,
// including go environment, protoc, whether protoc-gen-go is installed or not
func (g *Generator) Prepare() error {
return env.Prepare(true, true, g.verbose)
}

View File

@@ -23,7 +23,7 @@ Etcd:
// GenEtc generates the yaml configuration file of the rpc service,
// including host, port monitoring configuration items and etcd configuration
func (g *DefaultGenerator) GenEtc(ctx DirContext, _ parser.Proto, cfg *conf.Config) error {
func (g *Generator) GenEtc(ctx DirContext, _ parser.Proto, cfg *conf.Config) error {
dir := ctx.GetEtc()
etcFilename, err := format.FileNamingFormat(cfg.NamingFormat, ctx.GetServiceName().Source())
if err != nil {

View File

@@ -50,7 +50,7 @@ func (l *{{.logicName}}) {{.method}} ({{if .hasReq}}in {{.request}}{{if .stream}
)
// GenLogic generates the logic file of the rpc service, which corresponds to the RPC definition items in proto.
func (g *DefaultGenerator) GenLogic(ctx DirContext, proto parser.Proto, cfg *conf.Config) error {
func (g *Generator) GenLogic(ctx DirContext, proto parser.Proto, cfg *conf.Config) error {
dir := ctx.GetLogic()
service := proto.Service.Service.Name
for _, rpc := range proto.Service.RPC {
@@ -84,7 +84,7 @@ func (g *DefaultGenerator) GenLogic(ctx DirContext, proto parser.Proto, cfg *con
return nil
}
func (g *DefaultGenerator) genLogicFunction(serviceName, goPackage string, rpc *parser.RPC) (string, error) {
func (g *Generator) genLogicFunction(serviceName, goPackage string, rpc *parser.RPC) (string, error) {
functions := make([]string, 0)
text, err := pathx.LoadTemplate(category, logicFuncTemplateFileFile, logicFunctionTemplate)
if err != nil {

View File

@@ -53,7 +53,7 @@ func main() {
`
// GenMain generates the main file of the rpc service, which is an rpc service program call entry
func (g *DefaultGenerator) GenMain(ctx DirContext, proto parser.Proto, cfg *conf.Config) error {
func (g *Generator) GenMain(ctx DirContext, proto parser.Proto, cfg *conf.Config) error {
mainFilename, err := format.FileNamingFormat(cfg.NamingFormat, ctx.GetServiceName().Source())
if err != nil {
return err

View File

@@ -1,118 +1,22 @@
package generator
import (
"bytes"
"errors"
"fmt"
"io/fs"
"os"
"path/filepath"
"strings"
"github.com/zeromicro/go-zero/core/collection"
conf "github.com/zeromicro/go-zero/tools/goctl/config"
"github.com/zeromicro/go-zero/tools/goctl/rpc/execx"
"github.com/zeromicro/go-zero/tools/goctl/rpc/parser"
)
const googleProtocGenGoErr = `--go_out: protoc-gen-go: plugins are not supported; use 'protoc --go-grpc_out=...' to generate gRPC`
// GenPb generates the pb.go file, which is a layer of packaging for protoc to generate gprc,
// but the commands and flags in protoc are not completely joined in goctl. At present, proto_path(-I) is introduced
func (g *DefaultGenerator) GenPb(ctx DirContext, protoImportPath []string, proto parser.Proto, _ *conf.Config, c *ZRpcContext, goOptions ...string) error {
if c != nil {
return g.genPbDirect(ctx, c)
}
// deprecated: use genPbDirect instead.
dir := ctx.GetPb()
cw := new(bytes.Buffer)
directory, base := filepath.Split(proto.Src)
directory = filepath.Clean(directory)
cw.WriteString("protoc ")
protoImportPathSet := collection.NewSet()
isSamePackage := true
for _, ip := range protoImportPath {
pip := " --proto_path=" + ip
if protoImportPathSet.Contains(pip) {
continue
}
abs, err := filepath.Abs(ip)
if err != nil {
return err
}
if abs == directory {
isSamePackage = true
} else {
isSamePackage = false
}
protoImportPathSet.AddStr(pip)
cw.WriteString(pip)
}
currentPath := " --proto_path=" + directory
if !protoImportPathSet.Contains(currentPath) {
cw.WriteString(currentPath)
}
cw.WriteString(" " + proto.Name)
if strings.Contains(proto.GoPackage, "/") {
cw.WriteString(" --go_out=plugins=grpc:" + ctx.GetMain().Filename)
} else {
cw.WriteString(" --go_out=plugins=grpc:" + dir.Filename)
}
// Compatible with version 1.4.0github.com/golang/protobuf/protoc-gen-go@v1.4.0
// --go_opt usage please see https://developers.google.com/protocol-buffers/docs/reference/go-generated#package
optSet := collection.NewSet()
for _, op := range goOptions {
opt := " --go_opt=" + op
if optSet.Contains(opt) {
continue
}
optSet.AddStr(op)
cw.WriteString(" --go_opt=" + op)
}
var currentFileOpt string
if !isSamePackage || (len(proto.GoPackage) > 0 && proto.GoPackage != proto.Package.Name) {
if filepath.IsAbs(proto.GoPackage) {
currentFileOpt = " --go_opt=M" + base + "=" + proto.GoPackage
} else if strings.Contains(proto.GoPackage, string(filepath.Separator)) {
currentFileOpt = " --go_opt=M" + base + "=./" + proto.GoPackage
} else {
currentFileOpt = " --go_opt=M" + base + "=../" + proto.GoPackage
}
} else {
currentFileOpt = " --go_opt=M" + base + "=."
}
if !optSet.Contains(currentFileOpt) {
cw.WriteString(currentFileOpt)
}
command := cw.String()
g.log.Debug(command)
_, err := execx.Run(command, "")
if err != nil {
if strings.Contains(err.Error(), googleProtocGenGoErr) {
return errors.New(`unsupported plugin protoc-gen-go which installed from the following source:
google.golang.org/protobuf/cmd/protoc-gen-go,
github.com/protocolbuffers/protobuf-go/cmd/protoc-gen-go;
Please replace it by the following command, we recommend to use version before v1.3.5:
go get -u github.com/golang/protobuf/protoc-gen-go`)
}
return err
}
return nil
func (g *Generator) GenPb(ctx DirContext, c *ZRpcContext) error {
return g.genPbDirect(ctx, c)
}
func (g *DefaultGenerator) genPbDirect(ctx DirContext, c *ZRpcContext) error {
func (g *Generator) genPbDirect(ctx DirContext, c *ZRpcContext) error {
g.log.Debug("[command]: %s", c.ProtocCmd)
pwd, err := os.Getwd()
if err != nil {
@@ -126,7 +30,7 @@ func (g *DefaultGenerator) genPbDirect(ctx DirContext, c *ZRpcContext) error {
return g.setPbDir(ctx, c)
}
func (g *DefaultGenerator) setPbDir(ctx DirContext, c *ZRpcContext) error {
func (g *Generator) setPbDir(ctx DirContext, c *ZRpcContext) error {
pbDir, err := findPbFile(c.GoOutput, false)
if err != nil {
return err

View File

@@ -25,7 +25,7 @@ message Resp{}
service Greeter {
rpc greet(Req) returns (Resp);
}
`), 0666)
`), 0o666)
if err != nil {
t.Log(err)
return

View File

@@ -48,7 +48,7 @@ func (s *{{.server}}Server) {{.method}} ({{if .notStream}}ctx context.Context,{{
)
// GenServer generates rpc server file, which is an implementation of rpc server
func (g *DefaultGenerator) GenServer(ctx DirContext, proto parser.Proto, cfg *conf.Config) error {
func (g *Generator) GenServer(ctx DirContext, proto parser.Proto, cfg *conf.Config) error {
dir := ctx.GetServer()
logicImport := fmt.Sprintf(`"%v"`, ctx.GetLogic().Package)
svcImport := fmt.Sprintf(`"%v"`, ctx.GetSvc().Package)
@@ -94,7 +94,7 @@ func (g *DefaultGenerator) GenServer(ctx DirContext, proto parser.Proto, cfg *co
return err
}
func (g *DefaultGenerator) genFunctions(goPackage string, service parser.Service) ([]string, error) {
func (g *Generator) genFunctions(goPackage string, service parser.Service) ([]string, error) {
var functionList []string
for _, rpc := range service.RPC {
text, err := pathx.LoadTemplate(category, serverFuncTemplateFile, functionTemplate)

View File

@@ -28,7 +28,7 @@ func NewServiceContext(c config.Config) *ServiceContext {
// GenSvc generates the servicecontext.go file, which is the resource dependency of a service,
// such as rpc dependency, model dependency, etc.
func (g *DefaultGenerator) GenSvc(ctx DirContext, _ parser.Proto, cfg *conf.Config) error {
func (g *Generator) GenSvc(ctx DirContext, _ parser.Proto, cfg *conf.Config) error {
dir := ctx.GetSvc()
svcFilename, err := format.FileNamingFormat(cfg.NamingFormat, "service_context")
if err != nil {

View File

@@ -12,6 +12,7 @@ import (
const rpcTemplateText = `syntax = "proto3";
package {{.package}};
option go_package="./{{.package}}";
message Request {
string ping = 1;

View File

@@ -24,7 +24,9 @@ type (
Must(err error)
}
colorConsole struct{}
colorConsole struct {
enable bool
}
// for idea log
ideaConsole struct{}
@@ -39,45 +41,75 @@ func NewConsole(idea bool) Console {
}
// NewColorConsole returns an instance of colorConsole
func NewColorConsole() Console {
return &colorConsole{}
func NewColorConsole(enable ...bool) Console {
logEnable := true
for _, e := range enable {
logEnable = e
}
return &colorConsole{
enable: logEnable,
}
}
func (c *colorConsole) Info(format string, a ...interface{}) {
if !c.enable {
return
}
msg := fmt.Sprintf(format, a...)
fmt.Println(msg)
}
func (c *colorConsole) Debug(format string, a ...interface{}) {
if !c.enable {
return
}
msg := fmt.Sprintf(format, a...)
println(aurora.BrightCyan(msg))
}
func (c *colorConsole) Success(format string, a ...interface{}) {
if !c.enable {
return
}
msg := fmt.Sprintf(format, a...)
println(aurora.BrightGreen(msg))
}
func (c *colorConsole) Warning(format string, a ...interface{}) {
if !c.enable {
return
}
msg := fmt.Sprintf(format, a...)
println(aurora.BrightYellow(msg))
}
func (c *colorConsole) Error(format string, a ...interface{}) {
if !c.enable {
return
}
msg := fmt.Sprintf(format, a...)
println(aurora.BrightRed(msg))
}
func (c *colorConsole) Fatalln(format string, a ...interface{}) {
if !c.enable {
return
}
c.Error(format, a...)
os.Exit(1)
}
func (c *colorConsole) MarkDone() {
if !c.enable {
return
}
c.Success("Done.")
}
func (c *colorConsole) Must(err error) {
if !c.enable {
return
}
if err != nil {
c.Fatalln("%+v", err)
}

View File

@@ -12,7 +12,7 @@ import (
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
)
func CloneIntoGitHome(url string, branch string) (dir string, err error) {
func CloneIntoGitHome(url, branch string) (dir string, err error) {
gitHome, err := pathx.GetGitHome()
if err != nil {
return "", err

View File

@@ -1,7 +1,12 @@
package util
var headTemplate = `// Code generated by goctl. DO NOT EDIT!
const (
// DoNotEditHead added to the beginning of a file to prompt the user not to edit
DoNotEditHead = "// Code generated by goctl. DO NOT EDIT!"
headTemplate = DoNotEditHead + `
// Source: {{.source}}`
)
// GetHead returns a code head string with source filename
func GetHead(source string) string {

View File

@@ -79,7 +79,7 @@ func TestGetGoctlHome(t *testing.T) {
t.Run("goctl_is_file", func(t *testing.T) {
tmpFile := filepath.Join(t.TempDir(), "a.tmp")
backupTempFile := tmpFile + ".old"
err := ioutil.WriteFile(tmpFile, nil, 0666)
err := ioutil.WriteFile(tmpFile, nil, 0o666)
if err != nil {
return
}
@@ -104,5 +104,4 @@ func TestGetGoctlHome(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, dir, home)
})
}