Compare commits

..

56 Commits

Author SHA1 Message Date
Kevin Wan
8478474f7f update readme (#673) 2021-05-08 21:55:14 +08:00
anqiansong
df5ae9507f replace antlr module (#672)
* replace antlr module

* refactor version of antlr
2021-05-08 21:35:27 +08:00
noel
faf4d7e3bb modify the order of PrometheusHandler (#670)
* modify the order of PrometheusHandler

* modify the order of PrometheusHandler
2021-05-08 17:11:16 +08:00
anqiansong
f64fe5eb5e fix antlr mod (#669) 2021-05-08 00:03:01 +08:00
heyanfu
97d889103a fix some typo (#667) 2021-05-04 21:33:08 +08:00
Kevin Wan
9a44310d00 update wechat qrcode (#665) 2021-05-02 15:06:16 +08:00
Kevin Wan
06eeef2cf3 disable prometheus if not configured (#663) 2021-04-30 15:09:49 +08:00
Kevin Wan
9adc7d4cb9 fix comment function names (#649) 2021-04-23 11:56:41 +08:00
Kevin Wan
006f78c3d5 add go-zero users (#643) 2021-04-21 10:24:15 +08:00
Kevin Wan
64a8e65f4a update readme (#640) 2021-04-20 23:57:57 +08:00
anqiansong
8fd1e76d29 update readme (#638) 2021-04-19 14:37:47 +08:00
heyanfu
0466af5e49 optimize code (#637) 2021-04-18 22:49:03 +08:00
heyanfu
7405d7f506 spelling mistakes (#634) 2021-04-17 20:15:19 +08:00
Bo-Yi Wu
afd9ff889e chore: update code format. (#628) 2021-04-15 19:49:17 +08:00
另维64
7e087de6e6 doc: fix spell mistake (#627) 2021-04-14 17:58:27 +08:00
Kevin Wan
5aded99df5 update go-zero users (#623) 2021-04-13 14:38:40 +08:00
Kevin Wan
08fb980ad2 add syncx.Guard func (#620) 2021-04-13 00:04:19 +08:00
Kevin Wan
b94d7aa532 update readme (#617) 2021-04-10 19:19:05 +08:00
Kevin Wan
ee630b8b57 add code coverage (#615)
* add code coverage

* simplify redis code
2021-04-09 22:40:43 +08:00
Kevin Wan
bd82b7d8de add FAQs in readme (#612) 2021-04-09 18:59:17 +08:00
Kevin Wan
3d729c77a6 update go-zero users (#611) 2021-04-09 14:16:31 +08:00
Kevin Wan
e944b59bb3 update go-zero users (#609)
* add go-zero users registry notes

* update go-zero users

* fix typo
2021-04-09 10:43:47 +08:00
Kevin Wan
54b5e3f4b2 add go-zero users registry notes (#608) 2021-04-08 22:44:41 +08:00
Kevin Wan
b913229028 add go-zero users (#607) 2021-04-08 22:30:45 +08:00
Kevin Wan
9963ffb1c1 simplify redis tls implementation (#606) 2021-04-08 18:19:36 +08:00
r00mz
8cb6490724 redis增加tls支持 (#595)
* redis连接增加支持tls选项

* 优化redis tls config 写法

* redis增加tls支持

* 增加redis tls测试用例,但redis tls local server不支持,测试用例全部NotNil

Co-authored-by: liuyi <liuyi@fangyb.com>
Co-authored-by: yi.liu <yi.liu@xshoppy.com>
2021-04-07 20:44:16 +08:00
Kevin Wan
05e37ee20f refactor - remove ShrinkDeadline, it's the same as context.WithTimeout (#599) 2021-04-05 22:59:24 +08:00
zjbztianya
d88da4cc88 Replace contextx.ShrinkDeadline with context.WithTimeout (#598) 2021-04-05 21:20:35 +08:00
Oraoto
425430f67c Simplify contextx.ShrinkDeadline (#596) 2021-04-03 21:25:32 +08:00
Zcc、
4e0d91f6c0 fix (#592)
Co-authored-by: zhoudeyu <zhoudeyu@xiaoheiban.cn>
2021-04-01 18:42:50 +08:00
Kevin Wan
8584351b6d update regression test comment (#590) 2021-03-30 21:23:07 +08:00
Kevin Wan
b19c5223a9 update regression test comment (#589) 2021-03-30 20:53:35 +08:00
bittoy
99a2d95433 remove rt mode log (#587) 2021-03-30 20:45:55 +08:00
Ted Chen
9db222bf5b fix a simple typo (#588) 2021-03-29 23:35:49 +08:00
Kevin Wan
ac648d08cb fix typo (#586) 2021-03-28 22:10:07 +08:00
Kevin Wan
6df7fa619c fix typo (#585) 2021-03-28 21:20:04 +08:00
Kevin Wan
bbb4ce586f fix golint issues (#584) 2021-03-28 20:42:11 +08:00
anqiansong
888551627c optimize code (#579)
* optimize code

* optimize returns & unit test
2021-03-27 17:33:17 +08:00
Kevin Wan
bd623aaac3 support postgresql (#583)
support postgresql
2021-03-27 17:14:32 +08:00
Kevin Wan
9e6c2ba2c0 avoid goroutine leak after timeout (#575) 2021-03-21 16:54:34 +08:00
Kevin Wan
c0db8d017d gofmt logs (#574) 2021-03-20 16:40:09 +08:00
TonyWang
52b4f8ca91 add timezone and timeformat (#572)
* add timezone and timeformat

* rm time zone and keep time format

Co-authored-by: Tony Wang <tonywang.data@gmail.com>
2021-03-20 16:36:19 +08:00
Kevin Wan
4884a7b3c6 zrpc timeout & unit tests (#573)
* zrpc timeout & unit tests
2021-03-19 18:41:26 +08:00
Kevin Wan
3c6951577d make hijack more stable (#565) 2021-03-15 20:11:09 +08:00
Kevin Wan
fcd15c9b17 refactor, and add comments to describe graceful shutdown (#564) 2021-03-14 08:51:10 +08:00
Kevin Wan
155e6061cb fix golint issues (#561) 2021-03-12 23:08:04 +08:00
anqiansong
dda7666097 Feature mongo gen (#546)
* add feature: mongo code generation

* upgrade version

* update doc

* format code

* update update.tpl of mysql
2021-03-12 17:49:28 +08:00
hanhotfox
c954568b61 Hdel support for multiple key deletion (#542)
* Hdel support for multiple key deletion

* Hdel field -> fields

Co-authored-by: duanyan <duanyan@xiaoheiban.cn>
2021-03-12 17:47:21 +08:00
Kevin Wan
c2acc43a52 add important notes in readme (#560) 2021-03-12 16:48:25 +08:00
Kevin Wan
1a1a6f5239 add http hijack methods (#555) 2021-03-09 21:30:45 +08:00
anqiansong
60c7edf8f8 fix spelling (#551) 2021-03-08 18:23:12 +08:00
Kevin Wan
7ad86a52f3 update doc link (#552) 2021-03-08 17:56:03 +08:00
kingxt
1e4e5a02b2 rename (#543) 2021-03-04 17:13:07 +08:00
Kevin Wan
39540e21d2 fix golint issues (#540) 2021-03-03 17:16:09 +08:00
hexiaoen
b321622c95 暴露redis EvalSha 以及ScriptLoad接口 (#538)
Co-authored-by: shanehe <shanehe@zego.im>
2021-03-03 17:09:27 +08:00
kingxt
a25cba5380 fix collection breaker (#537)
* fix collection breaker

* optimized

* optimized

* optimized
2021-03-03 10:44:29 +08:00
146 changed files with 1982 additions and 873 deletions

View File

@@ -122,8 +122,7 @@ func BenchmarkGoogleBreaker(b *testing.B) {
}
}
type mockedPromise struct {
}
type mockedPromise struct{}
func (m *mockedPromise) Accept() {
}

View File

@@ -40,7 +40,7 @@ func TestAesEcbBase64(t *testing.T) {
// more than 32 chars
badKey2 = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
)
var key = []byte("q4t7w!z%C*F-JaNdRgUjXn2r5u8x/A?D")
key := []byte("q4t7w!z%C*F-JaNdRgUjXn2r5u8x/A?D")
b64Key := base64.StdEncoding.EncodeToString(key)
b64Val := base64.StdEncoding.EncodeToString([]byte(val))
_, err := EcbEncryptBase64(badKey1, val)

View File

@@ -139,7 +139,7 @@ func TestRollingWindowBucketTimeBoundary(t *testing.T) {
func TestRollingWindowDataRace(t *testing.T) {
const size = 3
r := NewRollingWindow(size, duration)
var stop = make(chan bool)
stop := make(chan bool)
go func() {
for {
select {

View File

@@ -25,7 +25,7 @@ func LoadConfig(file string, v interface{}, opts ...Option) error {
loader, ok := loaders[path.Ext(file)]
if !ok {
return fmt.Errorf("unrecoginized file type: %s", file)
return fmt.Errorf("unrecognized file type: %s", file)
}
var opt options

View File

@@ -1,19 +0,0 @@
package contextx
import (
"context"
"time"
)
// ShrinkDeadline returns a new Context with proper deadline base on the given ctx and timeout.
// And returns a cancel function as well.
func ShrinkDeadline(ctx context.Context, timeout time.Duration) (context.Context, func()) {
if deadline, ok := ctx.Deadline(); ok {
leftTime := time.Until(deadline)
if leftTime < timeout {
timeout = leftTime
}
}
return context.WithDeadline(ctx, time.Now().Add(timeout))
}

View File

@@ -1,31 +0,0 @@
package contextx
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestShrinkDeadlineLess(t *testing.T) {
deadline := time.Now().Add(time.Second)
ctx, cancel := context.WithDeadline(context.Background(), deadline)
defer cancel()
ctx, cancel = ShrinkDeadline(ctx, time.Minute)
defer cancel()
dl, ok := ctx.Deadline()
assert.True(t, ok)
assert.Equal(t, deadline, dl)
}
func TestShrinkDeadlineMore(t *testing.T) {
deadline := time.Now().Add(time.Minute)
ctx, cancel := context.WithDeadline(context.Background(), deadline)
defer cancel()
ctx, cancel = ShrinkDeadline(ctx, time.Second)
defer cancel()
dl, ok := ctx.Deadline()
assert.True(t, ok)
assert.True(t, dl.Before(deadline))
}

View File

@@ -8,7 +8,7 @@ import (
)
func TestChain(t *testing.T) {
var errDummy = errors.New("dummy")
errDummy := errors.New("dummy")
assert.Nil(t, Chain(func() error {
return nil
}, func() error {

View File

@@ -15,7 +15,7 @@ type (
// DoWithRetry runs fn, and retries if failed. Default to retry 3 times.
func DoWithRetry(fn func() error, opts ...RetryOption) error {
var options = newRetryOptions()
options := newRetryOptions()
for _, opt := range opts {
opt(options)
}

View File

@@ -30,7 +30,7 @@ func TestRetry(t *testing.T) {
return errors.New("any")
}))
var total = 2 * defaultRetryTimes
total := 2 * defaultRetryTimes
times = 0
assert.Nil(t, DoWithRetry(func() error {
times++

View File

@@ -3,8 +3,6 @@ package fx
import (
"context"
"time"
"github.com/tal-tech/go-zero/core/contextx"
)
var (
@@ -23,10 +21,11 @@ func DoWithTimeout(fn func() error, timeout time.Duration, opts ...DoOption) err
for _, opt := range opts {
parentCtx = opt()
}
ctx, cancel := contextx.ShrinkDeadline(parentCtx, timeout)
ctx, cancel := context.WithTimeout(parentCtx, timeout)
defer cancel()
done := make(chan error)
// create channel with buffer size 1 to avoid goroutine leak
done := make(chan error, 1)
panicChan := make(chan interface{}, 1)
go func() {
defer func() {
@@ -35,7 +34,6 @@ func DoWithTimeout(fn func() error, timeout time.Duration, opts ...DoOption) err
}
}()
done <- fn()
close(done)
}()
select {

View File

@@ -140,7 +140,7 @@ func (h *ConsistentHash) Remove(node interface{}) {
index := sort.Search(len(h.keys), func(i int) bool {
return h.keys[i] >= hash
})
if index < len(h.keys) {
if index < len(h.keys) && h.keys[index] == hash {
h.keys = append(h.keys[:index], h.keys[index+1:]...)
}
h.removeRingNode(hash, nodeRepr)

View File

@@ -168,7 +168,7 @@ func (as *adaptiveShedder) maxPass() int64 {
}
func (as *adaptiveShedder) minRt() float64 {
var result = defaultMinRt
result := defaultMinRt
as.rtCounter.Reduce(func(b *collection.Bucket) {
if b.Count <= 0 {

View File

@@ -201,7 +201,7 @@ func BenchmarkAdaptiveShedder_Allow(b *testing.B) {
logx.Disable()
bench := func(b *testing.B) {
var shedder = NewAdaptiveShedder()
shedder := NewAdaptiveShedder()
proba := mathx.NewProba()
for i := 0; i < 6000; i++ {
p, err := shedder.Allow()

View File

@@ -1,7 +1,6 @@
package load
type nopShedder struct {
}
type nopShedder struct{}
func newNopShedder() Shedder {
return nopShedder{}
@@ -11,8 +10,7 @@ func (s nopShedder) Allow() (Promise, error) {
return nopPromise{}, nil
}
type nopPromise struct {
}
type nopPromise struct{}
func (p nopPromise) Pass() {
}

View File

@@ -4,6 +4,7 @@ package logx
type LogConf struct {
ServiceName string `json:",optional"`
Mode string `json:",default=console,options=console|file|volume"`
TimeFormat string `json:",optional"`
Path string `json:",default=logs"`
Level string `json:",default=info,options=info|error|severe"`
Compress bool `json:",optional"`

View File

@@ -32,8 +32,6 @@ const (
)
const (
timeFormat = "2006-01-02T15:04:05.000Z07"
accessFilename = "access.log"
errorFilename = "error.log"
severeFilename = "severe.log"
@@ -64,6 +62,7 @@ var (
// ErrLogServiceNameNotSet is an error that indicates that the service name is not set.
ErrLogServiceNameNotSet = errors.New("log service name must be set")
timeFormat = "2006-01-02T15:04:05.000Z07"
writeConsole bool
logLevel uint32
infoLog io.WriteCloser
@@ -117,6 +116,10 @@ func MustSetup(c LogConf) {
// we need to allow different service frameworks to initialize logx respectively.
// the same logic for SetUp
func SetUp(c LogConf) error {
if len(c.TimeFormat) > 0 {
timeFormat = c.TimeFormat
}
switch c.Mode {
case consoleMode:
setupWithConsole(c)

View File

@@ -22,8 +22,8 @@ const (
dateFormat = "2006-01-02"
hoursPerDay = 24
bufferSize = 100
defaultDirMode = 0755
defaultFileMode = 0600
defaultDirMode = 0o755
defaultFileMode = 0o600
)
// ErrLogFileClosed is an error that indicates the log file is already closed.

View File

@@ -752,7 +752,7 @@ func TestUnmarshalJsonNumberInt64(t *testing.T) {
for i := 0; i <= maxUintBitsToTest; i++ {
var intValue int64 = 1 << uint(i)
strValue := strconv.FormatInt(intValue, 10)
var number = json.Number(strValue)
number := json.Number(strValue)
m := map[string]interface{}{
"ID": number,
}
@@ -768,7 +768,7 @@ func TestUnmarshalJsonNumberUint64(t *testing.T) {
for i := 0; i <= maxUintBitsToTest; i++ {
var intValue uint64 = 1 << uint(i)
strValue := strconv.FormatUint(intValue, 10)
var number = json.Number(strValue)
number := json.Number(strValue)
m := map[string]interface{}{
"ID": number,
}
@@ -784,7 +784,7 @@ func TestUnmarshalJsonNumberUint64Ptr(t *testing.T) {
for i := 0; i <= maxUintBitsToTest; i++ {
var intValue uint64 = 1 << uint(i)
strValue := strconv.FormatUint(intValue, 10)
var number = json.Number(strValue)
number := json.Number(strValue)
m := map[string]interface{}{
"ID": number,
}

View File

@@ -170,6 +170,28 @@ func implicitValueRequiredStruct(tag string, tp reflect.Type) (bool, error) {
return false, nil
}
func isLeftInclude(b byte) (bool, error) {
switch b {
case '[':
return true, nil
case '(':
return false, nil
default:
return false, errNumberRange
}
}
func isRightInclude(b byte) (bool, error) {
switch b {
case ']':
return true, nil
case ')':
return false, nil
default:
return false, errNumberRange
}
}
func maybeNewValue(field reflect.StructField, value reflect.Value) {
if field.Type.Kind() == reflect.Ptr && value.IsNil() {
value.Set(reflect.New(value.Type().Elem()))
@@ -211,14 +233,9 @@ func parseNumberRange(str string) (*numberRange, error) {
return nil, errNumberRange
}
var leftInclude bool
switch str[0] {
case '[':
leftInclude = true
case '(':
leftInclude = false
default:
return nil, errNumberRange
leftInclude, err := isLeftInclude(str[0])
if err != nil {
return nil, err
}
str = str[1:]
@@ -226,14 +243,9 @@ func parseNumberRange(str string) (*numberRange, error) {
return nil, errNumberRange
}
var rightInclude bool
switch str[len(str)-1] {
case ']':
rightInclude = true
case ')':
rightInclude = false
default:
return nil, errNumberRange
rightInclude, err := isRightInclude(str[len(str)-1])
if err != nil {
return nil, err
}
str = str[:len(str)-1]

View File

@@ -16,8 +16,8 @@ type Foo struct {
}
func TestDeferInt(t *testing.T) {
var i = 1
var s = "hello"
i := 1
s := "hello"
number := struct {
f float64
}{

View File

@@ -36,7 +36,7 @@ func (u Unstable) AroundDuration(base time.Duration) time.Duration {
return val
}
// AroundInt returns a randome int64 with given base and deviation.
// AroundInt returns a random int64 with given base and deviation.
func (u Unstable) AroundInt(base int64) int64 {
u.lock.Lock()
val := int64((1 + u.deviation - 2*u.deviation*u.r.Float64()) * float64(base))

View File

@@ -12,7 +12,7 @@ type (
// CounterVec interface represents a counter vector.
CounterVec interface {
// Inc increments labels.
Inc(lables ...string)
Inc(labels ...string)
// Add adds labels with v.
Add(v float64, labels ...string)
close() bool
@@ -50,8 +50,8 @@ func (cv *promCounterVec) Inc(labels ...string) {
cv.counter.WithLabelValues(labels...).Inc()
}
func (cv *promCounterVec) Add(v float64, lables ...string) {
cv.counter.WithLabelValues(lables...).Add(v)
func (cv *promCounterVec) Add(v float64, labels ...string) {
cv.counter.WithLabelValues(labels...).Add(v)
}
func (cv *promCounterVec) close() bool {

View File

@@ -20,7 +20,7 @@ type (
close() bool
}
promGuageVec struct {
promGaugeVec struct {
gauge *prom.GaugeVec
}
)
@@ -39,7 +39,7 @@ func NewGaugeVec(cfg *GaugeVecOpts) GaugeVec {
Help: cfg.Help,
}, cfg.Labels)
prom.MustRegister(vec)
gv := &promGuageVec{
gv := &promGaugeVec{
gauge: vec,
}
proc.AddShutdownListener(func() {
@@ -49,18 +49,18 @@ func NewGaugeVec(cfg *GaugeVecOpts) GaugeVec {
return gv
}
func (gv *promGuageVec) Inc(labels ...string) {
func (gv *promGaugeVec) Inc(labels ...string) {
gv.gauge.WithLabelValues(labels...).Inc()
}
func (gv *promGuageVec) Add(v float64, lables ...string) {
gv.gauge.WithLabelValues(lables...).Add(v)
func (gv *promGaugeVec) Add(v float64, labels ...string) {
gv.gauge.WithLabelValues(labels...).Add(v)
}
func (gv *promGuageVec) Set(v float64, lables ...string) {
gv.gauge.WithLabelValues(lables...).Set(v)
func (gv *promGaugeVec) Set(v float64, labels ...string) {
gv.gauge.WithLabelValues(labels...).Set(v)
}
func (gv *promGuageVec) close() bool {
func (gv *promGaugeVec) close() bool {
return prom.Unregister(gv.gauge)
}

View File

@@ -29,7 +29,7 @@ func TestGaugeInc(t *testing.T) {
Labels: []string{"path"},
})
defer gaugeVec.close()
gv, _ := gaugeVec.(*promGuageVec)
gv, _ := gaugeVec.(*promGaugeVec)
gv.Inc("/users")
gv.Inc("/users")
r := testutil.ToFloat64(gv.gauge)
@@ -45,7 +45,7 @@ func TestGaugeAdd(t *testing.T) {
Labels: []string{"path"},
})
defer gaugeVec.close()
gv, _ := gaugeVec.(*promGuageVec)
gv, _ := gaugeVec.(*promGaugeVec)
gv.Add(-10, "/classroom")
gv.Add(30, "/classroom")
r := testutil.ToFloat64(gv.gauge)
@@ -61,7 +61,7 @@ func TestGaugeSet(t *testing.T) {
Labels: []string{"path"},
})
gaugeVec.close()
gv, _ := gaugeVec.(*promGuageVec)
gv, _ := gaugeVec.(*promGaugeVec)
gv.Set(666, "/users")
r := testutil.ToFloat64(gv.gauge)
assert.Equal(t, float64(666), r)

View File

@@ -19,7 +19,7 @@ type (
// A HistogramVec interface represents a histogram vector.
HistogramVec interface {
// Observe adds observation v to labels.
Observe(v int64, lables ...string)
Observe(v int64, labels ...string)
close() bool
}

View File

@@ -101,7 +101,7 @@ func Map(generate GenerateFunc, mapper MapFunc, opts ...Option) chan interface{}
}
// MapReduce maps all elements generated from given generate func,
// and reduces the output elemenets with given reducer.
// and reduces the output elements with given reducer.
func MapReduce(generate GenerateFunc, mapper MapperFunc, reducer ReducerFunc, opts ...Option) (interface{}, error) {
source := buildSource(generate)
return MapReduceWithSource(source, mapper, reducer, opts...)

View File

@@ -7,10 +7,19 @@ import (
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/tal-tech/go-zero/core/logx"
"github.com/tal-tech/go-zero/core/syncx"
"github.com/tal-tech/go-zero/core/threading"
)
var once sync.Once
var (
once sync.Once
enabled syncx.AtomicBool
)
// Enabled returns if prometheus is enabled.
func Enabled() bool {
return enabled.True()
}
// StartAgent starts a prometheus agent.
func StartAgent(c Config) {
@@ -19,6 +28,7 @@ func StartAgent(c Config) {
return
}
enabled.Set(true)
threading.GoSafe(func() {
http.Handle(c.Path, promhttp.Handler())
addr := fmt.Sprintf("%s:%d", c.Host, c.Port)

View File

@@ -84,8 +84,7 @@ func (p *mockedProducer) Produce() (string, bool) {
return "", false
}
type mockedListener struct {
}
type mockedListener struct{}
func (l *mockedListener) OnPause() {
}

View File

@@ -14,6 +14,8 @@ const (
DevMode = "dev"
// TestMode means test mode.
TestMode = "test"
// RtMode means regression test mode.
RtMode = "rt"
// PreMode means pre-release mode.
PreMode = "pre"
// ProMode means production mode.
@@ -56,7 +58,7 @@ func (sc ServiceConf) SetUp() error {
func (sc ServiceConf) initMode() {
switch sc.Mode {
case DevMode, TestMode, PreMode:
case DevMode, TestMode, RtMode, PreMode:
load.Disable()
stat.SetReporter(nil)
}

View File

@@ -95,8 +95,7 @@ func WithStarter(start Starter) Service {
}
type (
stopper struct {
}
stopper struct{}
startOnlyService struct {
start func()

View File

@@ -70,7 +70,7 @@ func TestServiceGroup_WithStart(t *testing.T) {
wait.Add(len(multipliers))
group := NewServiceGroup()
for _, multiplier := range multipliers {
var mul = multiplier
mul := multiplier
group.Add(WithStart(func() {
lock.Lock()
want *= mul
@@ -97,7 +97,7 @@ func TestServiceGroup_WithStarter(t *testing.T) {
wait.Add(len(multipliers))
group := NewServiceGroup()
for _, multiplier := range multipliers {
var mul = multiplier
mul := multiplier
group.Add(WithStarter(mockedStarter{
fn: func() {
lock.Lock()

View File

@@ -59,7 +59,7 @@ func NewNode(rds *redis.Redis, barrier syncx.SharedCalls, st *Stat,
}
}
// DelCache deletes cached values with keys.
// Del deletes cached values with keys.
func (c cacheNode) Del(keys ...string) error {
if len(keys) == 0 {
return nil
@@ -73,7 +73,7 @@ func (c cacheNode) Del(keys ...string) error {
return nil
}
// GetCache gets the cache with key and fills into v.
// Get gets the cache with key and fills into v.
func (c cacheNode) Get(key string, v interface{}) error {
err := c.doGetCache(key, v)
if err == errPlaceholder {
@@ -88,12 +88,12 @@ func (c cacheNode) IsNotFound(err error) bool {
return err == c.errNotFound
}
// SetCache sets the cache with key and v, using c.expiry.
// Set sets the cache with key and v, using c.expiry.
func (c cacheNode) Set(key string, v interface{}) error {
return c.SetWithExpire(key, v, c.aroundDuration(c.expiry))
}
// SetCacheWithExpire sets the cache with key and v, using given expire.
// SetWithExpire sets the cache with key and v, using given expire.
func (c cacheNode) SetWithExpire(key string, v interface{}, expire time.Duration) error {
data, err := jsonx.Marshal(v)
if err != nil {
@@ -108,7 +108,7 @@ func (c cacheNode) String() string {
return c.rds.Addr
}
// TakeWithExpire takes the result from cache first, if not found,
// Take takes the result from cache first, if not found,
// query from DB and set cache using c.expiry, then return the result.
func (c cacheNode) Take(v interface{}, key string, query func(v interface{}) error) error {
return c.doTake(v, key, query, func(v interface{}) error {

View File

@@ -129,7 +129,7 @@ func TestCacheNode_TakeNotFound(t *testing.T) {
assert.True(t, cn.IsNotFound(cn.Get("any", &str)))
store.Del("any")
var errDummy = errors.New("dummy")
errDummy := errors.New("dummy")
err = cn.Take(&str, "any", func(v interface{}) error {
return errDummy
})

View File

@@ -12,8 +12,10 @@ import (
"github.com/tal-tech/go-zero/core/stringx"
)
var s1, _ = miniredis.Run()
var s2, _ = miniredis.Run()
var (
s1, _ = miniredis.Run()
s2, _ = miniredis.Run()
)
func TestRedis_Exists(t *testing.T) {
store := clusterStore{dispatcher: hash.NewConsistentHash()}

View File

@@ -43,11 +43,11 @@ type (
}
)
func newCollection(collection *mgo.Collection) Collection {
func newCollection(collection *mgo.Collection, brk breaker.Breaker) Collection {
return &decoratedCollection{
name: collection.FullName,
collection: collection,
brk: breaker.NewBreaker(),
brk: brk,
}
}

View File

@@ -71,7 +71,7 @@ func TestNewCollection(t *testing.T) {
Database: nil,
Name: "foo",
FullName: "bar",
})
}, breaker.GetBreaker("localhost"))
assert.Equal(t, "bar", col.(*decoratedCollection).name)
}
@@ -279,8 +279,7 @@ func (p *mockPromise) Reject(reason string) {
p.reason = reason
}
type dropBreaker struct {
}
type dropBreaker struct{}
func (d *dropBreaker) Name() string {
return "dummy"

View File

@@ -5,6 +5,7 @@ import (
"time"
"github.com/globalsign/mgo"
"github.com/tal-tech/go-zero/core/breaker"
)
type (
@@ -20,6 +21,7 @@ type (
session *concurrentSession
db *mgo.Database
collection string
brk breaker.Breaker
opts []Option
}
)
@@ -46,6 +48,7 @@ func NewModel(url, collection string, opts ...Option) (*Model, error) {
// If name is empty, the database name provided in the dialed URL is used instead
db: session.DB(""),
collection: collection,
brk: breaker.GetBreaker(url),
opts: opts,
}, nil
}
@@ -66,7 +69,7 @@ func (mm *Model) FindId(id interface{}) (Query, error) {
// GetCollection returns a Collection with given session.
func (mm *Model) GetCollection(session *mgo.Session) Collection {
return newCollection(mm.db.C(mm.collection).With(session))
return newCollection(mm.db.C(mm.collection).With(session), mm.brk)
}
// Insert inserts docs into mm.

View File

@@ -17,6 +17,7 @@ type (
Host string
Type string `json:",default=node,options=node|cluster"`
Pass string `json:",optional"`
Tls bool `json:",default=false,options=true|false"`
}
// A RedisKeyConf is a redis config with key.
@@ -28,7 +29,18 @@ type (
// NewRedis returns a Redis.
func (rc RedisConf) NewRedis() *Redis {
return NewRedis(rc.Host, rc.Type, rc.Pass)
var opts []Option
if rc.Type == ClusterType {
opts = append(opts, Cluster())
}
if len(rc.Pass) > 0 {
opts = append(opts, WithPass(rc.Pass))
}
if rc.Tls {
opts = append(opts, WithTLS())
}
return New(rc.Host, opts...)
}
// Validate validates the RedisConf.

View File

@@ -29,6 +29,9 @@ const (
var ErrNilNode = errors.New("nil redis node")
type (
// Option defines the method to customize a Redis.
Option func(r *Redis)
// A Pair is a key/pair set used in redis zset.
Pair struct {
Key string
@@ -40,6 +43,7 @@ type (
Addr string
Type string
Pass string
tls bool
brk breaker.Breaker
}
@@ -69,19 +73,32 @@ type (
FloatCmd = red.FloatCmd
)
// NewRedis returns a Redis.
func NewRedis(redisAddr, redisType string, redisPass ...string) *Redis {
var pass string
for _, v := range redisPass {
pass = v
}
return &Redis{
Addr: redisAddr,
Type: redisType,
Pass: pass,
// New returns a Redis with given options.
func New(addr string, opts ...Option) *Redis {
r := &Redis{
Addr: addr,
Type: NodeType,
brk: breaker.NewBreaker(),
}
for _, opt := range opts {
opt(r)
}
return r
}
// NewRedis returns a Redis.
func NewRedis(redisAddr, redisType string, redisPass ...string) *Redis {
var opts []Option
if redisType == ClusterType {
opts = append(opts, Cluster())
}
for _, v := range redisPass {
opts = append(opts, WithPass(v))
}
return New(redisAddr, opts...)
}
// BitCount is redis bitcount command implementation.
@@ -250,6 +267,21 @@ func (s *Redis) Eval(script string, keys []string, args ...interface{}) (val int
return
}
// EvalSha is the implementation of redis evalsha command.
func (s *Redis) EvalSha(sha string, keys []string, args ...interface{}) (val interface{}, err error) {
err = s.brk.DoWithAcceptable(func() error {
conn, err := getRedis(s)
if err != nil {
return err
}
val, err = conn.EvalSha(sha, keys, args...).Result()
return err
}, acceptable)
return
}
// Exists is the implementation of redis exists command.
func (s *Redis) Exists(key string) (val bool, err error) {
err = s.brk.DoWithAcceptable(func() error {
@@ -449,14 +481,14 @@ func (s *Redis) GetBit(key string, offset int64) (val int, err error) {
}
// Hdel is the implementation of redis hdel command.
func (s *Redis) Hdel(key, field string) (val bool, err error) {
func (s *Redis) Hdel(key string, fields ...string) (val bool, err error) {
err = s.brk.DoWithAcceptable(func() error {
conn, err := getRedis(s)
if err != nil {
return err
}
v, err := conn.HDel(key, field).Result()
v, err := conn.HDel(key, fields...).Result()
if err != nil {
return err
}
@@ -913,7 +945,6 @@ func (s *Redis) Pipelined(fn func(Pipeliner) error) (err error) {
_, err = conn.Pipelined(fn)
return err
}, acceptable)
return
@@ -1032,6 +1063,16 @@ func (s *Redis) Scard(key string) (val int64, err error) {
return
}
// ScriptLoad is the implementation of redis script load command.
func (s *Redis) ScriptLoad(script string) (string, error) {
conn, err := getRedis(s)
if err != nil {
return "", err
}
return conn.ScriptLoad(script).Result()
}
// Set is the implementation of redis set command.
func (s *Redis) Set(key string, value string) error {
return s.brk.DoWithAcceptable(func() error {
@@ -1101,26 +1142,6 @@ func (s *Redis) Sismember(key string, value interface{}) (val bool, err error) {
return
}
// Srem is the implementation of redis srem command.
func (s *Redis) Srem(key string, values ...interface{}) (val int, err error) {
err = s.brk.DoWithAcceptable(func() error {
conn, err := getRedis(s)
if err != nil {
return err
}
v, err := conn.SRem(key, values...).Result()
if err != nil {
return err
}
val = int(v)
return nil
}, acceptable)
return
}
// Smembers is the implementation of redis smembers command.
func (s *Redis) Smembers(key string) (val []string, err error) {
err = s.brk.DoWithAcceptable(func() error {
@@ -1166,6 +1187,31 @@ func (s *Redis) Srandmember(key string, count int) (val []string, err error) {
return
}
// Srem is the implementation of redis srem command.
func (s *Redis) Srem(key string, values ...interface{}) (val int, err error) {
err = s.brk.DoWithAcceptable(func() error {
conn, err := getRedis(s)
if err != nil {
return err
}
v, err := conn.SRem(key, values...).Result()
if err != nil {
return err
}
val = int(v)
return nil
}, acceptable)
return
}
// String returns the string representation of s.
func (s *Redis) String() string {
return s.Addr
}
// Sunion is the implementation of redis sunion command.
func (s *Redis) Sunion(keys ...string) (val []string, err error) {
err = s.brk.DoWithAcceptable(func() error {
@@ -1667,18 +1713,25 @@ func (s *Redis) Zunionstore(dest string, store ZStore, keys ...string) (val int6
return
}
// String returns the string representation of s.
func (s *Redis) String() string {
return s.Addr
// Cluster customizes the given Redis as a cluster.
func Cluster() Option {
return func(r *Redis) {
r.Type = ClusterType
}
}
func (s *Redis) scriptLoad(script string) (string, error) {
conn, err := getRedis(s)
if err != nil {
return "", err
// WithPass customizes the given Redis with given password.
func WithPass(pass string) Option {
return func(r *Redis) {
r.Pass = pass
}
}
return conn.ScriptLoad(script).Result()
// WithTLS customizes the given Redis with TLS enabled.
func WithTLS() Option {
return func(r *Redis) {
r.tls = true
}
}
func acceptable(err error) bool {
@@ -1688,9 +1741,9 @@ func acceptable(err error) bool {
func getRedis(r *Redis) (RedisNode, error) {
switch r.Type {
case ClusterType:
return getCluster(r.Addr, r.Pass)
return getCluster(r)
case NodeType:
return getClient(r.Addr, r.Pass)
return getClient(r)
default:
return nil, fmt.Errorf("redis type '%s' is not supported", r.Type)
}

View File

@@ -1,6 +1,7 @@
package redis
import (
"crypto/tls"
"errors"
"io"
"strconv"
@@ -14,7 +15,7 @@ import (
func TestRedis_Exists(t *testing.T) {
runOnRedis(t, func(client *Redis) {
_, err := NewRedis(client.Addr, "").Exists("a")
_, err := New(client.Addr, badType()).Exists("a")
assert.NotNil(t, err)
ok, err := client.Exists("a")
assert.Nil(t, err)
@@ -26,9 +27,23 @@ func TestRedis_Exists(t *testing.T) {
})
}
func TestRedisTLS_Exists(t *testing.T) {
runOnRedisTLS(t, func(client *Redis) {
_, err := New(client.Addr, badType()).Exists("a")
assert.NotNil(t, err)
ok, err := client.Exists("a")
assert.NotNil(t, err)
assert.False(t, ok)
assert.NotNil(t, client.Set("a", "b"))
ok, err = client.Exists("a")
assert.NotNil(t, err)
assert.False(t, ok)
})
}
func TestRedis_Eval(t *testing.T) {
runOnRedis(t, func(client *Redis) {
_, err := NewRedis(client.Addr, "").Eval(`redis.call("EXISTS", KEYS[1])`, []string{"notexist"})
_, err := New(client.Addr, badType()).Eval(`redis.call("EXISTS", KEYS[1])`, []string{"notexist"})
assert.NotNil(t, err)
_, err = client.Eval(`redis.call("EXISTS", KEYS[1])`, []string{"notexist"})
assert.Equal(t, Nil, err)
@@ -53,7 +68,7 @@ func TestRedis_Hgetall(t *testing.T) {
runOnRedis(t, func(client *Redis) {
assert.Nil(t, client.Hset("a", "aa", "aaa"))
assert.Nil(t, client.Hset("a", "bb", "bbb"))
_, err := NewRedis(client.Addr, "").Hgetall("a")
_, err := New(client.Addr, badType()).Hgetall("a")
assert.NotNil(t, err)
vals, err := client.Hgetall("a")
assert.Nil(t, err)
@@ -66,10 +81,10 @@ func TestRedis_Hgetall(t *testing.T) {
func TestRedis_Hvals(t *testing.T) {
runOnRedis(t, func(client *Redis) {
assert.NotNil(t, NewRedis(client.Addr, "").Hset("a", "aa", "aaa"))
assert.NotNil(t, New(client.Addr, badType()).Hset("a", "aa", "aaa"))
assert.Nil(t, client.Hset("a", "aa", "aaa"))
assert.Nil(t, client.Hset("a", "bb", "bbb"))
_, err := NewRedis(client.Addr, "").Hvals("a")
_, err := New(client.Addr, badType()).Hvals("a")
assert.NotNil(t, err)
vals, err := client.Hvals("a")
assert.Nil(t, err)
@@ -81,7 +96,7 @@ func TestRedis_Hsetnx(t *testing.T) {
runOnRedis(t, func(client *Redis) {
assert.Nil(t, client.Hset("a", "aa", "aaa"))
assert.Nil(t, client.Hset("a", "bb", "bbb"))
_, err := NewRedis(client.Addr, "").Hsetnx("a", "bb", "ccc")
_, err := New(client.Addr, badType()).Hsetnx("a", "bb", "ccc")
assert.NotNil(t, err)
ok, err := client.Hsetnx("a", "bb", "ccc")
assert.Nil(t, err)
@@ -99,7 +114,7 @@ func TestRedis_HdelHlen(t *testing.T) {
runOnRedis(t, func(client *Redis) {
assert.Nil(t, client.Hset("a", "aa", "aaa"))
assert.Nil(t, client.Hset("a", "bb", "bbb"))
_, err := NewRedis(client.Addr, "").Hlen("a")
_, err := New(client.Addr, badType()).Hlen("a")
assert.NotNil(t, err)
num, err := client.Hlen("a")
assert.Nil(t, err)
@@ -115,7 +130,7 @@ func TestRedis_HdelHlen(t *testing.T) {
func TestRedis_HIncrBy(t *testing.T) {
runOnRedis(t, func(client *Redis) {
_, err := NewRedis(client.Addr, "").Hincrby("key", "field", 2)
_, err := New(client.Addr, badType()).Hincrby("key", "field", 2)
assert.NotNil(t, err)
val, err := client.Hincrby("key", "field", 2)
assert.Nil(t, err)
@@ -130,7 +145,7 @@ func TestRedis_Hkeys(t *testing.T) {
runOnRedis(t, func(client *Redis) {
assert.Nil(t, client.Hset("a", "aa", "aaa"))
assert.Nil(t, client.Hset("a", "bb", "bbb"))
_, err := NewRedis(client.Addr, "").Hkeys("a")
_, err := New(client.Addr, badType()).Hkeys("a")
assert.NotNil(t, err)
vals, err := client.Hkeys("a")
assert.Nil(t, err)
@@ -142,7 +157,7 @@ func TestRedis_Hmget(t *testing.T) {
runOnRedis(t, func(client *Redis) {
assert.Nil(t, client.Hset("a", "aa", "aaa"))
assert.Nil(t, client.Hset("a", "bb", "bbb"))
_, err := NewRedis(client.Addr, "").Hmget("a", "aa", "bb")
_, err := New(client.Addr, badType()).Hmget("a", "aa", "bb")
assert.NotNil(t, err)
vals, err := client.Hmget("a", "aa", "bb")
assert.Nil(t, err)
@@ -155,7 +170,7 @@ func TestRedis_Hmget(t *testing.T) {
func TestRedis_Hmset(t *testing.T) {
runOnRedis(t, func(client *Redis) {
assert.NotNil(t, NewRedis(client.Addr, "").Hmset("a", nil))
assert.NotNil(t, New(client.Addr, badType()).Hmset("a", nil))
assert.Nil(t, client.Hmset("a", map[string]string{
"aa": "aaa",
"bb": "bbb",
@@ -179,7 +194,7 @@ func TestRedis_Hscan(t *testing.T) {
var cursor uint64 = 0
sum := 0
for {
_, _, err := NewRedis(client.Addr, "").Hscan(key, cursor, "*", 100)
_, _, err := New(client.Addr, badType()).Hscan(key, cursor, "*", 100)
assert.NotNil(t, err)
reMap, next, err := client.Hscan(key, cursor, "*", 100)
assert.Nil(t, err)
@@ -191,7 +206,7 @@ func TestRedis_Hscan(t *testing.T) {
}
assert.Equal(t, sum, 3100)
_, err = NewRedis(client.Addr, "").Del(key)
_, err = New(client.Addr, badType()).Del(key)
assert.NotNil(t, err)
_, err = client.Del(key)
assert.Nil(t, err)
@@ -200,7 +215,7 @@ func TestRedis_Hscan(t *testing.T) {
func TestRedis_Incr(t *testing.T) {
runOnRedis(t, func(client *Redis) {
_, err := NewRedis(client.Addr, "").Incr("a")
_, err := New(client.Addr, badType()).Incr("a")
assert.NotNil(t, err)
val, err := client.Incr("a")
assert.Nil(t, err)
@@ -213,7 +228,7 @@ func TestRedis_Incr(t *testing.T) {
func TestRedis_IncrBy(t *testing.T) {
runOnRedis(t, func(client *Redis) {
_, err := NewRedis(client.Addr, "").Incrby("a", 2)
_, err := New(client.Addr, badType()).Incrby("a", 2)
assert.NotNil(t, err)
val, err := client.Incrby("a", 2)
assert.Nil(t, err)
@@ -230,7 +245,7 @@ func TestRedis_Keys(t *testing.T) {
assert.Nil(t, err)
err = client.Set("key2", "value2")
assert.Nil(t, err)
_, err = NewRedis(client.Addr, "").Keys("*")
_, err = New(client.Addr, badType()).Keys("*")
assert.NotNil(t, err)
keys, err := client.Keys("*")
assert.Nil(t, err)
@@ -241,7 +256,7 @@ func TestRedis_Keys(t *testing.T) {
func TestRedis_HyperLogLog(t *testing.T) {
runOnRedis(t, func(client *Redis) {
client.Ping()
r := NewRedis(client.Addr, "")
r := New(client.Addr, badType())
_, err := r.Pfadd("key1")
assert.NotNil(t, err)
_, err = r.Pfcount("*")
@@ -253,17 +268,17 @@ func TestRedis_HyperLogLog(t *testing.T) {
func TestRedis_List(t *testing.T) {
runOnRedis(t, func(client *Redis) {
_, err := NewRedis(client.Addr, "").Lpush("key", "value1", "value2")
_, err := New(client.Addr, badType()).Lpush("key", "value1", "value2")
assert.NotNil(t, err)
val, err := client.Lpush("key", "value1", "value2")
assert.Nil(t, err)
assert.Equal(t, 2, val)
_, err = NewRedis(client.Addr, "").Rpush("key", "value3", "value4")
_, err = New(client.Addr, badType()).Rpush("key", "value3", "value4")
assert.NotNil(t, err)
val, err = client.Rpush("key", "value3", "value4")
assert.Nil(t, err)
assert.Equal(t, 4, val)
_, err = NewRedis(client.Addr, "").Llen("key")
_, err = New(client.Addr, badType()).Llen("key")
assert.NotNil(t, err)
val, err = client.Llen("key")
assert.Nil(t, err)
@@ -271,7 +286,7 @@ func TestRedis_List(t *testing.T) {
vals, err := client.Lrange("key", 0, 10)
assert.Nil(t, err)
assert.EqualValues(t, []string{"value2", "value1", "value3", "value4"}, vals)
_, err = NewRedis(client.Addr, "").Lpop("key")
_, err = New(client.Addr, badType()).Lpop("key")
assert.NotNil(t, err)
v, err := client.Lpop("key")
assert.Nil(t, err)
@@ -279,7 +294,7 @@ func TestRedis_List(t *testing.T) {
val, err = client.Lpush("key", "value1", "value2")
assert.Nil(t, err)
assert.Equal(t, 5, val)
_, err = NewRedis(client.Addr, "").Rpop("key")
_, err = New(client.Addr, badType()).Rpop("key")
assert.NotNil(t, err)
v, err = client.Rpop("key")
assert.Nil(t, err)
@@ -287,12 +302,12 @@ func TestRedis_List(t *testing.T) {
val, err = client.Rpush("key", "value4", "value3", "value3")
assert.Nil(t, err)
assert.Equal(t, 7, val)
_, err = NewRedis(client.Addr, "").Lrem("key", 2, "value1")
_, err = New(client.Addr, badType()).Lrem("key", 2, "value1")
assert.NotNil(t, err)
n, err := client.Lrem("key", 2, "value1")
assert.Nil(t, err)
assert.Equal(t, 2, n)
_, err = NewRedis(client.Addr, "").Lrange("key", 0, 10)
_, err = New(client.Addr, badType()).Lrange("key", 0, 10)
assert.NotNil(t, err)
vals, err = client.Lrange("key", 0, 10)
assert.Nil(t, err)
@@ -312,7 +327,7 @@ func TestRedis_Mget(t *testing.T) {
assert.Nil(t, err)
err = client.Set("key2", "value2")
assert.Nil(t, err)
_, err = NewRedis(client.Addr, "").Mget("key1", "key0", "key2", "key3")
_, err = New(client.Addr, badType()).Mget("key1", "key0", "key2", "key3")
assert.NotNil(t, err)
vals, err := client.Mget("key1", "key0", "key2", "key3")
assert.Nil(t, err)
@@ -322,7 +337,7 @@ func TestRedis_Mget(t *testing.T) {
func TestRedis_SetBit(t *testing.T) {
runOnRedis(t, func(client *Redis) {
err := NewRedis(client.Addr, "").SetBit("key", 1, 1)
err := New(client.Addr, badType()).SetBit("key", 1, 1)
assert.NotNil(t, err)
err = client.SetBit("key", 1, 1)
assert.Nil(t, err)
@@ -333,7 +348,7 @@ func TestRedis_GetBit(t *testing.T) {
runOnRedis(t, func(client *Redis) {
err := client.SetBit("key", 2, 1)
assert.Nil(t, err)
_, err = NewRedis(client.Addr, "").GetBit("key", 2)
_, err = New(client.Addr, badType()).GetBit("key", 2)
assert.NotNil(t, err)
val, err := client.GetBit("key", 2)
assert.Nil(t, err)
@@ -348,7 +363,7 @@ func TestRedis_BitCount(t *testing.T) {
assert.Nil(t, err)
}
_, err := NewRedis(client.Addr, "").BitCount("key", 0, -1)
_, err := New(client.Addr, badType()).BitCount("key", 0, -1)
assert.NotNil(t, err)
val, err := client.BitCount("key", 0, -1)
assert.Nil(t, err)
@@ -369,7 +384,6 @@ func TestRedis_BitCount(t *testing.T) {
val, err = client.BitCount("key", 2, 2)
assert.Nil(t, err)
assert.Equal(t, int64(0), val)
})
}
@@ -379,14 +393,14 @@ func TestRedis_BitOpAnd(t *testing.T) {
assert.Nil(t, err)
err = client.Set("key2", "1")
assert.Nil(t, err)
_, err = NewRedis(client.Addr, "").BitOpAnd("destKey", "key1", "key2")
_, err = New(client.Addr, badType()).BitOpAnd("destKey", "key1", "key2")
assert.NotNil(t, err)
val, err := client.BitOpAnd("destKey", "key1", "key2")
assert.Nil(t, err)
assert.Equal(t, int64(1), val)
valStr, err := client.Get("destKey")
assert.Nil(t, err)
//destKey binary 110000 ascii 0
// destKey binary 110000 ascii 0
assert.Equal(t, "0", valStr)
})
}
@@ -395,7 +409,7 @@ func TestRedis_BitOpNot(t *testing.T) {
runOnRedis(t, func(client *Redis) {
err := client.Set("key1", "\u0000")
assert.Nil(t, err)
_, err = NewRedis(client.Addr, "").BitOpNot("destKey", "key1")
_, err = New(client.Addr, badType()).BitOpNot("destKey", "key1")
assert.NotNil(t, err)
val, err := client.BitOpNot("destKey", "key1")
assert.Nil(t, err)
@@ -412,7 +426,7 @@ func TestRedis_BitOpOr(t *testing.T) {
assert.Nil(t, err)
err = client.Set("key2", "0")
assert.Nil(t, err)
_, err = NewRedis(client.Addr, "").BitOpOr("destKey", "key1", "key2")
_, err = New(client.Addr, badType()).BitOpOr("destKey", "key1", "key2")
assert.NotNil(t, err)
val, err := client.BitOpOr("destKey", "key1", "key2")
assert.Nil(t, err)
@@ -429,7 +443,7 @@ func TestRedis_BitOpXor(t *testing.T) {
assert.Nil(t, err)
err = client.Set("key2", "\x0f")
assert.Nil(t, err)
_, err = NewRedis(client.Addr, "").BitOpXor("destKey", "key1", "key2")
_, err = New(client.Addr, badType()).BitOpXor("destKey", "key1", "key2")
assert.NotNil(t, err)
val, err := client.BitOpXor("destKey", "key1", "key2")
assert.Nil(t, err)
@@ -439,13 +453,14 @@ func TestRedis_BitOpXor(t *testing.T) {
assert.Equal(t, "\xf0", valStr)
})
}
func TestRedis_BitPos(t *testing.T) {
runOnRedis(t, func(client *Redis) {
//11111111 11110000 00000000
// 11111111 11110000 00000000
err := client.Set("key", "\xff\xf0\x00")
assert.Nil(t, err)
_, err = NewRedis(client.Addr, "").BitPos("key", 0, 0, -1)
_, err = New(client.Addr, badType()).BitPos("key", 0, 0, -1)
assert.NotNil(t, err)
val, err := client.BitPos("key", 0, 0, 2)
assert.Nil(t, err)
@@ -466,13 +481,12 @@ func TestRedis_BitPos(t *testing.T) {
val, err = client.BitPos("key", 1, 2, 2)
assert.Nil(t, err)
assert.Equal(t, int64(-1), val)
})
}
func TestRedis_Persist(t *testing.T) {
runOnRedis(t, func(client *Redis) {
_, err := NewRedis(client.Addr, "").Persist("key")
_, err := New(client.Addr, badType()).Persist("key")
assert.NotNil(t, err)
ok, err := client.Persist("key")
assert.Nil(t, err)
@@ -482,14 +496,14 @@ func TestRedis_Persist(t *testing.T) {
ok, err = client.Persist("key")
assert.Nil(t, err)
assert.False(t, ok)
err = NewRedis(client.Addr, "").Expire("key", 5)
err = New(client.Addr, badType()).Expire("key", 5)
assert.NotNil(t, err)
err = client.Expire("key", 5)
assert.Nil(t, err)
ok, err = client.Persist("key")
assert.Nil(t, err)
assert.True(t, ok)
err = NewRedis(client.Addr, "").Expireat("key", time.Now().Unix()+5)
err = New(client.Addr, badType()).Expireat("key", time.Now().Unix()+5)
assert.NotNil(t, err)
err = client.Expireat("key", time.Now().Unix()+5)
assert.Nil(t, err)
@@ -512,7 +526,7 @@ func TestRedis_Scan(t *testing.T) {
assert.Nil(t, err)
err = client.Set("key2", "value2")
assert.Nil(t, err)
_, _, err = NewRedis(client.Addr, "").Scan(0, "*", 100)
_, _, err = New(client.Addr, badType()).Scan(0, "*", 100)
assert.NotNil(t, err)
keys, _, err := client.Scan(0, "*", 100)
assert.Nil(t, err)
@@ -534,7 +548,7 @@ func TestRedis_Sscan(t *testing.T) {
var cursor uint64 = 0
sum := 0
for {
_, _, err := NewRedis(client.Addr, "").Sscan(key, cursor, "", 100)
_, _, err := New(client.Addr, badType()).Sscan(key, cursor, "", 100)
assert.NotNil(t, err)
keys, next, err := client.Sscan(key, cursor, "", 100)
assert.Nil(t, err)
@@ -546,7 +560,7 @@ func TestRedis_Sscan(t *testing.T) {
}
assert.Equal(t, sum, 1550)
_, err = NewRedis(client.Addr, "").Del(key)
_, err = New(client.Addr, badType()).Del(key)
assert.NotNil(t, err)
_, err = client.Del(key)
assert.Nil(t, err)
@@ -555,48 +569,48 @@ func TestRedis_Sscan(t *testing.T) {
func TestRedis_Set(t *testing.T) {
runOnRedis(t, func(client *Redis) {
_, err := NewRedis(client.Addr, "").Sadd("key", 1, 2, 3, 4)
_, err := New(client.Addr, badType()).Sadd("key", 1, 2, 3, 4)
assert.NotNil(t, err)
num, err := client.Sadd("key", 1, 2, 3, 4)
assert.Nil(t, err)
assert.Equal(t, 4, num)
_, err = NewRedis(client.Addr, "").Scard("key")
_, err = New(client.Addr, badType()).Scard("key")
assert.NotNil(t, err)
val, err := client.Scard("key")
assert.Nil(t, err)
assert.Equal(t, int64(4), val)
_, err = NewRedis(client.Addr, "").Sismember("key", 2)
_, err = New(client.Addr, badType()).Sismember("key", 2)
assert.NotNil(t, err)
ok, err := client.Sismember("key", 2)
assert.Nil(t, err)
assert.True(t, ok)
_, err = NewRedis(client.Addr, "").Srem("key", 3, 4)
_, err = New(client.Addr, badType()).Srem("key", 3, 4)
assert.NotNil(t, err)
num, err = client.Srem("key", 3, 4)
assert.Nil(t, err)
assert.Equal(t, 2, num)
_, err = NewRedis(client.Addr, "").Smembers("key")
_, err = New(client.Addr, badType()).Smembers("key")
assert.NotNil(t, err)
vals, err := client.Smembers("key")
assert.Nil(t, err)
assert.ElementsMatch(t, []string{"1", "2"}, vals)
_, err = NewRedis(client.Addr, "").Srandmember("key", 1)
_, err = New(client.Addr, badType()).Srandmember("key", 1)
assert.NotNil(t, err)
members, err := client.Srandmember("key", 1)
assert.Nil(t, err)
assert.Len(t, members, 1)
assert.Contains(t, []string{"1", "2"}, members[0])
_, err = NewRedis(client.Addr, "").Spop("key")
_, err = New(client.Addr, badType()).Spop("key")
assert.NotNil(t, err)
member, err := client.Spop("key")
assert.Nil(t, err)
assert.Contains(t, []string{"1", "2"}, member)
_, err = NewRedis(client.Addr, "").Smembers("key")
_, err = New(client.Addr, badType()).Smembers("key")
assert.NotNil(t, err)
vals, err = client.Smembers("key")
assert.Nil(t, err)
assert.NotContains(t, vals, member)
_, err = NewRedis(client.Addr, "").Sadd("key1", 1, 2, 3, 4)
_, err = New(client.Addr, badType()).Sadd("key1", 1, 2, 3, 4)
assert.NotNil(t, err)
num, err = client.Sadd("key1", 1, 2, 3, 4)
assert.Nil(t, err)
@@ -604,22 +618,22 @@ func TestRedis_Set(t *testing.T) {
num, err = client.Sadd("key2", 2, 3, 4, 5)
assert.Nil(t, err)
assert.Equal(t, 4, num)
_, err = NewRedis(client.Addr, "").Sunion("key1", "key2")
_, err = New(client.Addr, badType()).Sunion("key1", "key2")
assert.NotNil(t, err)
vals, err = client.Sunion("key1", "key2")
assert.Nil(t, err)
assert.ElementsMatch(t, []string{"1", "2", "3", "4", "5"}, vals)
_, err = NewRedis(client.Addr, "").Sunionstore("key3", "key1", "key2")
_, err = New(client.Addr, badType()).Sunionstore("key3", "key1", "key2")
assert.NotNil(t, err)
num, err = client.Sunionstore("key3", "key1", "key2")
assert.Nil(t, err)
assert.Equal(t, 5, num)
_, err = NewRedis(client.Addr, "").Sdiff("key1", "key2")
_, err = New(client.Addr, badType()).Sdiff("key1", "key2")
assert.NotNil(t, err)
vals, err = client.Sdiff("key1", "key2")
assert.Nil(t, err)
assert.EqualValues(t, []string{"1"}, vals)
_, err = NewRedis(client.Addr, "").Sdiffstore("key4", "key1", "key2")
_, err = New(client.Addr, badType()).Sdiffstore("key4", "key1", "key2")
assert.NotNil(t, err)
num, err = client.Sdiffstore("key4", "key1", "key2")
assert.Nil(t, err)
@@ -629,11 +643,11 @@ func TestRedis_Set(t *testing.T) {
func TestRedis_SetGetDel(t *testing.T) {
runOnRedis(t, func(client *Redis) {
err := NewRedis(client.Addr, "").Set("hello", "world")
err := New(client.Addr, badType()).Set("hello", "world")
assert.NotNil(t, err)
err = client.Set("hello", "world")
assert.Nil(t, err)
_, err = NewRedis(client.Addr, "").Get("hello")
_, err = New(client.Addr, badType()).Get("hello")
assert.NotNil(t, err)
val, err := client.Get("hello")
assert.Nil(t, err)
@@ -646,11 +660,11 @@ func TestRedis_SetGetDel(t *testing.T) {
func TestRedis_SetExNx(t *testing.T) {
runOnRedis(t, func(client *Redis) {
err := NewRedis(client.Addr, "").Setex("hello", "world", 5)
err := New(client.Addr, badType()).Setex("hello", "world", 5)
assert.NotNil(t, err)
err = client.Setex("hello", "world", 5)
assert.Nil(t, err)
_, err = NewRedis(client.Addr, "").Setnx("hello", "newworld")
_, err = New(client.Addr, badType()).Setnx("hello", "newworld")
assert.NotNil(t, err)
ok, err := client.Setnx("hello", "newworld")
assert.Nil(t, err)
@@ -667,7 +681,7 @@ func TestRedis_SetExNx(t *testing.T) {
ttl, err := client.Ttl("hello")
assert.Nil(t, err)
assert.True(t, ttl > 0)
_, err = NewRedis(client.Addr, "").SetnxEx("newhello", "newworld", 5)
_, err = New(client.Addr, badType()).SetnxEx("newhello", "newworld", 5)
assert.NotNil(t, err)
ok, err = client.SetnxEx("newhello", "newworld", 5)
assert.Nil(t, err)
@@ -688,17 +702,17 @@ func TestRedis_SetGetDelHashField(t *testing.T) {
runOnRedis(t, func(client *Redis) {
err := client.Hset("key", "field", "value")
assert.Nil(t, err)
_, err = NewRedis(client.Addr, "").Hget("key", "field")
_, err = New(client.Addr, badType()).Hget("key", "field")
assert.NotNil(t, err)
val, err := client.Hget("key", "field")
assert.Nil(t, err)
assert.Equal(t, "value", val)
_, err = NewRedis(client.Addr, "").Hexists("key", "field")
_, err = New(client.Addr, badType()).Hexists("key", "field")
assert.NotNil(t, err)
ok, err := client.Hexists("key", "field")
assert.Nil(t, err)
assert.True(t, ok)
_, err = NewRedis(client.Addr, "").Hdel("key", "field")
_, err = New(client.Addr, badType()).Hdel("key", "field")
assert.NotNil(t, err)
ret, err := client.Hdel("key", "field")
assert.Nil(t, err)
@@ -720,17 +734,17 @@ func TestRedis_SortedSet(t *testing.T) {
val, err := client.Zscore("key", "value1")
assert.Nil(t, err)
assert.Equal(t, int64(2), val)
_, err = NewRedis(client.Addr, "").Zincrby("key", 3, "value1")
_, err = New(client.Addr, badType()).Zincrby("key", 3, "value1")
assert.NotNil(t, err)
val, err = client.Zincrby("key", 3, "value1")
assert.Nil(t, err)
assert.Equal(t, int64(5), val)
_, err = NewRedis(client.Addr, "").Zscore("key", "value1")
_, err = New(client.Addr, badType()).Zscore("key", "value1")
assert.NotNil(t, err)
val, err = client.Zscore("key", "value1")
assert.Nil(t, err)
assert.Equal(t, int64(5), val)
_, err = NewRedis(client.Addr, "").Zadds("key")
_, err = New(client.Addr, badType()).Zadds("key")
assert.NotNil(t, err)
val, err = client.Zadds("key", Pair{
Key: "value2",
@@ -741,7 +755,7 @@ func TestRedis_SortedSet(t *testing.T) {
})
assert.Nil(t, err)
assert.Equal(t, int64(2), val)
_, err = NewRedis(client.Addr, "").ZRevRangeWithScores("key", 1, 3)
_, err = New(client.Addr, badType()).ZRevRangeWithScores("key", 1, 3)
assert.NotNil(t, err)
pairs, err := client.ZRevRangeWithScores("key", 1, 3)
assert.Nil(t, err)
@@ -761,11 +775,11 @@ func TestRedis_SortedSet(t *testing.T) {
rank, err = client.Zrevrank("key", "value1")
assert.Nil(t, err)
assert.Equal(t, int64(2), rank)
_, err = NewRedis(client.Addr, "").Zrank("key", "value4")
_, err = New(client.Addr, badType()).Zrank("key", "value4")
assert.NotNil(t, err)
_, err = client.Zrank("key", "value4")
assert.Equal(t, Nil, err)
_, err = NewRedis(client.Addr, "").Zrem("key", "value2", "value3")
_, err = New(client.Addr, badType()).Zrem("key", "value2", "value3")
assert.NotNil(t, err)
num, err := client.Zrem("key", "value2", "value3")
assert.Nil(t, err)
@@ -779,7 +793,7 @@ func TestRedis_SortedSet(t *testing.T) {
ok, err = client.Zadd("key", 8, "value4")
assert.Nil(t, err)
assert.True(t, ok)
_, err = NewRedis(client.Addr, "").Zremrangebyscore("key", 6, 7)
_, err = New(client.Addr, badType()).Zremrangebyscore("key", 6, 7)
assert.NotNil(t, err)
num, err = client.Zremrangebyscore("key", 6, 7)
assert.Nil(t, err)
@@ -787,37 +801,37 @@ func TestRedis_SortedSet(t *testing.T) {
ok, err = client.Zadd("key", 6, "value2")
assert.Nil(t, err)
assert.True(t, ok)
_, err = NewRedis(client.Addr, "").Zadd("key", 7, "value3")
_, err = New(client.Addr, badType()).Zadd("key", 7, "value3")
assert.NotNil(t, err)
ok, err = client.Zadd("key", 7, "value3")
assert.Nil(t, err)
assert.True(t, ok)
_, err = NewRedis(client.Addr, "").Zcount("key", 6, 7)
_, err = New(client.Addr, badType()).Zcount("key", 6, 7)
assert.NotNil(t, err)
num, err = client.Zcount("key", 6, 7)
assert.Nil(t, err)
assert.Equal(t, 2, num)
_, err = NewRedis(client.Addr, "").Zremrangebyrank("key", 1, 2)
_, err = New(client.Addr, badType()).Zremrangebyrank("key", 1, 2)
assert.NotNil(t, err)
num, err = client.Zremrangebyrank("key", 1, 2)
assert.Nil(t, err)
assert.Equal(t, 2, num)
_, err = NewRedis(client.Addr, "").Zcard("key")
_, err = New(client.Addr, badType()).Zcard("key")
assert.NotNil(t, err)
card, err := client.Zcard("key")
assert.Nil(t, err)
assert.Equal(t, 2, card)
_, err = NewRedis(client.Addr, "").Zrange("key", 0, -1)
_, err = New(client.Addr, badType()).Zrange("key", 0, -1)
assert.NotNil(t, err)
vals, err := client.Zrange("key", 0, -1)
assert.Nil(t, err)
assert.EqualValues(t, []string{"value1", "value4"}, vals)
_, err = NewRedis(client.Addr, "").Zrevrange("key", 0, -1)
_, err = New(client.Addr, badType()).Zrevrange("key", 0, -1)
assert.NotNil(t, err)
vals, err = client.Zrevrange("key", 0, -1)
assert.Nil(t, err)
assert.EqualValues(t, []string{"value4", "value1"}, vals)
_, err = NewRedis(client.Addr, "").ZrangeWithScores("key", 0, -1)
_, err = New(client.Addr, badType()).ZrangeWithScores("key", 0, -1)
assert.NotNil(t, err)
pairs, err = client.ZrangeWithScores("key", 0, -1)
assert.Nil(t, err)
@@ -831,7 +845,7 @@ func TestRedis_SortedSet(t *testing.T) {
Score: 8,
},
}, pairs)
_, err = NewRedis(client.Addr, "").ZrangebyscoreWithScores("key", 5, 8)
_, err = New(client.Addr, badType()).ZrangebyscoreWithScores("key", 5, 8)
assert.NotNil(t, err)
pairs, err = client.ZrangebyscoreWithScores("key", 5, 8)
assert.Nil(t, err)
@@ -845,7 +859,7 @@ func TestRedis_SortedSet(t *testing.T) {
Score: 8,
},
}, pairs)
_, err = NewRedis(client.Addr, "").ZrangebyscoreWithScoresAndLimit(
_, err = New(client.Addr, badType()).ZrangebyscoreWithScoresAndLimit(
"key", 5, 8, 1, 1)
assert.NotNil(t, err)
pairs, err = client.ZrangebyscoreWithScoresAndLimit("key", 5, 8, 1, 1)
@@ -859,7 +873,7 @@ func TestRedis_SortedSet(t *testing.T) {
pairs, err = client.ZrangebyscoreWithScoresAndLimit("key", 5, 8, 1, 0)
assert.Nil(t, err)
assert.Equal(t, 0, len(pairs))
_, err = NewRedis(client.Addr, "").ZrevrangebyscoreWithScores("key", 5, 8)
_, err = New(client.Addr, badType()).ZrevrangebyscoreWithScores("key", 5, 8)
assert.NotNil(t, err)
pairs, err = client.ZrevrangebyscoreWithScores("key", 5, 8)
assert.Nil(t, err)
@@ -873,7 +887,7 @@ func TestRedis_SortedSet(t *testing.T) {
Score: 5,
},
}, pairs)
_, err = NewRedis(client.Addr, "").ZrevrangebyscoreWithScoresAndLimit(
_, err = New(client.Addr, badType()).ZrevrangebyscoreWithScoresAndLimit(
"key", 5, 8, 1, 1)
assert.NotNil(t, err)
pairs, err = client.ZrevrangebyscoreWithScoresAndLimit("key", 5, 8, 1, 1)
@@ -887,7 +901,7 @@ func TestRedis_SortedSet(t *testing.T) {
pairs, err = client.ZrevrangebyscoreWithScoresAndLimit("key", 5, 8, 1, 0)
assert.Nil(t, err)
assert.Equal(t, 0, len(pairs))
_, err = NewRedis(client.Addr, "").Zrevrank("key", "value")
_, err = New(client.Addr, badType()).Zrevrank("key", "value")
assert.NotNil(t, err)
client.Zadd("second", 2, "aa")
client.Zadd("third", 3, "bbb")
@@ -897,6 +911,8 @@ func TestRedis_SortedSet(t *testing.T) {
}, "second", "third")
assert.Nil(t, err)
assert.Equal(t, int64(2), val)
_, err = New(client.Addr, badType()).Zunionstore("union", ZStore{})
assert.NotNil(t, err)
vals, err = client.Zrange("union", 0, 10000)
assert.Nil(t, err)
assert.EqualValues(t, []string{"aa", "bbb"}, vals)
@@ -908,7 +924,7 @@ func TestRedis_SortedSet(t *testing.T) {
func TestRedis_Pipelined(t *testing.T) {
runOnRedis(t, func(client *Redis) {
assert.NotNil(t, NewRedis(client.Addr, "").Pipelined(func(pipeliner Pipeliner) error {
assert.NotNil(t, New(client.Addr, badType()).Pipelined(func(pipeliner Pipeliner) error {
return nil
}))
err := client.Pipelined(
@@ -920,7 +936,7 @@ func TestRedis_Pipelined(t *testing.T) {
},
)
assert.Nil(t, err)
_, err = NewRedis(client.Addr, "").Ttl("pipelined_counter")
_, err = New(client.Addr, badType()).Ttl("pipelined_counter")
assert.NotNil(t, err)
ttl, err := client.Ttl("pipelined_counter")
assert.Nil(t, err)
@@ -940,20 +956,31 @@ func TestRedisString(t *testing.T) {
_, err := getRedis(NewRedis(client.Addr, ClusterType))
assert.Nil(t, err)
assert.Equal(t, client.Addr, client.String())
assert.NotNil(t, NewRedis(client.Addr, "").Ping())
assert.NotNil(t, New(client.Addr, badType()).Ping())
})
}
func TestRedisScriptLoad(t *testing.T) {
runOnRedis(t, func(client *Redis) {
client.Ping()
_, err := NewRedis(client.Addr, "").scriptLoad("foo")
_, err := New(client.Addr, badType()).ScriptLoad("foo")
assert.NotNil(t, err)
_, err = client.scriptLoad("foo")
_, err = client.ScriptLoad("foo")
assert.NotNil(t, err)
})
}
func TestRedisEvalSha(t *testing.T) {
runOnRedis(t, func(client *Redis) {
client.Ping()
scriptHash, err := client.ScriptLoad(`return redis.call("EXISTS", KEYS[1])`)
assert.Nil(t, err)
result, err := client.EvalSha(scriptHash, []string{"key1"})
assert.Nil(t, err)
assert.Equal(t, int64(0), result)
})
}
func TestRedisToPairs(t *testing.T) {
pairs := toPairs([]red.Z{
{
@@ -1007,7 +1034,7 @@ func TestRedisBlpopEx(t *testing.T) {
func TestRedisGeo(t *testing.T) {
runOnRedis(t, func(client *Redis) {
client.Ping()
var geoLocation = []*GeoLocation{{Longitude: 13.361389, Latitude: 38.115556, Name: "Palermo"}, {Longitude: 15.087269, Latitude: 37.502669, Name: "Catania"}}
geoLocation := []*GeoLocation{{Longitude: 13.361389, Latitude: 38.115556, Name: "Palermo"}, {Longitude: 15.087269, Latitude: 37.502669, Name: "Catania"}}
v, err := client.GeoAdd("sicily", geoLocation...)
assert.Nil(t, err)
assert.Equal(t, int64(2), v)
@@ -1025,7 +1052,7 @@ func TestRedisGeo(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, int64(v4[0].Dist), int64(190))
assert.Equal(t, int64(v4[1].Dist), int64(56))
var geoLocation2 = []*GeoLocation{{Longitude: 13.583333, Latitude: 37.316667, Name: "Agrigento"}}
geoLocation2 := []*GeoLocation{{Longitude: 13.583333, Latitude: 37.316667, Name: "Agrigento"}}
v5, err := client.GeoAdd("sicily", geoLocation2...)
assert.Nil(t, err)
assert.Equal(t, int64(1), v5)
@@ -1036,6 +1063,13 @@ func TestRedisGeo(t *testing.T) {
})
}
func TestRedis_WithPass(t *testing.T) {
runOnRedis(t, func(client *Redis) {
err := NewRedis(client.Addr, NodeType, "any").Ping()
assert.NotNil(t, err)
})
}
func runOnRedis(t *testing.T, fn func(client *Redis)) {
s, err := miniredis.Run()
assert.Nil(t, err)
@@ -1051,10 +1085,35 @@ func runOnRedis(t *testing.T, fn func(client *Redis)) {
client.Close()
}
}()
fn(NewRedis(s.Addr(), NodeType))
}
func runOnRedisTLS(t *testing.T, fn func(client *Redis)) {
s, err := miniredis.RunTLS(&tls.Config{
Certificates: make([]tls.Certificate, 1),
InsecureSkipVerify: true,
})
assert.Nil(t, err)
defer func() {
client, err := clientManager.GetResource(s.Addr(), func() (io.Closer, error) {
return nil, errors.New("should already exist")
})
if err != nil {
t.Error(err)
}
if client != nil {
client.Close()
}
}()
fn(New(s.Addr(), WithTLS()))
}
func badType() Option {
return func(r *Redis) {
r.Type = "bad"
}
}
type mockedNode struct {
RedisNode
}

View File

@@ -1,6 +1,7 @@
package redis
import (
"crypto/tls"
"io"
red "github.com/go-redis/redis"
@@ -15,14 +16,21 @@ const (
var clientManager = syncx.NewResourceManager()
func getClient(server, pass string) (*red.Client, error) {
val, err := clientManager.GetResource(server, func() (io.Closer, error) {
func getClient(r *Redis) (*red.Client, error) {
val, err := clientManager.GetResource(r.Addr, func() (io.Closer, error) {
var tlsConfig *tls.Config
if r.tls {
tlsConfig = &tls.Config{
InsecureSkipVerify: true,
}
}
store := red.NewClient(&red.Options{
Addr: server,
Password: pass,
Addr: r.Addr,
Password: r.Pass,
DB: defaultDatabase,
MaxRetries: maxRetries,
MinIdleConns: idleConns,
TLSConfig: tlsConfig,
})
store.WrapProcess(process)
return store, nil

View File

@@ -1,6 +1,7 @@
package redis
import (
"crypto/tls"
"io"
red "github.com/go-redis/redis"
@@ -9,13 +10,20 @@ import (
var clusterManager = syncx.NewResourceManager()
func getCluster(server, pass string) (*red.ClusterClient, error) {
val, err := clusterManager.GetResource(server, func() (io.Closer, error) {
func getCluster(r *Redis) (*red.ClusterClient, error) {
val, err := clusterManager.GetResource(r.Addr, func() (io.Closer, error) {
var tlsConfig *tls.Config
if r.tls {
tlsConfig = &tls.Config{
InsecureSkipVerify: true,
}
}
store := red.NewClusterClient(&red.ClusterOptions{
Addrs: []string{server},
Password: pass,
Addrs: []string{r.Addr},
Password: r.Pass,
MaxRetries: maxRetries,
MinIdleConns: idleConns,
TLSConfig: tlsConfig,
})
store.WrapProcess(process)

View File

@@ -53,7 +53,8 @@ func NewRedisLock(store *Redis, key string) *RedisLock {
func (rl *RedisLock) Acquire() (bool, error) {
seconds := atomic.LoadUint32(&rl.seconds)
resp, err := rl.store.Eval(lockCommand, []string{rl.key}, []string{
rl.id, strconv.Itoa(int(seconds)*millisPerSecond + tolerance)})
rl.id, strconv.Itoa(int(seconds)*millisPerSecond + tolerance),
})
if err == red.Nil {
return false, nil
} else if err != nil {

View File

@@ -24,6 +24,7 @@ type (
ResultHandler func(sql.Result, error)
// A BulkInserter is used to batch insert records.
// Postgresql is not supported yet, because of the sql is formated with symbol `$`.
BulkInserter struct {
executor *executors.PeriodicalExecutor
inserter *dbInserter

View File

@@ -206,7 +206,7 @@ func TestUnmarshalRowString(t *testing.T) {
}
func TestUnmarshalRowStruct(t *testing.T) {
var value = new(struct {
value := new(struct {
Name string
Age int
})
@@ -224,7 +224,7 @@ func TestUnmarshalRowStruct(t *testing.T) {
}
func TestUnmarshalRowStructWithTags(t *testing.T) {
var value = new(struct {
value := new(struct {
Age int `db:"age"`
Name string `db:"name"`
})
@@ -242,7 +242,7 @@ func TestUnmarshalRowStructWithTags(t *testing.T) {
}
func TestUnmarshalRowStructWithTagsWrongColumns(t *testing.T) {
var value = new(struct {
value := new(struct {
Age *int `db:"age"`
Name string `db:"name"`
})
@@ -259,7 +259,7 @@ func TestUnmarshalRowStructWithTagsWrongColumns(t *testing.T) {
func TestUnmarshalRowsBool(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []bool{true, false}
expect := []bool{true, false}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -273,7 +273,7 @@ func TestUnmarshalRowsBool(t *testing.T) {
func TestUnmarshalRowsInt(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []int{2, 3}
expect := []int{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -287,7 +287,7 @@ func TestUnmarshalRowsInt(t *testing.T) {
func TestUnmarshalRowsInt8(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []int8{2, 3}
expect := []int8{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -301,7 +301,7 @@ func TestUnmarshalRowsInt8(t *testing.T) {
func TestUnmarshalRowsInt16(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []int16{2, 3}
expect := []int16{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -315,7 +315,7 @@ func TestUnmarshalRowsInt16(t *testing.T) {
func TestUnmarshalRowsInt32(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []int32{2, 3}
expect := []int32{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -329,7 +329,7 @@ func TestUnmarshalRowsInt32(t *testing.T) {
func TestUnmarshalRowsInt64(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []int64{2, 3}
expect := []int64{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -343,7 +343,7 @@ func TestUnmarshalRowsInt64(t *testing.T) {
func TestUnmarshalRowsUint(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []uint{2, 3}
expect := []uint{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -357,7 +357,7 @@ func TestUnmarshalRowsUint(t *testing.T) {
func TestUnmarshalRowsUint8(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []uint8{2, 3}
expect := []uint8{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -371,7 +371,7 @@ func TestUnmarshalRowsUint8(t *testing.T) {
func TestUnmarshalRowsUint16(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []uint16{2, 3}
expect := []uint16{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -385,7 +385,7 @@ func TestUnmarshalRowsUint16(t *testing.T) {
func TestUnmarshalRowsUint32(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []uint32{2, 3}
expect := []uint32{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -399,7 +399,7 @@ func TestUnmarshalRowsUint32(t *testing.T) {
func TestUnmarshalRowsUint64(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []uint64{2, 3}
expect := []uint64{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -413,7 +413,7 @@ func TestUnmarshalRowsUint64(t *testing.T) {
func TestUnmarshalRowsFloat32(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []float32{2, 3}
expect := []float32{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -427,7 +427,7 @@ func TestUnmarshalRowsFloat32(t *testing.T) {
func TestUnmarshalRowsFloat64(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []float64{2, 3}
expect := []float64{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -441,7 +441,7 @@ func TestUnmarshalRowsFloat64(t *testing.T) {
func TestUnmarshalRowsString(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []string{"hello", "world"}
expect := []string{"hello", "world"}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("hello\nworld")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -457,7 +457,7 @@ func TestUnmarshalRowsBoolPtr(t *testing.T) {
yes := true
no := false
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []*bool{&yes, &no}
expect := []*bool{&yes, &no}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -473,7 +473,7 @@ func TestUnmarshalRowsIntPtr(t *testing.T) {
two := 2
three := 3
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []*int{&two, &three}
expect := []*int{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -489,7 +489,7 @@ func TestUnmarshalRowsInt8Ptr(t *testing.T) {
two := int8(2)
three := int8(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []*int8{&two, &three}
expect := []*int8{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -505,7 +505,7 @@ func TestUnmarshalRowsInt16Ptr(t *testing.T) {
two := int16(2)
three := int16(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []*int16{&two, &three}
expect := []*int16{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -521,7 +521,7 @@ func TestUnmarshalRowsInt32Ptr(t *testing.T) {
two := int32(2)
three := int32(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []*int32{&two, &three}
expect := []*int32{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -537,7 +537,7 @@ func TestUnmarshalRowsInt64Ptr(t *testing.T) {
two := int64(2)
three := int64(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []*int64{&two, &three}
expect := []*int64{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -553,7 +553,7 @@ func TestUnmarshalRowsUintPtr(t *testing.T) {
two := uint(2)
three := uint(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []*uint{&two, &three}
expect := []*uint{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -569,7 +569,7 @@ func TestUnmarshalRowsUint8Ptr(t *testing.T) {
two := uint8(2)
three := uint8(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []*uint8{&two, &three}
expect := []*uint8{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -585,7 +585,7 @@ func TestUnmarshalRowsUint16Ptr(t *testing.T) {
two := uint16(2)
three := uint16(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []*uint16{&two, &three}
expect := []*uint16{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -601,7 +601,7 @@ func TestUnmarshalRowsUint32Ptr(t *testing.T) {
two := uint32(2)
three := uint32(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []*uint32{&two, &three}
expect := []*uint32{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -617,7 +617,7 @@ func TestUnmarshalRowsUint64Ptr(t *testing.T) {
two := uint64(2)
three := uint64(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []*uint64{&two, &three}
expect := []*uint64{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -633,7 +633,7 @@ func TestUnmarshalRowsFloat32Ptr(t *testing.T) {
two := float32(2)
three := float32(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []*float32{&two, &three}
expect := []*float32{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -649,7 +649,7 @@ func TestUnmarshalRowsFloat64Ptr(t *testing.T) {
two := float64(2)
three := float64(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []*float64{&two, &three}
expect := []*float64{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -665,7 +665,7 @@ func TestUnmarshalRowsStringPtr(t *testing.T) {
hello := "hello"
world := "world"
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []*string{&hello, &world}
expect := []*string{&hello, &world}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("hello\nworld")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -678,7 +678,7 @@ func TestUnmarshalRowsStringPtr(t *testing.T) {
}
func TestUnmarshalRowsStruct(t *testing.T) {
var expect = []struct {
expect := []struct {
Name string
Age int64
}{
@@ -711,7 +711,7 @@ func TestUnmarshalRowsStruct(t *testing.T) {
}
func TestUnmarshalRowsStructWithNullStringType(t *testing.T) {
var expect = []struct {
expect := []struct {
Name string
NullString sql.NullString
}{
@@ -752,7 +752,7 @@ func TestUnmarshalRowsStructWithNullStringType(t *testing.T) {
}
func TestUnmarshalRowsStructWithTags(t *testing.T) {
var expect = []struct {
expect := []struct {
Name string
Age int64
}{
@@ -789,7 +789,7 @@ func TestUnmarshalRowsStructAndEmbeddedAnonymousStructWithTags(t *testing.T) {
Value int64 `db:"value"`
}
var expect = []struct {
expect := []struct {
Name string
Age int64
Value int64
@@ -831,7 +831,7 @@ func TestUnmarshalRowsStructAndEmbeddedStructPtrAnonymousWithTags(t *testing.T)
Value int64 `db:"value"`
}
var expect = []struct {
expect := []struct {
Name string
Age int64
Value int64
@@ -869,7 +869,7 @@ func TestUnmarshalRowsStructAndEmbeddedStructPtrAnonymousWithTags(t *testing.T)
}
func TestUnmarshalRowsStructPtr(t *testing.T) {
var expect = []*struct {
expect := []*struct {
Name string
Age int64
}{
@@ -902,7 +902,7 @@ func TestUnmarshalRowsStructPtr(t *testing.T) {
}
func TestUnmarshalRowsStructWithTagsPtr(t *testing.T) {
var expect = []*struct {
expect := []*struct {
Name string
Age int64
}{
@@ -935,7 +935,7 @@ func TestUnmarshalRowsStructWithTagsPtr(t *testing.T) {
}
func TestUnmarshalRowsStructWithTagsPtrWithInnerPtr(t *testing.T) {
var expect = []*struct {
expect := []*struct {
Name string
Age int64
}{

View File

@@ -12,14 +12,10 @@ import (
const slowThreshold = time.Millisecond * 500
func exec(conn sessionConn, q string, args ...interface{}) (sql.Result, error) {
stmt, err := format(q, args...)
if err != nil {
return nil, err
}
startTime := timex.Now()
result, err := conn.Exec(q, args...)
duration := timex.Since(startTime)
stmt := formatForPrint(q, args)
if duration > slowThreshold {
logx.WithDuration(duration).Slowf("[SQL] exec: slowcall - %s", stmt)
} else {
@@ -33,10 +29,10 @@ func exec(conn sessionConn, q string, args ...interface{}) (sql.Result, error) {
}
func execStmt(conn stmtConn, args ...interface{}) (sql.Result, error) {
stmt := fmt.Sprint(args...)
startTime := timex.Now()
result, err := conn.Exec(args...)
duration := timex.Since(startTime)
stmt := fmt.Sprint(args...)
if duration > slowThreshold {
logx.WithDuration(duration).Slowf("[SQL] execStmt: slowcall - %s", stmt)
} else {
@@ -50,14 +46,10 @@ func execStmt(conn stmtConn, args ...interface{}) (sql.Result, error) {
}
func query(conn sessionConn, scanner func(*sql.Rows) error, q string, args ...interface{}) error {
stmt, err := format(q, args...)
if err != nil {
return err
}
startTime := timex.Now()
rows, err := conn.Query(q, args...)
duration := timex.Since(startTime)
stmt := fmt.Sprint(args...)
if duration > slowThreshold {
logx.WithDuration(duration).Slowf("[SQL] query: slowcall - %s", stmt)
} else {

View File

@@ -16,7 +16,6 @@ func TestStmt_exec(t *testing.T) {
name string
args []interface{}
delay bool
formatError bool
hasError bool
err error
lastInsertId int64
@@ -28,12 +27,6 @@ func TestStmt_exec(t *testing.T) {
lastInsertId: 1,
rowsAffected: 2,
},
{
name: "wrong format",
args: []interface{}{1, 2},
formatError: true,
hasError: true,
},
{
name: "exec error",
args: []interface{}{1},
@@ -70,18 +63,13 @@ func TestStmt_exec(t *testing.T) {
},
}
for i, fn := range fns {
i := i
for _, fn := range fns {
fn := fn
t.Run(test.name, func(t *testing.T) {
t.Parallel()
res, err := fn(test.args...)
if i == 0 && test.formatError {
assert.NotNil(t, err)
return
}
if !test.formatError && test.hasError {
if test.hasError {
assert.NotNil(t, err)
return
}
@@ -100,23 +88,16 @@ func TestStmt_exec(t *testing.T) {
func TestStmt_query(t *testing.T) {
tests := []struct {
name string
args []interface{}
delay bool
formatError bool
hasError bool
err error
name string
args []interface{}
delay bool
hasError bool
err error
}{
{
name: "normal",
args: []interface{}{1},
},
{
name: "wrong format",
args: []interface{}{1, 2},
formatError: true,
hasError: true,
},
{
name: "query error",
args: []interface{}{1},
@@ -151,18 +132,13 @@ func TestStmt_query(t *testing.T) {
},
}
for i, fn := range fns {
i := i
for _, fn := range fns {
fn := fn
t.Run(test.name, func(t *testing.T) {
t.Parallel()
err := fn(test.args...)
if i == 0 && test.formatError {
assert.NotNil(t, err)
return
}
if !test.formatError && test.hasError {
if test.hasError {
assert.NotNil(t, err)
return
}

View File

@@ -45,6 +45,24 @@ func escape(input string) string {
return b.String()
}
func formatForPrint(query string, args ...interface{}) string {
if len(args) == 0 {
return query
}
var vals []string
for _, arg := range args {
vals = append(vals, fmt.Sprintf("%q", mapping.Repr(arg)))
}
var b strings.Builder
b.WriteByte('[')
b.WriteString(strings.Join(vals, ", "))
b.WriteByte(']')
return strings.Join([]string{query, b.String()}, " ")
}
func format(query string, args ...interface{}) (string, error) {
numArgs := len(args)
if numArgs == 0 {

View File

@@ -28,3 +28,31 @@ func TestDesensitize_WithoutAccount(t *testing.T) {
datasource = desensitize(datasource)
assert.True(t, strings.Contains(datasource, "tcp(111.222.333.44:3306)"))
}
func TestFormatForPrint(t *testing.T) {
tests := []struct {
name string
query string
args []interface{}
expect string
}{
{
name: "no args",
query: "select user, name from table where id=?",
expect: `select user, name from table where id=?`,
},
{
name: "one arg",
query: "select user, name from table where id=?",
args: []interface{}{"kevin"},
expect: `select user, name from table where id=? ["kevin"]`,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
actual := formatForPrint(test.query, test.args...)
assert.Equal(t, test.expect, actual)
})
}
}

View File

@@ -16,7 +16,7 @@ type (
// NewReplacer returns a Replacer.
func NewReplacer(mapping map[string]string) Replacer {
var rep = &replacer{
rep := &replacer{
mapping: mapping,
}
for k := range mapping {
@@ -28,9 +28,9 @@ func NewReplacer(mapping map[string]string) Replacer {
func (r *replacer) Replace(text string) string {
var builder strings.Builder
var chars = []rune(text)
var size = len(chars)
var start = -1
chars := []rune(text)
size := len(chars)
start := -1
for i := 0; i < size; i++ {
child, ok := r.children[chars[i]]
@@ -42,12 +42,12 @@ func (r *replacer) Replace(text string) string {
if start < 0 {
start = i
}
var end = -1
end := -1
if child.end {
end = i + 1
}
var j = i + 1
j := i + 1
for ; j < size; j++ {
grandchild, ok := child.children[chars[j]]
if !ok {

View File

@@ -7,7 +7,7 @@ import (
)
func TestReplacer_Replace(t *testing.T) {
var mapping = map[string]string{
mapping := map[string]string{
"一二三四": "1234",
"二三": "23",
"二": "2",
@@ -16,28 +16,28 @@ func TestReplacer_Replace(t *testing.T) {
}
func TestReplacer_ReplaceSingleChar(t *testing.T) {
var mapping = map[string]string{
mapping := map[string]string{
"二": "2",
}
assert.Equal(t, "零一2三四五", NewReplacer(mapping).Replace("零一二三四五"))
}
func TestReplacer_ReplaceExceedRange(t *testing.T) {
var mapping = map[string]string{
mapping := map[string]string{
"二三四五六": "23456",
}
assert.Equal(t, "零一二三四五", NewReplacer(mapping).Replace("零一二三四五"))
}
func TestReplacer_ReplacePartialMatch(t *testing.T) {
var mapping = map[string]string{
mapping := map[string]string{
"二三四七": "2347",
}
assert.Equal(t, "零一二三四五", NewReplacer(mapping).Replace("零一二三四五"))
}
func TestReplacer_ReplaceMultiMatches(t *testing.T) {
var mapping = map[string]string{
mapping := map[string]string{
"二三": "23",
}
assert.Equal(t, "零一23四五一23四五", NewReplacer(mapping).Replace("零一二三四五一二三四五"))

View File

@@ -9,7 +9,12 @@ type Barrier struct {
// Guard guards the given fn on the resource.
func (b *Barrier) Guard(fn func()) {
b.lock.Lock()
defer b.lock.Unlock()
Guard(&b.lock, fn)
}
// Guard guards the given fn with lock.
func Guard(lock sync.Locker, fn func()) {
lock.Lock()
defer lock.Unlock()
fn()
}

View File

@@ -38,3 +38,19 @@ func TestBarrierPtr_Guard(t *testing.T) {
wg.Wait()
assert.Equal(t, total, count)
}
func TestGuard(t *testing.T) {
const total = 10000
var count int
var lock sync.Mutex
wg := new(sync.WaitGroup)
wg.Add(total)
for i := 0; i < total; i++ {
go Guard(&lock, func() {
count++
wg.Done()
})
}
wg.Wait()
assert.Equal(t, total, count)
}

2
go.mod
View File

@@ -6,7 +6,6 @@ require (
github.com/ClickHouse/clickhouse-go v1.4.3
github.com/DATA-DOG/go-sqlmock v1.4.1
github.com/alicebob/miniredis/v2 v2.14.1
github.com/antlr/antlr4 v0.0.0-20210105212045-464bcbc32de2
github.com/cpuguy83/go-md2man/v2 v2.0.0 // indirect
github.com/dgrijalva/jwt-go v3.2.0+incompatible
github.com/emicklei/proto v1.9.0
@@ -44,6 +43,7 @@ require (
github.com/tmc/grpc-websocket-proxy v0.0.0-20171017195756-830351dc03c6 // indirect
github.com/urfave/cli v1.22.5
github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2
github.com/zeromicro/antlr v0.0.1 // indirect
go.etcd.io/etcd v0.0.0-20200402134248-51bdeb39e698
go.uber.org/automaxprocs v1.3.0
golang.org/x/lint v0.0.0-20200302205851-738671d3881b // indirect

4
go.sum
View File

@@ -276,6 +276,10 @@ github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2/go.mod h1:hzfGeI
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/gopher-lua v0.0.0-20191220021717-ab39c6098bdb h1:ZkM6LRnq40pR1Ox0hTHlnpkcOTuFIDQpZ1IN8rKKhX0=
github.com/yuin/gopher-lua v0.0.0-20191220021717-ab39c6098bdb/go.mod h1:gqRgreBUhTSL0GeU64rtZ3Uq3wtjOa/TB2YfrtkCbVQ=
github.com/zeromicro/antlr v0.0.0-20210508120604-8d7a7786f5c4 h1:wt86gT5lsN+xWmij6lFCrDIqoxeOnqM2MxiO7cNm+Lo=
github.com/zeromicro/antlr v0.0.0-20210508120604-8d7a7786f5c4/go.mod h1:nfpjEwFR6Q4xGDJMcZnCL9tEfQRgszMwu3rDz2Z+p5M=
github.com/zeromicro/antlr v0.0.1 h1:CQpIn/dc0pUjgGQ81y98s/NGOm2Hfru2NNio2I9mQgk=
github.com/zeromicro/antlr v0.0.1/go.mod h1:nfpjEwFR6Q4xGDJMcZnCL9tEfQRgszMwu3rDz2Z+p5M=
go.etcd.io/bbolt v1.3.4 h1:hi1bXHMVrlQh6WwxAy+qZCV/SYIlqo+Ushwdpa4tAKg=
go.etcd.io/bbolt v1.3.4/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ=
go.etcd.io/etcd v0.0.0-20200402134248-51bdeb39e698 h1:jWtjCJX1qxhHISBMLRztWwR+EXkI7MJAF2HjHAE/x/I=

View File

@@ -159,7 +159,17 @@ GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/tal-tech/
* API 文档
[https://www.yuque.com/tal-tech/go-zero](https://www.yuque.com/tal-tech/go-zero)
[https://go-zero.dev/zh-hans/](https://zeromicro.github.io/go-zero)
* 常见问题
* 因为 `etcd` 和 `grpc` 兼容性问题,请使用 `grpc@v1.29.1`
`google.golang.org/grpc v1.29.1`
* 因为 `protobuf` 兼容性问题,请使用 `protocol-gen@v1.3.2`
`go get -u github.com/golang/protobuf/protoc-gen-go@v1.3.2`
* awesome 系列(更多文章见『微服务实践』公众号)
* [快速构建高并发微服务](https://github.com/tal-tech/zero-doc/blob/main/doc/shorturl.md)
@@ -175,13 +185,39 @@ GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/tal-tech/
| [goctl-android](https://github.com/zeromicro/goctl-android) | 生成 `java (android)` 端 `http client` 请求代码 |
| [goctl-go-compact](https://github.com/zeromicro/goctl-go-compact) | 合并 `api` 里同一个 `group` 里的 `handler` 到一个 `go` 文件 |
## 8. 微信公众号
## 8. go-zero 用户
`go-zero` 相关文章都会在 `微服务实践` 公众号整理呈现,欢迎扫码关注,也可以通过公众号私信我 👏
go-zero 已被许多公司用于生产部署,接入场景如在线教育、电商业务、游戏、区块链等,目前为止,已使用 go-zero 的公司包括但不限于:
>1. 好未来
>2. 上海晓信信息科技有限公司(晓黑板)
>3. 上海玉数科技有限公司
>4. 常州千帆网络科技有限公司
>5. 上班族科技
>6. 英雄体育VSPN
>7. githubmemory
>8. 释空(上海)品牌策划有限公司(senkoo)
>9. 鞍山三合众鑫科技有限公司
>10. 广州星梦工场网络科技有限公司
>11. 杭州复杂美科技有限公司
>12. 赛凌科技
>13. 捞月狗
>14. 浙江三合通信科技有限公司
>15. 爱克萨
>16. 郑州众合互联信息技术有限公司
>17. 三七游戏
>18. 成都创道夫科技有限公司
>19. 联想Lenovo
如果贵公司也已使用 go-zero欢迎在 [登记地址](https://github.com/tal-tech/go-zero/issues/602) 登记,仅仅为了推广,不做其它用途。
## 9. 微信公众号
`go-zero` 相关文章和视频都会在 `微服务实践` 公众号整理呈现,欢迎扫码关注 👏
<img src="https://gitee.com/kevwan/static/raw/master/images/wechat-micro.jpg" alt="wechat" width="300" />
## 9. 微信交流群
## 10. 微信交流群
如果文档中未能覆盖的任何疑问,欢迎您在群里提出,我们会尽快答复。
@@ -189,12 +225,6 @@ GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/tal-tech/
如果您发现 ***bug*** 请及时提 ***issue***,我们会尽快确认并修改。
为了防止广告用户、识别技术同行,请 ***star*** 后加我时注明 **github** 当前 ***star*** 数,我再拉进 **go-zero** 群,感谢!
加群之前有劳点一下 ***star***,一个小小的 ***star*** 是作者们回答海量问题的动力🤝
加我之前有劳点一下 ***star***,一个小小的 ***star*** 是作者们回答海量问题的动力🤝
<img src="https://gitee.com/kevwan/static/raw/master/images/wechat.jpg" alt="wechat" width="300" />
项目地址:[https://github.com/tal-tech/go-zero](https://github.com/tal-tech/go-zero)
码云地址:[https://gitee.com/kevwan/go-zero](https://gitee.com/kevwan/go-zero) (国内用户可访问gitee每日自动从github同步代码)
<img src="https://raw.githubusercontent.com/tal-tech/zero-doc/main/doc/images/wechat.jpg" alt="wechat" width="300" />

View File

@@ -12,7 +12,7 @@ English | [简体中文](readme-cn.md)
## 0. what is go-zero
go-zero is a web and rpc framework that with lots of engineering practices builtin. Its born to ensure the stability of the busy services with resilience design, and has been serving sites with tens of millions users for years.
go-zero is a web and rpc framework with lots of builtin engineering practices. Its born to ensure the stability of the busy services with resilience design, and has been serving sites with tens of millions users for years.
go-zero contains simple API description syntax and code generation tool called `goctl`. You can generate Go, iOS, Android, Kotlin, Dart, TypeScript, JavaScript from .api files with `goctl`.
@@ -115,11 +115,11 @@ go get -u github.com/tal-tech/go-zero
type Request struct {
Name string `path:"name,options=you|me"` // parameters are auto validated
}
type Response struct {
Message string `json:"message"`
}
service greet-api {
@handler GreetHandler
get /greet/from/:name(Request) returns (Response);
@@ -200,6 +200,8 @@ go get -u github.com/tal-tech/go-zero
## 7. Benchmark
Document: [https://go-zero.dev/en/](https://go-zero.dev/en/)
![benchmark](https://raw.githubusercontent.com/tal-tech/zero-doc/main/doc/images/benchmark.png)
[Checkout the test code](https://github.com/smallnest/go-web-framework-benchmark)
@@ -210,6 +212,20 @@ go get -u github.com/tal-tech/go-zero
* [Rapid development of microservice systems - multiple RPCs](https://github.com/tal-tech/zero-doc/blob/main/docs/zero/bookstore-en.md)
* [Examples](https://github.com/zeromicro/zero-examples)
## 9. Chat group
## 9. Important notes
* Use grpc 1.29.1, because etcd lib doesnt support latter versions.
`google.golang.org/grpc v1.29.1`
* For protobuf compatibility, use `protocol-gen@v1.3.2`.
` go get -u github.com/golang/protobuf/protoc-gen-go@v1.3.2`
## 10. Chat group
Join the chat via https://join.slack.com/t/go-zeroworkspace/shared_invite/zt-m39xssxc-kgIqERa7aVsujKNj~XuPKg
## Give a Star! ⭐
If you like or are using this project to learn or start your solution, please give it a star. Thanks!

View File

@@ -109,13 +109,13 @@ func (s *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat
chain := alice.New(
handler.TracingHandler,
s.getLogHandler(),
handler.PrometheusHandler(route.Path),
handler.MaxConns(s.conf.MaxConns),
handler.BreakerHandler(route.Method, route.Path, metrics),
handler.SheddingHandler(s.getShedder(fr.priority), metrics),
handler.TimeoutHandler(time.Duration(s.conf.Timeout)*time.Millisecond),
handler.RecoverHandler,
handler.MetricHandler(metrics),
handler.PrometheusHandler(route.Path),
handler.MaxBytesHandler(s.conf.MaxBytes),
handler.GunzipHandler,
)

View File

@@ -154,8 +154,7 @@ Verbose: true
}
}
type mockedRouter struct {
}
type mockedRouter struct{}
func (m mockedRouter) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
}

View File

@@ -1,8 +1,10 @@
package handler
import (
"bufio"
"context"
"errors"
"net"
"net/http"
"net/http/httputil"
@@ -138,6 +140,16 @@ func (grw *guardedResponseWriter) Header() http.Header {
return grw.writer.Header()
}
// Hijack implements the http.Hijacker interface.
// This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it.
func (grw *guardedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hijacked, ok := grw.writer.(http.Hijacker); ok {
return hijacked.Hijack()
}
return nil, nil, errors.New("server doesn't support hijacking")
}
func (grw *guardedResponseWriter) Write(body []byte) (int, error) {
return grw.writer.Write(body)
}

View File

@@ -1,6 +1,8 @@
package handler
import (
"bufio"
"net"
"net/http"
"net/http/httptest"
"testing"
@@ -87,6 +89,26 @@ func TestAuthHandler_NilError(t *testing.T) {
})
}
func TestAuthHandler_Flush(t *testing.T) {
resp := httptest.NewRecorder()
handler := newGuardedResponseWriter(resp)
handler.Flush()
assert.True(t, resp.Flushed)
}
func TestAuthHandler_Hijack(t *testing.T) {
resp := httptest.NewRecorder()
writer := newGuardedResponseWriter(resp)
assert.NotPanics(t, func() {
writer.Hijack()
})
writer = newGuardedResponseWriter(mockedHijackable{resp})
assert.NotPanics(t, func() {
writer.Hijack()
})
}
func buildToken(secretKey string, payloads map[string]interface{}, seconds int64) (string, error) {
now := time.Now().Unix()
claims := make(jwt.MapClaims)
@@ -101,3 +123,11 @@ func buildToken(secretKey string, payloads map[string]interface{}, seconds int64
return token.SignedString([]byte(secretKey))
}
type mockedHijackable struct {
*httptest.ResponseRecorder
}
func (m mockedHijackable) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return nil, nil, nil
}

View File

@@ -1,11 +1,13 @@
package handler
import (
"bufio"
"bytes"
"encoding/base64"
"errors"
"io"
"io/ioutil"
"net"
"net/http"
"github.com/tal-tech/go-zero/core/codec"
@@ -94,6 +96,16 @@ func (w *cryptionResponseWriter) Header() http.Header {
return w.ResponseWriter.Header()
}
// Hijack implements the http.Hijacker interface.
// This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it.
func (w *cryptionResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hijacked, ok := w.ResponseWriter.(http.Hijacker); ok {
return hijacked.Hijack()
}
return nil, nil, errors.New("server doesn't support hijacking")
}
func (w *cryptionResponseWriter) Write(p []byte) (int, error) {
return w.buf.Write(p)
}

View File

@@ -103,3 +103,16 @@ func TestCryptionHandlerFlush(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String())
}
func TestCryptionHandler_Hijack(t *testing.T) {
resp := httptest.NewRecorder()
writer := newCryptionResponseWriter(resp)
assert.NotPanics(t, func() {
writer.Hijack()
})
writer = newCryptionResponseWriter(mockedHijackable{resp})
assert.NotPanics(t, func() {
writer.Hijack()
})
}

View File

@@ -1,10 +1,13 @@
package handler
import (
"bufio"
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/httputil"
"time"
@@ -25,10 +28,26 @@ type loggedResponseWriter struct {
code int
}
func (w *loggedResponseWriter) Flush() {
if flusher, ok := w.w.(http.Flusher); ok {
flusher.Flush()
}
}
func (w *loggedResponseWriter) Header() http.Header {
return w.w.Header()
}
// Hijack implements the http.Hijacker interface.
// This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it.
func (w *loggedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hijacked, ok := w.w.(http.Hijacker); ok {
return hijacked.Hijack()
}
return nil, nil, errors.New("server doesn't support hijacking")
}
func (w *loggedResponseWriter) Write(bytes []byte) (int, error) {
return w.w.Write(bytes)
}
@@ -38,12 +57,6 @@ func (w *loggedResponseWriter) WriteHeader(code int) {
w.code = code
}
func (w *loggedResponseWriter) Flush() {
if flusher, ok := w.w.(http.Flusher); ok {
flusher.Flush()
}
}
// LogHandler returns a middleware that logs http request and response.
func LogHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -83,6 +96,16 @@ func (w *detailLoggedResponseWriter) Header() http.Header {
return w.writer.Header()
}
// Hijack implements the http.Hijacker interface.
// This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it.
func (w *detailLoggedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hijacked, ok := w.writer.w.(http.Hijacker); ok {
return hijacked.Hijack()
}
return nil, nil, errors.New("server doesn't support hijacking")
}
func (w *detailLoggedResponseWriter) Write(bs []byte) (int, error) {
w.buf.Write(bs)
return w.writer.Write(bs)

View File

@@ -62,6 +62,44 @@ func TestLogHandlerSlow(t *testing.T) {
}
}
func TestLogHandler_Hijack(t *testing.T) {
resp := httptest.NewRecorder()
writer := &loggedResponseWriter{
w: resp,
}
assert.NotPanics(t, func() {
writer.Hijack()
})
writer = &loggedResponseWriter{
w: mockedHijackable{resp},
}
assert.NotPanics(t, func() {
writer.Hijack()
})
}
func TestDetailedLogHandler_Hijack(t *testing.T) {
resp := httptest.NewRecorder()
writer := &detailLoggedResponseWriter{
writer: &loggedResponseWriter{
w: resp,
},
}
assert.NotPanics(t, func() {
writer.Hijack()
})
writer = &detailLoggedResponseWriter{
writer: &loggedResponseWriter{
w: mockedHijackable{resp},
},
}
assert.NotPanics(t, func() {
writer.Hijack()
})
}
func BenchmarkLogHandler(b *testing.B) {
b.ReportAllocs()

View File

@@ -6,6 +6,7 @@ import (
"time"
"github.com/tal-tech/go-zero/core/metric"
"github.com/tal-tech/go-zero/core/prometheus"
"github.com/tal-tech/go-zero/core/timex"
"github.com/tal-tech/go-zero/rest/internal/security"
)
@@ -34,6 +35,10 @@ var (
// PrometheusHandler returns a middleware that reports stats to prometheus.
func PrometheusHandler(path string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
if !prometheus.Enabled() {
return next
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
startTime := timex.Now()
cw := &security.WithCodeResponseWriter{Writer: w}

View File

@@ -6,9 +6,26 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/prometheus"
)
func TestPromMetricHandler(t *testing.T) {
func TestPromMetricHandler_Disabled(t *testing.T) {
promMetricHandler := PrometheusHandler("/user/login")
handler := promMetricHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
}
func TestPromMetricHandler_Enabled(t *testing.T) {
prometheus.StartAgent(prometheus.Config{
Host: "localhost",
Path: "/",
})
promMetricHandler := PrometheusHandler("/user/login")
handler := promMetricHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)

View File

@@ -94,8 +94,7 @@ func (s mockShedder) Allow() (load.Promise, error) {
return nil, load.ErrServiceOverloaded
}
type mockPromise struct {
}
type mockPromise struct{}
func (p mockPromise) Pass() {
}

View File

@@ -1,6 +1,10 @@
package security
import "net/http"
import (
"bufio"
"net"
"net/http"
)
// A WithCodeResponseWriter is a helper to delay sealing a http.ResponseWriter on writing code.
type WithCodeResponseWriter struct {
@@ -20,6 +24,12 @@ func (w *WithCodeResponseWriter) Header() http.Header {
return w.Writer.Header()
}
// Hijack implements the http.Hijacker interface.
// This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it.
func (w *WithCodeResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return w.Writer.(http.Hijacker).Hijack()
}
// Write writes bytes into w.
func (w *WithCodeResponseWriter) Write(bytes []byte) (int, error) {
return w.Writer.Write(bytes)

View File

@@ -77,6 +77,8 @@ func (e *Server) AddRoute(r Route, opts ...RouteOption) {
}
// Start starts the Server.
// Graceful shutdown is enabled by default.
// Use proc.SetTimeToForceQuit to customize the graceful shutdown period.
func (e *Server) Start() {
handleError(e.opts.start(e.ngin))
}
@@ -108,7 +110,7 @@ func WithJwt(secret string) RouteOption {
}
// WithJwtTransition returns a func to enable jwt authentication as well as jwt secret transition.
// Which means old and new jwt secrets work together for a peroid.
// Which means old and new jwt secrets work together for a period.
func WithJwtTransition(secret, prevSecret string) RouteOption {
return func(r *featuredRoutes) {
// why not validate prevSecret, because prevSecret is an already used one,

View File

@@ -29,7 +29,7 @@ Future {{pathToFuncName .Path}}( {{if ne .Method "get"}}{{with .RequestType}}{{.
{{end}}`
func genApi(dir string, api *spec.ApiSpec) error {
err := os.MkdirAll(dir, 0755)
err := os.MkdirAll(dir, 0o755)
if err != nil {
return err
}
@@ -39,7 +39,7 @@ func genApi(dir string, api *spec.ApiSpec) error {
return err
}
file, err := os.OpenFile(dir+api.Service.Name+".dart", os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
file, err := os.OpenFile(dir+api.Service.Name+".dart", os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o644)
if err != nil {
return err
}
@@ -60,7 +60,7 @@ func genApiFile(dir string) error {
if fileExists(path) {
return nil
}
apiFile, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
apiFile, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o644)
if err != nil {
return err
}

View File

@@ -32,7 +32,7 @@ class {{.Name}}{
`
func genData(dir string, api *spec.ApiSpec) error {
err := os.MkdirAll(dir, 0755)
err := os.MkdirAll(dir, 0o755)
if err != nil {
return err
}
@@ -42,7 +42,7 @@ func genData(dir string, api *spec.ApiSpec) error {
return err
}
file, err := os.OpenFile(dir+api.Service.Name+".dart", os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
file, err := os.OpenFile(dir+api.Service.Name+".dart", os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o644)
if err != nil {
return err
}
@@ -64,7 +64,7 @@ func genTokens(dir string) error {
return nil
}
tokensFile, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
tokensFile, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o644)
if err != nil {
return err
}

View File

@@ -41,20 +41,20 @@ Future<Tokens> getTokens() async {
`
func genVars(dir string) error {
err := os.MkdirAll(dir, 0755)
err := os.MkdirAll(dir, 0o755)
if err != nil {
return err
}
if !fileExists(dir + "vars.dart") {
err = ioutil.WriteFile(dir+"vars.dart", []byte(`const serverHost='demo-crm.xiaoheiban.cn';`), 0644)
err = ioutil.WriteFile(dir+"vars.dart", []byte(`const serverHost='demo-crm.xiaoheiban.cn';`), 0o644)
if err != nil {
return err
}
}
if !fileExists(dir + "kv.dart") {
err = ioutil.WriteFile(dir+"kv.dart", []byte(varTemplate), 0644)
err = ioutil.WriteFile(dir+"kv.dart", []byte(varTemplate), 0o644)
if err != nil {
return err
}

View File

@@ -84,7 +84,7 @@ func buildDoc(route spec.Type) (string, error) {
return "", nil
}
var tps = make([]spec.Type, 0)
tps := make([]spec.Type, 0)
tps = append(tps, route)
if definedType, ok := route.(spec.DefineStruct); ok {
associatedTypes(definedType, &tps)
@@ -98,7 +98,7 @@ func buildDoc(route spec.Type) (string, error) {
}
func associatedTypes(tp spec.DefineStruct, tps *[]spec.Type) {
var hasAdded = false
hasAdded := false
for _, item := range *tps {
if item.Name() == tp.Name() {
hasAdded = true

View File

@@ -107,8 +107,8 @@ func apiFormat(data string) (string, error) {
var builder strings.Builder
s := bufio.NewScanner(strings.NewReader(data))
var tapCount = 0
var newLineCount = 0
tapCount := 0
newLineCount := 0
var preLine string
for s.Scan() {
line := strings.TrimSpace(s.Text())

View File

@@ -35,12 +35,12 @@ func genConfig(dir string, cfg *config.Config, api *spec.ApiSpec) error {
return err
}
var authNames = getAuths(api)
authNames := getAuths(api)
var auths []string
for _, item := range authNames {
auths = append(auths, fmt.Sprintf("%s %s", item, jwtTemplate))
}
var authImportStr = fmt.Sprintf("\"%s/rest\"", vars.ProjectOpenSourceURL)
authImportStr := fmt.Sprintf("\"%s/rest\"", vars.ProjectOpenSourceURL)
return genFile(fileGenConfig{
dir: dir,

View File

@@ -31,7 +31,7 @@ func (m *{{.name}})Handle(next http.HandlerFunc) http.HandlerFunc {
`
func genMiddleware(dir string, cfg *config.Config, api *spec.ApiSpec) error {
var middlewares = getMiddleware(api)
middlewares := getMiddleware(api)
for _, item := range middlewares {
middlewareFilename := strings.TrimSuffix(strings.ToLower(item), "middleware") + "_middleware"
filename, err := format.FileNamingFormat(cfg.NamingFormat, middlewareFilename)

View File

@@ -95,11 +95,11 @@ func genRoutes(dir string, cfg *config.Config, api *spec.ApiSpec) error {
var routes string
if len(g.middlewares) > 0 {
gbuilder.WriteString("\n}...,")
var params = g.middlewares
params := g.middlewares
for i := range params {
params[i] = "serverCtx." + params[i]
}
var middlewareStr = strings.Join(params, ", ")
middlewareStr := strings.Join(params, ", ")
routes = fmt.Sprintf("rest.WithMiddlewares(\n[]rest.Middleware{ %s }, \n %s \n),",
middlewareStr, strings.TrimSpace(gbuilder.String()))
} else {
@@ -146,7 +146,7 @@ func genRoutes(dir string, cfg *config.Config, api *spec.ApiSpec) error {
}
func genRouteImports(parentPkg string, api *spec.ApiSpec) string {
var importSet = collection.NewSet()
importSet := collection.NewSet()
importSet.AddStr(fmt.Sprintf("\"%s\"", util.JoinPackages(parentPkg, contextDir)))
for _, group := range api.Service.Groups {
for _, route := range group.Routes {

View File

@@ -39,7 +39,7 @@ func genServiceContext(dir string, cfg *config.Config, api *spec.ApiSpec) error
return err
}
var authNames = getAuths(api)
authNames := getAuths(api)
var auths []string
for _, item := range authNames {
auths = append(auths, fmt.Sprintf("%s config.AuthConfig", item))
@@ -52,7 +52,7 @@ func genServiceContext(dir string, cfg *config.Config, api *spec.ApiSpec) error
var middlewareStr string
var middlewareAssignment string
var middlewares = getMiddleware(api)
middlewares := getMiddleware(api)
for _, item := range middlewares {
middlewareStr += fmt.Sprintf("%s rest.Middleware\n", item)
@@ -61,7 +61,7 @@ func genServiceContext(dir string, cfg *config.Config, api *spec.ApiSpec) error
fmt.Sprintf("middleware.New%s().%s", strings.Title(name), "Handle"))
}
var configImport = "\"" + ctlutil.JoinPackages(parentPkg, configDir) + "\""
configImport := "\"" + ctlutil.JoinPackages(parentPkg, configDir) + "\""
if len(middlewareStr) > 0 {
configImport += "\n\t\"" + ctlutil.JoinPackages(parentPkg, middlewareDir) + "\""
configImport += fmt.Sprintf("\n\t\"%s/rest\"", vars.ProjectOpenSourceURL)

View File

@@ -103,17 +103,9 @@ func genComponents(dir, packetName string, api *spec.ApiSpec) error {
}
func (c *componentsContext) createComponent(dir, packetName string, ty spec.Type) error {
defineStruct, ok := ty.(spec.DefineStruct)
if !ok {
return errors.New("unsupported type %s" + ty.Name())
}
for _, item := range c.requestTypes {
if item.Name() == defineStruct.Name() {
if len(defineStruct.GetFormMembers())+len(defineStruct.GetBodyMembers()) == 0 {
return nil
}
}
defineStruct, done, err := c.checkStruct(ty)
if done {
return err
}
modelFile := util.Title(ty.Name()) + ".java"
@@ -181,6 +173,22 @@ func (c *componentsContext) createComponent(dir, packetName string, ty spec.Type
return err
}
func (c *componentsContext) checkStruct(ty spec.Type) (spec.DefineStruct, bool, error) {
defineStruct, ok := ty.(spec.DefineStruct)
if !ok {
return spec.DefineStruct{}, true, errors.New("unsupported type %s" + ty.Name())
}
for _, item := range c.requestTypes {
if item.Name() == defineStruct.Name() {
if len(defineStruct.GetFormMembers())+len(defineStruct.GetBodyMembers()) == 0 {
return spec.DefineStruct{}, true, nil
}
}
}
return defineStruct, false, nil
}
func (c *componentsContext) buildProperties(defineStruct spec.DefineStruct) (string, error) {
var builder strings.Builder
if err := c.writeType(&builder, defineStruct); err != nil {
@@ -269,15 +277,15 @@ func (c *componentsContext) buildConstructor() (string, string, error) {
}
func (c *componentsContext) genGetSet(writer io.Writer, indent int) error {
var members = c.members
members := c.members
for _, member := range members {
javaType, err := specTypeToJava(member.Type)
if err != nil {
return nil
}
var property = util.Title(member.Name)
var templateStr = getSetTemplate
property := util.Title(member.Name)
templateStr := getSetTemplate
if javaType == "boolean" {
templateStr = boolTemplate
property = strings.TrimPrefix(property, "Is")

View File

@@ -67,7 +67,7 @@ func createWith(dir string, api *spec.ApiSpec, route spec.Route, packetName stri
}
defer fp.Close()
var hasRequestBody = false
hasRequestBody := false
if route.RequestType != nil {
if defineStruct, ok := route.RequestType.(spec.DefineStruct); ok {
hasRequestBody = len(defineStruct.GetBodyMembers()) > 0 || len(defineStruct.GetFormMembers()) > 0

View File

@@ -61,7 +61,7 @@ func writeIndent(writer io.Writer, indent int) {
}
func indentString(indent int) string {
var result = ""
result := ""
for i := 0; i < indent; i++ {
result += "\t"
}
@@ -95,17 +95,9 @@ func specTypeToJava(tp spec.Type) (string, error) {
return "", err
}
switch valueType {
case "int":
return "Integer[]", nil
case "long":
return "Long[]", nil
case "float":
return "Float[]", nil
case "double":
return "Double[]", nil
case "boolean":
return "Boolean[]", nil
s := getBaseType(valueType)
if len(s) == 0 {
return s, errors.New("unsupported primitive type " + tp.Name())
}
return fmt.Sprintf("java.util.ArrayList<%s>", util.Title(valueType)), nil
@@ -118,6 +110,23 @@ func specTypeToJava(tp spec.Type) (string, error) {
return "", errors.New("unsupported primitive type " + tp.Name())
}
func getBaseType(valueType string) string {
switch valueType {
case "int":
return "Integer[]"
case "long":
return "Long[]"
case "float":
return "Float[]"
case "double":
return "Double[]"
case "boolean":
return "Boolean[]"
default:
return ""
}
}
func primitiveType(tp string) (string, bool) {
switch tp {
case "string":

View File

@@ -98,7 +98,7 @@ object {{with .Info}}{{.Title}}{{end}}{
)
func genBase(dir, pkg string, api *spec.ApiSpec) error {
e := os.MkdirAll(dir, 0755)
e := os.MkdirAll(dir, 0o755)
if e != nil {
return e
}
@@ -108,7 +108,7 @@ func genBase(dir, pkg string, api *spec.ApiSpec) error {
return nil
}
file, e := os.OpenFile(path, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
file, e := os.OpenFile(path, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o644)
if e != nil {
return e
}
@@ -146,12 +146,12 @@ func genApi(dir, pkg string, api *spec.ApiSpec) error {
api.Info.Title = name
api.Info.Desc = desc
e := os.MkdirAll(dir, 0755)
e := os.MkdirAll(dir, 0o755)
if e != nil {
return e
}
file, e := os.OpenFile(path, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0644)
file, e := os.OpenFile(path, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0o644)
if e != nil {
return e
}

View File

@@ -33,7 +33,7 @@ func (v *ApiVisitor) VisitApi(ctx *api.ApiContext) interface{} {
for _, each := range ctx.AllSpec() {
root := each.Accept(v).(*Api)
v.acceptSyntax(root, &final)
v.accetpImport(root, &final)
v.acceptImport(root, &final)
v.acceptInfo(root, &final)
v.acceptType(root, &final)
v.acceptService(root, &final)
@@ -133,7 +133,7 @@ func (v *ApiVisitor) acceptInfo(root *Api, final *Api) {
}
}
func (v *ApiVisitor) accetpImport(root *Api, final *Api) {
func (v *ApiVisitor) acceptImport(root *Api, final *Api) {
for _, imp := range root.Import {
if _, ok := final.importM[imp.Value.Text()]; ok {
v.panic(imp.Import, fmt.Sprintf("duplicate import '%s'", imp.Value.Text()))

View File

@@ -6,9 +6,9 @@ import (
"path/filepath"
"strings"
"github.com/antlr/antlr4/runtime/Go/antlr"
"github.com/tal-tech/go-zero/tools/goctl/api/parser/g4/gen/api"
"github.com/tal-tech/go-zero/tools/goctl/util/console"
"github.com/zeromicro/antlr"
)
type (
@@ -175,7 +175,7 @@ func (p *Parser) valid(mainApi *Api, nestedApi *Api) error {
for _, g := range list {
handler := g.GetHandler()
if handler.IsNotNil() {
var handlerName = handler.Text()
handlerName := handler.Text()
handlerMap[handlerName] = Holder
path := fmt.Sprintf("%s://%s", g.Route.Method.Text(), g.Route.Path.Text())
routeMap[path] = Holder

View File

@@ -5,9 +5,9 @@ import (
"sort"
"strings"
"github.com/antlr/antlr4/runtime/Go/antlr"
"github.com/tal-tech/go-zero/tools/goctl/api/parser/g4/gen/api"
"github.com/tal-tech/go-zero/tools/goctl/util/console"
"github.com/zeromicro/antlr"
)
type (
@@ -289,7 +289,7 @@ func (v *ApiVisitor) getHiddenTokensToLeft(t TokenStream, channel int, containsC
if index > 0 {
allTokens := ct.GetAllTokens()
var flag = false
flag := false
for i := index; i >= 0; i-- {
tk := allTokens[i]
if tk.GetChannel() == antlr.LexerDefaultTokenChannel {

View File

@@ -50,7 +50,7 @@ type AtDoc struct {
Kv []*KvExpr
}
// AtHandler describes service hander ast for api syntax
// AtHandler describes service handler ast for api syntax
type AtHandler struct {
AtHandlerToken Expr
Name Expr
@@ -630,7 +630,7 @@ func (s *Service) Equal(v interface{}) bool {
return s.ServiceApi.Equal(service.ServiceApi)
}
// Get returns the tergate KV by specified key
// Get returns the target KV by specified key
func (kv KV) Get(key string) Expr {
for _, each := range kv {
if each.Key.Text() == key {

View File

@@ -17,7 +17,7 @@ type (
NameExpr() Expr
}
// TypeAlias describes alias ast for api syatax
// TypeAlias describes alias ast for api syntax
TypeAlias struct {
Name Expr
Assign Expr
@@ -26,7 +26,7 @@ type (
CommentExpr Expr
}
// TypeStruct describes structure ast for api syatax
// TypeStruct describes structure ast for api syntax
TypeStruct struct {
Name Expr
Struct Expr
@@ -128,7 +128,6 @@ func (v *ApiVisitor) VisitTypeBlock(ctx *api.TypeBlockContext) interface{} {
var types []TypeExpr
for _, each := range list {
types = append(types, each.Accept(v).(TypeExpr))
}
return types
}
@@ -155,7 +154,6 @@ func (v *ApiVisitor) VisitTypeStruct(ctx *api.TypeStructContext) interface{} {
st.Name = v.newExprWithToken(ctx.GetStructName())
if util.UnExport(ctx.GetStructName().GetText()) {
}
if ctx.GetStructToken() != nil {
structExpr := v.newExprWithToken(ctx.GetStructToken())
@@ -225,7 +223,7 @@ func (v *ApiVisitor) VisitTypeBlockAlias(ctx *api.TypeBlockAliasContext) interfa
alias.DocExpr = v.getDoc(ctx)
alias.CommentExpr = v.getComment(ctx)
// todo: reopen if necessary
v.panic(alias.Name, "unsupport alias")
v.panic(alias.Name, "unsupported alias")
return &alias
}
@@ -238,7 +236,7 @@ func (v *ApiVisitor) VisitTypeAlias(ctx *api.TypeAliasContext) interface{} {
alias.DocExpr = v.getDoc(ctx)
alias.CommentExpr = v.getComment(ctx)
// todo: reopen if necessary
v.panic(alias.Name, "unsupport alias")
v.panic(alias.Name, "unsupported alias")
return &alias
}
@@ -319,7 +317,7 @@ func (v *ApiVisitor) VisitDataType(ctx *api.DataTypeContext) interface{} {
if ctx.GetTime() != nil {
// todo: reopen if it is necessary
timeExpr := v.newExprWithToken(ctx.GetTime())
v.panic(timeExpr, "unsupport time.Time")
v.panic(timeExpr, "unsupported time.Time")
return &Time{Literal: timeExpr}
}
if ctx.PointerType() != nil {

View File

@@ -1,7 +1,7 @@
// Code generated from tools/goctl/api/parser/g4/ApiParser.g4 by ANTLR 4.9. DO NOT EDIT.
package api // ApiParser
import "github.com/antlr/antlr4/runtime/Go/antlr"
import "github.com/zeromicro/antlr"
type BaseApiParserVisitor struct {
*antlr.BaseParseTreeVisitor

View File

@@ -6,12 +6,14 @@ import (
"fmt"
"unicode"
"github.com/antlr/antlr4/runtime/Go/antlr"
"github.com/zeromicro/antlr"
)
// Suppress unused import error
var _ = fmt.Printf
var _ = unicode.IsLetter
var (
_ = fmt.Printf
_ = unicode.IsLetter
)
var serializedLexerAtn = []uint16{
3, 24715, 42794, 33075, 47597, 16764, 15335, 30598, 22884, 2, 25, 266,

File diff suppressed because it is too large Load Diff

View File

@@ -1,7 +1,7 @@
// Code generated from tools/goctl/api/parser/g4/ApiParser.g4 by ANTLR 4.9. DO NOT EDIT.
package api // ApiParser
import "github.com/antlr/antlr4/runtime/Go/antlr"
import "github.com/zeromicro/antlr"
// A complete Visitor for a parse tree produced by ApiParserParser.
type ApiParserVisitor interface {

View File

@@ -7,7 +7,7 @@ import (
"strings"
"unicode"
"github.com/antlr/antlr4/runtime/Go/antlr"
"github.com/zeromicro/antlr"
)
const (
@@ -16,28 +16,30 @@ const (
tagRegex = `(?m)\x60[a-z]+:".+"\x60`
)
var holder = struct{}{}
var kind = map[string]struct{}{
"bool": holder,
"int": holder,
"int8": holder,
"int16": holder,
"int32": holder,
"int64": holder,
"uint": holder,
"uint8": holder,
"uint16": holder,
"uint32": holder,
"uint64": holder,
"uintptr": holder,
"float32": holder,
"float64": holder,
"complex64": holder,
"complex128": holder,
"string": holder,
"byte": holder,
"rune": holder,
}
var (
holder = struct{}{}
kind = map[string]struct{}{
"bool": holder,
"int": holder,
"int8": holder,
"int16": holder,
"int32": holder,
"int64": holder,
"uint": holder,
"uint8": holder,
"uint16": holder,
"uint32": holder,
"uint64": holder,
"uintptr": holder,
"float32": holder,
"float64": holder,
"complex64": holder,
"complex128": holder,
"string": holder,
"byte": holder,
"rune": holder,
}
)
func match(p *ApiParserParser, text string) {
v := getCurrentTokenText(p)

View File

@@ -145,7 +145,6 @@ line"`),
},
},
}))
})
t.Run("mismatched", func(t *testing.T) {

View File

@@ -120,7 +120,8 @@ func TestRoute(t *testing.T) {
PointerExpr: ast.NewTextExpr("*Bar"),
Star: ast.NewTextExpr("*"),
Name: ast.NewTextExpr("Bar"),
}},
},
},
},
}))
@@ -224,7 +225,6 @@ func TestAtHandler(t *testing.T) {
_, err = parser.Accept(fn, `@handler "foo"`)
assert.Error(t, err)
})
}
func TestAtDoc(t *testing.T) {

View File

@@ -64,7 +64,7 @@ func (p parser) convert2Spec() error {
}
func (p parser) fillInfo() {
var properties = make(map[string]string, 0)
properties := make(map[string]string, 0)
if p.ast.Info != nil {
p.spec.Info = spec.Info{}
for _, kv := range p.ast.Info.Kvs {
@@ -147,8 +147,8 @@ func (p parser) findDefinedType(name string) (*spec.Type, error) {
}
func (p parser) fieldToMember(field *ast.TypeField) spec.Member {
var name = ""
var tag = ""
name := ""
tag := ""
if !field.IsAnonymous {
name = field.Name.Text()
if field.Tag == nil {
@@ -219,9 +219,9 @@ func (p parser) fillService() error {
for _, astRoute := range item.ServiceApi.ServiceRoute {
route := spec.Route{
Annotation: spec.Annotation{},
Method: astRoute.Route.Method.Text(),
Path: astRoute.Route.Path.Text(),
AtServerAnnotation: spec.Annotation{},
Method: astRoute.Route.Method.Text(),
Path: astRoute.Route.Path.Text(),
}
if astRoute.AtHandler != nil {
route.Handler = astRoute.AtHandler.Name.Text()
@@ -239,7 +239,7 @@ func (p parser) fillService() error {
route.ResponseType = p.astTypeToSpec(astRoute.Route.Reply.Name)
}
if astRoute.AtDoc != nil {
var properties = make(map[string]string, 0)
properties := make(map[string]string, 0)
for _, kv := range astRoute.AtDoc.Kv {
properties[kv.Key.Text()] = kv.Value.Text()
}
@@ -271,11 +271,11 @@ func (p parser) fillService() error {
func (p parser) fillRouteAtServer(astRoute *ast.ServiceRoute, route *spec.Route) error {
if astRoute.AtServer != nil {
var properties = make(map[string]string, 0)
properties := make(map[string]string, 0)
for _, kv := range astRoute.AtServer.Kv {
properties[kv.Key.Text()] = kv.Value.Text()
}
route.Annotation.Properties = properties
route.AtServerAnnotation.Properties = properties
if len(route.Handler) == 0 {
route.Handler = properties["handler"]
}
@@ -295,7 +295,7 @@ func (p parser) fillRouteAtServer(astRoute *ast.ServiceRoute, route *spec.Route)
func (p parser) fillAtServer(item *ast.Service, group *spec.Group) {
if item.AtServer != nil {
var properties = make(map[string]string, 0)
properties := make(map[string]string, 0)
for _, kv := range item.AtServer.Kv {
properties[kv.Key.Text()] = kv.Value.Text()
}

View File

@@ -11,10 +11,11 @@ import (
const (
bodyTagKey = "json"
formTagKey = "form"
pathTagKey = "path"
defaultSummaryKey = "summary"
)
var definedKeys = []string{bodyTagKey, formTagKey, "path"}
var definedKeys = []string{bodyTagKey, formTagKey, pathTagKey}
// Routes returns all routes in api service
func (s Service) Routes() []Route {
@@ -25,7 +26,7 @@ func (s Service) Routes() []Route {
return result
}
// Tags retuens all tags in Member
// Tags returns all tags in Member
func (m Member) Tags() []*Tag {
tags, err := Parse(m.Tag)
if err != nil {
@@ -141,7 +142,7 @@ func (t DefineStruct) GetFormMembers() []Member {
return result
}
// GetNonBodyMembers retruns all have no tag fields
// GetNonBodyMembers returns all have no tag fields
func (t DefineStruct) GetNonBodyMembers() []Member {
var result []Member
for _, member := range t.Members {
@@ -162,16 +163,16 @@ func (r Route) JoinedDoc() string {
return strings.TrimSpace(doc)
}
// GetAnnotation returns the value by specified key
// GetAnnotation returns the value by specified key from @server
func (r Route) GetAnnotation(key string) string {
if r.Annotation.Properties == nil {
if r.AtServerAnnotation.Properties == nil {
return ""
}
return r.Annotation.Properties[key]
return r.AtServerAnnotation.Properties[key]
}
// GetAnnotation returns the value by specified key
// GetAnnotation returns the value by specified key from @server
func (g Group) GetAnnotation(key string) string {
if g.Annotation.Properties == nil {
return ""

View File

@@ -63,14 +63,14 @@ type (
// Route describes api route
Route struct {
Annotation Annotation
Method string
Path string
RequestType Type
ResponseType Type
Docs Doc
Handler string
AtDoc AtDoc
AtServerAnnotation Annotation
Method string
Path string
RequestType Type
ResponseType Type
Docs Doc
Handler string
AtDoc AtDoc
}
// Service describes api service

Some files were not shown because too many files have changed in this diff Show More