Compare commits

...

94 Commits

Author SHA1 Message Date
kevin
fe855c52f1 avoid bigint converted into float64 when unmarshaling 2020-10-10 15:24:29 +08:00
kevin
3f8b080882 add more tests 2020-10-10 13:47:55 +08:00
kevin
adc275872d add more tests 2020-10-10 11:53:49 +08:00
kevin
be39133dba fix data race in tests 2020-10-09 19:13:10 +08:00
kingxt
15a9ab1d18 parser ad test (#116)
* rebase upstream

* rebase

* trim no need line

* trim no need line

* trim no need line

* update doc

* remove update

* remove no need

* remove no need

* goctl add jwt support

* goctl add jwt support

* goctl add jwt support

* goctl support import

* goctl support import

* support return ()

* revert

* refactor and rename folder to group

* parser add test

Co-authored-by: kingxt <dream4kingxt@163.com>
2020-10-09 16:03:00 +08:00
kevin
7c354dcc38 add more tests 2020-10-09 14:53:13 +08:00
kevin
3733b06f1b fix data race in tests 2020-10-09 14:15:27 +08:00
kevin
8115a0932e add more tests 2020-10-09 13:59:38 +08:00
kevin
4df5eb760c add more tests 2020-10-08 22:39:07 +08:00
kevin
4a639b853c add more tests 2020-10-08 09:42:20 +08:00
kevin
1023425c1d add more tests 2020-10-07 23:15:34 +08:00
kevin
360fbfd0fa add more tests 2020-10-07 23:02:58 +08:00
kevin
09b7625f06 add more tests 2020-10-07 22:54:51 +08:00
kevin
6db294b5cc add more tests 2020-10-07 19:33:52 +08:00
kevin
305b6749fd add more tests 2020-10-07 19:13:19 +08:00
kevin
10b855713d add more tests 2020-10-07 19:00:15 +08:00
kevin
1cc0f071d9 add more tests 2020-10-07 18:07:54 +08:00
kevin
02ce8f82c8 add more tests 2020-10-07 11:43:02 +08:00
kevin
8a585afbf0 add more tests 2020-10-07 11:19:10 +08:00
kevin
e356025cef add more tests 2020-10-07 08:11:20 +08:00
kevin
14dee114dd add more tests 2020-10-06 10:12:35 +08:00
kevin
637a94a189 add fx.Count 2020-10-05 18:17:59 +08:00
kevin
173b347c90 add more tests 2020-10-05 12:19:54 +08:00
kevin
6749c5b94a add more tests 2020-10-04 17:52:54 +08:00
刘青
e66cca3710 breaker: remover useless code (#114) 2020-10-04 16:25:26 +08:00
kevin
f90c0aa98e update wechat qrcode 2020-10-04 10:14:08 +08:00
kevin
f00b5416a3 update codecov settings 2020-10-03 23:09:29 +08:00
kevin
f49694d6b6 fix data race 2020-10-02 22:41:25 +08:00
kevin
d809bf2dca add more tests 2020-10-02 22:37:15 +08:00
kevin
44ae5463bc add more tests 2020-10-02 09:00:25 +08:00
kevin
40dbd722d7 add more tests 2020-10-01 23:29:49 +08:00
kevin
709574133b add more tests 2020-10-01 23:22:53 +08:00
kevin
cb1c593108 remove markdown linter 2020-10-01 21:11:19 +08:00
kevin
6ecf575c00 add more tests 2020-10-01 20:58:12 +08:00
kevin
b8fcdd5460 add more tests 2020-10-01 17:50:53 +08:00
kevin
ce42281568 add more tests 2020-10-01 17:27:21 +08:00
kevin
40230d79e7 fix data race 2020-10-01 16:58:07 +08:00
kevin
ba7851795b add more tests 2020-10-01 16:49:39 +08:00
kevin
096fe3bc47 add more tests 2020-10-01 11:57:06 +08:00
kevin
e37858295a add more tests 2020-10-01 11:49:17 +08:00
kevin
5a4afb1518 add more tests 2020-10-01 10:29:03 +08:00
kevin
63f1f39c40 fix int64 primary key problem 2020-09-30 22:25:47 +08:00
kevin
481895d1e4 add more tests 2020-09-30 17:47:56 +08:00
shenbaise9527
9e9ce3bf48 GetBreaker need double-check (#112) 2020-09-30 16:50:02 +08:00
kevin
0ce654968d add more tests 2020-09-30 15:36:13 +08:00
Percy Gauguin
2703493541 update: fix wrong word (#110) 2020-09-30 15:08:47 +08:00
janetyu
d4240cd4b0 perfect the bookstore and shorturl doc (#109)
* perfect the bookstore and shorturl doc

* 避免歧义
2020-09-30 14:22:37 +08:00
kevin
a22bcc84a3 better lock practice in sharedcalls 2020-09-30 12:31:35 +08:00
kevin
93f430a449 update shorturl doc 2020-09-29 17:36:00 +08:00
kevin
d1b303fe7e export cache package, add client interceptor customization 2020-09-29 17:25:49 +08:00
kevin
dbca20e3df add zrpc client interceptor 2020-09-29 16:09:11 +08:00
boob
b3ead4d76c doc: update sharedcalls.md layout (#107) 2020-09-29 14:32:17 +08:00
kevin
33a9db85c8 add unit test, fix interceptor bug 2020-09-29 14:30:22 +08:00
kingxt
e7d46aa6e2 refactor and rename folder to group (#106)
* rebase upstream

* rebase

* trim no need line

* trim no need line

* trim no need line

* update doc

* remove update

* remove no need

* remove no need

* goctl add jwt support

* goctl add jwt support

* goctl add jwt support

* goctl support import

* goctl support import

* support return ()

* support return ()

* revert

* format api

* refactor and rename folder to group

Co-authored-by: kingxt <dream4kingxt@163.com>
2020-09-29 11:14:52 +08:00
kevin
b282304054 add api doc link 2020-09-28 16:58:29 +08:00
bittoy
0a36031d48 use default mongo db (#103) 2020-09-28 16:35:07 +08:00
kevin
e5d7c3ab04 unmarshal should be struct 2020-09-28 15:19:30 +08:00
kevin
12c08bfd39 Revert "goreportcard not working, remove it temporarily"
This reverts commit 8f465fa439.
2020-09-28 11:41:23 +08:00
kevin
8f465fa439 goreportcard not working, remove it temporarily 2020-09-28 00:31:24 +08:00
kingxt
8a470bb6ee support return () syntax (#101)
* rebase upstream

* rebase

* trim no need line

* trim no need line

* trim no need line

* update doc

* remove update

* remove no need

* remove no need

* goctl add jwt support

* goctl add jwt support

* goctl add jwt support

* goctl support import

* goctl support import

* support return ()

* support return ()

* remove pwd for windows not support

* revert

* remove no need

Co-authored-by: kingxt <dream4kingxt@163.com>
2020-09-27 17:23:15 +08:00
kevin
9277ad77f7 fix typo of prometheus 2020-09-27 17:15:15 +08:00
kevin
a958400595 rename prommetric to prometheous, add unit tests 2020-09-27 16:14:16 +08:00
kevin
015716d1b5 update wechat and etcd yaml 2020-09-27 14:15:33 +08:00
kevin
54e9d01312 update example 2020-09-27 11:10:21 +08:00
kevin
bc831b75dd export AddOptions, AddStreamInterceptors, AddUnaryInterceptors 2020-09-26 22:05:57 +08:00
kevin
ff112fdaee query from cache first when do cache.Take 2020-09-26 21:58:46 +08:00
kingxt
8d0f7dbb27 rename (#98)
* rebase upstream

* rebase

* trim no need line

* trim no need line

* trim no need line

* update doc

* remove update

* remove no need

* remove no need

* goctl add jwt support

* goctl add jwt support

* goctl add jwt support

* goctl support import

* goctl support import

* rename

Co-authored-by: kingxt <dream4kingxt@163.com>
2020-09-24 10:31:49 +08:00
Keson
a5ce2c448e fix bug: module parse error (#97) 2020-09-23 22:10:25 +08:00
kevin
0dd8e27557 add more clear error when rpc service is not started 2020-09-23 22:07:26 +08:00
Zhang Hao
17a0908a84 add test (#95) 2020-09-22 19:15:30 +08:00
Keson
9f9c24cce9 fix bug: release empty struct limit (#96) 2020-09-22 19:13:46 +08:00
kingxt
b628bc0086 goctl support import api file (#94)
* rebase upstream

* rebase

* trim no need line

* trim no need line

* trim no need line

* update doc

* remove update

* remove no need

* remove no need

* goctl add jwt support

* goctl add jwt support

* goctl add jwt support

* goctl support import

* goctl support import

Co-authored-by: kingxt <dream4kingxt@163.com>
2020-09-22 18:32:26 +08:00
kevin
be9c48da7f add tracing logs in server side and client side 2020-09-22 17:34:39 +08:00
kevin
797a90ae7d remove unnecessary tag 2020-09-21 22:41:14 +08:00
kevin
92e60a5777 use options instead of opts in error message 2020-09-21 22:37:07 +08:00
miaogaolin
46995a4d7d 修改不能编辑代码注释 (#92)
* rename file and function name

* update comments of "code generate"
2020-09-21 18:27:35 +08:00
kingxt
5e6dcac734 feature: goctl jwt (#91)
* rebase upstream

* rebase

* trim no need line

* trim no need line

* trim no need line

* update doc

* remove update

* remove no need

* remove no need

* goctl add jwt support

* goctl add jwt support

* goctl add jwt support

* goctl add jwt support

* goctl add jwt support

* goctl add jwt support

* goctl add jwt support

Co-authored-by: kingxt <dream4kingxt@163.com>
2020-09-21 16:38:23 +08:00
dylanNew
3e7e466526 fix redis error (#88)
Co-authored-by: dylan <wangdi@xiaoheiban.cn>
2020-09-21 16:37:40 +08:00
kingxt
b6b8941a18 update doc (#90)
* rebase upstream

* rebase

* trim no need line

* trim no need line

* trim no need line

* update doc

* remove update

* remove no need

* remove no need

* update jwt doc

* update jwt doc

* update jwt doc

* update jwt doc

Co-authored-by: kingxt <dream4kingxt@163.com>
2020-09-21 16:09:02 +08:00
kingxt
878fd14739 remove no need (#87)
* rebase upstream

* rebase

* trim no need line

* trim no need line

* trim no need line

* update doc

* remove update

* remove no need

* remove no need

* add jwt doc

Co-authored-by: kingxt <dream4kingxt@163.com>
2020-09-21 14:29:12 +08:00
kevin
5e99f2b85d add trace/span in http logs 2020-09-20 22:02:45 +08:00
Howie
9c23399c33 chore: fix typos (#85)
* chore: fix typos

Signed-off-by: lihaowei <haoweili35@gmail.com>

* chore: fix 2 typos
2020-09-20 14:00:31 +08:00
kevin
86d3de4c89 use package level defined contextKey as context key 2020-09-20 12:46:35 +08:00
kevin
dc17855367 printing context key friendly 2020-09-20 12:08:30 +08:00
kevin
1606a92c6e use contextType as string type 2020-09-20 12:04:49 +08:00
mlboy
029fd3ea35 fix: golint: context.WithValue should should not use basic type as key (#83)
* fix: golint: context.WithValue should should not use basic type as key

* optimiz
2020-09-20 12:01:43 +08:00
kevin
57299a7597 rename ngin to rest in goctl 2020-09-20 09:15:19 +08:00
Changkun Ou
762af9dda2 optimize AtomicError (#82)
This commit optimize AtomicError using atomic.Value. Benchmarks:

name               old time/op  new time/op  delta
AtomicError/Load-6   305ns ±11%    12ns ± 6%  -96.18%  (p=0.000 n=10+10)
AtomicError/Set-6   314ns ±16%    14ns ± 2%  -95.61%  (p=0.000 n=10+9)
2020-09-18 22:45:01 +08:00
kevin
eccfaba614 update doc 2020-09-18 22:33:40 +08:00
kevin
974c19d6d3 update rpc example 2020-09-18 18:15:39 +08:00
Zhang Hao
0f8140031a fix rpc client examle (#81) 2020-09-18 18:07:08 +08:00
kevin
0b1ee79d3a rename rpcx to zrpc 2020-09-18 11:41:52 +08:00
Zhang Hao
26e16107ce fix example tracing edge config (#76) 2020-09-18 08:53:06 +08:00
kevin
1e5e9d63bd update wechat qrcode 2020-09-17 10:28:33 +08:00
254 changed files with 4192 additions and 847 deletions

View File

@@ -1,4 +1,4 @@
ignore: ignore:
- "doc" - "doc"
- "example" - "example"
- "tools" - "tools"

View File

@@ -1,6 +0,0 @@
{
"MD010": false,
"MD013": false,
"MD033": false,
"MD034": false
}

View File

@@ -13,15 +13,13 @@ const (
// maps as k in the error rate table // maps as k in the error rate table
maps = 14 maps = 14
setScript = ` setScript = `
local key = KEYS[1]
for _, offset in ipairs(ARGV) do for _, offset in ipairs(ARGV) do
redis.call("setbit", key, offset, 1) redis.call("setbit", KEYS[1], offset, 1)
end end
` `
testScript = ` testScript = `
local key = KEYS[1]
for _, offset in ipairs(ARGV) do for _, offset in ipairs(ARGV) do
if tonumber(redis.call("getbit", key, offset)) == 0 then if tonumber(redis.call("getbit", KEYS[1], offset)) == 0 then
return false return false
end end
end end

View File

@@ -13,11 +13,6 @@ import (
"github.com/tal-tech/go-zero/core/timex" "github.com/tal-tech/go-zero/core/timex"
) )
const (
StateClosed State = iota
StateOpen
)
const ( const (
numHistoryReasons = 5 numHistoryReasons = 5
timeFormat = "15:04:05" timeFormat = "15:04:05"
@@ -27,7 +22,6 @@ const (
var ErrServiceUnavailable = errors.New("circuit breaker is open") var ErrServiceUnavailable = errors.New("circuit breaker is open")
type ( type (
State = int32
Acceptable func(err error) bool Acceptable func(err error) bool
Breaker interface { Breaker interface {

View File

@@ -41,10 +41,13 @@ func GetBreaker(name string) Breaker {
} }
lock.Lock() lock.Lock()
defer lock.Unlock() b, ok = breakers[name]
if !ok {
b = NewBreaker(WithName(name))
breakers[name] = b
}
lock.Unlock()
b = NewBreaker()
breakers[name] = b
return b return b
} }
@@ -55,20 +58,5 @@ func NoBreakFor(name string) {
} }
func do(name string, execute func(b Breaker) error) error { func do(name string, execute func(b Breaker) error) error {
lock.RLock() return execute(GetBreaker(name))
b, ok := breakers[name]
lock.RUnlock()
if ok {
return execute(b)
}
lock.Lock()
b, ok = breakers[name]
if !ok {
b = NewBreaker(WithName(name))
breakers[name] = b
}
lock.Unlock()
return execute(b)
} }

View File

@@ -2,7 +2,6 @@ package breaker
import ( import (
"math" "math"
"sync/atomic"
"time" "time"
"github.com/tal-tech/go-zero/core/collection" "github.com/tal-tech/go-zero/core/collection"
@@ -21,7 +20,6 @@ const (
// see Client-Side Throttling section in https://landing.google.com/sre/sre-book/chapters/handling-overload/ // see Client-Side Throttling section in https://landing.google.com/sre/sre-book/chapters/handling-overload/
type googleBreaker struct { type googleBreaker struct {
k float64 k float64
state int32
stat *collection.RollingWindow stat *collection.RollingWindow
proba *mathx.Proba proba *mathx.Proba
} }
@@ -32,7 +30,6 @@ func newGoogleBreaker() *googleBreaker {
return &googleBreaker{ return &googleBreaker{
stat: st, stat: st,
k: k, k: k,
state: StateClosed,
proba: mathx.NewProba(), proba: mathx.NewProba(),
} }
} }
@@ -43,15 +40,9 @@ func (b *googleBreaker) accept() error {
// https://landing.google.com/sre/sre-book/chapters/handling-overload/#eq2101 // https://landing.google.com/sre/sre-book/chapters/handling-overload/#eq2101
dropRatio := math.Max(0, (float64(total-protection)-weightedAccepts)/float64(total+1)) dropRatio := math.Max(0, (float64(total-protection)-weightedAccepts)/float64(total+1))
if dropRatio <= 0 { if dropRatio <= 0 {
if atomic.LoadInt32(&b.state) == StateOpen {
atomic.CompareAndSwapInt32(&b.state, StateOpen, StateClosed)
}
return nil return nil
} }
if atomic.LoadInt32(&b.state) == StateClosed {
atomic.CompareAndSwapInt32(&b.state, StateClosed, StateOpen)
}
if b.proba.TrueOnProba(dropRatio) { if b.proba.TrueOnProba(dropRatio) {
return ErrServiceUnavailable return ErrServiceUnavailable
} }

View File

@@ -27,7 +27,6 @@ func getGoogleBreaker() *googleBreaker {
return &googleBreaker{ return &googleBreaker{
stat: st, stat: st,
k: 5, k: 5,
state: StateClosed,
proba: mathx.NewProba(), proba: mathx.NewProba(),
} }
} }

View File

@@ -0,0 +1,80 @@
package cmdline
import (
"fmt"
"os"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/iox"
"github.com/tal-tech/go-zero/core/lang"
)
func TestEnterToContinue(t *testing.T) {
restore, err := iox.RedirectInOut()
assert.Nil(t, err)
defer restore()
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
fmt.Println()
}()
go func() {
defer wg.Done()
EnterToContinue()
}()
wait := make(chan lang.PlaceholderType)
go func() {
wg.Wait()
close(wait)
}()
select {
case <-time.After(time.Second):
t.Error("timeout")
case <-wait:
}
}
func TestReadLine(t *testing.T) {
r, w, err := os.Pipe()
assert.Nil(t, err)
ow := os.Stdout
os.Stdout = w
or := os.Stdin
os.Stdin = r
defer func() {
os.Stdin = or
os.Stdout = ow
}()
const message = "hello"
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
fmt.Println(message)
}()
go func() {
defer wg.Done()
input := ReadLine("")
assert.Equal(t, message, input)
}()
wait := make(chan lang.PlaceholderType)
go func() {
wg.Wait()
close(wait)
}()
select {
case <-time.After(time.Second):
t.Error("timeout")
case <-wait:
}
}

View File

@@ -71,3 +71,12 @@ func TestDiffieHellmanMiddleManAttack(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, string(src), string(decryptedSrc)) assert.Equal(t, string(src), string(decryptedSrc))
} }
func TestKeyBytes(t *testing.T) {
var empty DhKey
assert.Equal(t, 0, len(empty.Bytes()))
key, err := GenerateKey()
assert.Nil(t, err)
assert.True(t, len(key.Bytes()) > 0)
}

19
core/codec/hmac_test.go Normal file
View File

@@ -0,0 +1,19 @@
package codec
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
)
func TestHmac(t *testing.T) {
ret := Hmac([]byte("foo"), "bar")
assert.Equal(t, "f9320baf0249169e73850cd6156ded0106e2bb6ad8cab01b7bbbebe6d1065317",
fmt.Sprintf("%x", ret))
}
func TestHmacBase64(t *testing.T) {
ret := HmacBase64([]byte("foo"), "bar")
assert.Equal(t, "+TILrwJJFp5zhQzWFW3tAQbiu2rYyrAbe7vr5tEGUxc=", ret)
}

59
core/codec/rsa_test.go Normal file
View File

@@ -0,0 +1,59 @@
package codec
import (
"encoding/base64"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/fs"
)
const (
priKey = `-----BEGIN RSA PRIVATE KEY-----
MIICXQIBAAKBgQC4TJk3onpqb2RYE3wwt23J9SHLFstHGSkUYFLe+nl1dEKHbD+/
Zt95L757J3xGTrwoTc7KCTxbrgn+stn0w52BNjj/kIE2ko4lbh/v8Fl14AyVR9ms
fKtKOnhe5FCT72mdtApr+qvzcC3q9hfXwkyQU32pv7q5UimZ205iKSBmgQIDAQAB
AoGAM5mWqGIAXj5z3MkP01/4CDxuyrrGDVD5FHBno3CDgyQa4Gmpa4B0/ywj671B
aTnwKmSmiiCN2qleuQYASixes2zY5fgTzt+7KNkl9JHsy7i606eH2eCKzsUa/s6u
WD8V3w/hGCQ9zYI18ihwyXlGHIgcRz/eeRh+nWcWVJzGOPUCQQD5nr6It/1yHb1p
C6l4fC4xXF19l4KxJjGu1xv/sOpSx0pOqBDEX3Mh//FU954392rUWDXV1/I65BPt
TLphdsu3AkEAvQJ2Qay/lffFj9FaUrvXuftJZ/Ypn0FpaSiUh3Ak3obBT6UvSZS0
bcYdCJCNHDtBOsWHnIN1x+BcWAPrdU7PhwJBAIQ0dUlH2S3VXnoCOTGc44I1Hzbj
Rc65IdsuBqA3fQN2lX5vOOIog3vgaFrOArg1jBkG1wx5IMvb/EnUN2pjVqUCQCza
KLXtCInOAlPemlCHwumfeAvznmzsWNdbieOZ+SXVVIpR6KbNYwOpv7oIk3Pfm9sW
hNffWlPUKhW42Gc+DIECQQDmk20YgBXwXWRM5DRPbhisIV088N5Z58K9DtFWkZsd
OBDT3dFcgZONtlmR1MqZO0pTh30lA4qovYj3Bx7A8i36
-----END RSA PRIVATE KEY-----`
pubKey = `-----BEGIN PUBLIC KEY-----
MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQC4TJk3onpqb2RYE3wwt23J9SHL
FstHGSkUYFLe+nl1dEKHbD+/Zt95L757J3xGTrwoTc7KCTxbrgn+stn0w52BNjj/
kIE2ko4lbh/v8Fl14AyVR9msfKtKOnhe5FCT72mdtApr+qvzcC3q9hfXwkyQU32p
v7q5UimZ205iKSBmgQIDAQAB
-----END PUBLIC KEY-----`
testBody = `this is the content`
encryptedBody = `49e7bc15640e5d927fd3f129b749536d0755baf03a0f35fc914ff1b7b8ce659e5fe3a598442eb908c5995e28bacd3d76e4420bb05b6bfc177040f66c6976f680f7123505d626ab96a9db1151f45c93bc0262db9087b9fb6801715f76f902e644a20029262858f05b0d10540842204346ac1d6d8f29cc5d47dab79af75d922ef2`
)
func TestCryption(t *testing.T) {
enc, err := NewRsaEncrypter([]byte(pubKey))
assert.Nil(t, err)
ret, err := enc.Encrypt([]byte(testBody))
assert.Nil(t, err)
file, err := fs.TempFilenameWithText(priKey)
assert.Nil(t, err)
dec, err := NewRsaDecrypter(file)
assert.Nil(t, err)
actual, err := dec.Decrypt(ret)
assert.Nil(t, err)
assert.Equal(t, testBody, string(actual))
actual, err = dec.DecryptBase64(base64.StdEncoding.EncodeToString(ret))
assert.Nil(t, err)
assert.Equal(t, testBody, string(actual))
}
func TestBadPubKey(t *testing.T) {
_, err := NewRsaEncrypter([]byte("foo"))
assert.Equal(t, ErrPublicKey, err)
}

View File

@@ -82,12 +82,7 @@ func (c *Cache) Del(key string) {
} }
func (c *Cache) Get(key string) (interface{}, bool) { func (c *Cache) Get(key string) (interface{}, bool) {
c.lock.Lock() value, ok := c.doGet(key)
value, ok := c.data[key]
if ok {
c.lruCache.add(key)
}
c.lock.Unlock()
if ok { if ok {
c.stats.IncrementHit() c.stats.IncrementHit()
} else { } else {
@@ -113,12 +108,25 @@ func (c *Cache) Set(key string, value interface{}) {
} }
func (c *Cache) Take(key string, fetch func() (interface{}, error)) (interface{}, error) { func (c *Cache) Take(key string, fetch func() (interface{}, error)) (interface{}, error) {
val, fresh, err := c.barrier.DoEx(key, func() (interface{}, error) { if val, ok := c.doGet(key); ok {
c.stats.IncrementHit()
return val, nil
}
var fresh bool
val, err := c.barrier.Do(key, func() (interface{}, error) {
// because O(1) on map search in memory, and fetch is an IO query
// so we do double check, cache might be taken by another call
if val, ok := c.doGet(key); ok {
return val, nil
}
v, e := fetch() v, e := fetch()
if e != nil { if e != nil {
return nil, e return nil, e
} }
fresh = true
c.Set(key, v) c.Set(key, v)
return v, nil return v, nil
}) })
@@ -137,6 +145,18 @@ func (c *Cache) Take(key string, fetch func() (interface{}, error)) (interface{}
return val, nil return val, nil
} }
func (c *Cache) doGet(key string) (interface{}, bool) {
c.lock.Lock()
defer c.lock.Unlock()
value, ok := c.data[key]
if ok {
c.lruCache.add(key)
}
return value, ok
}
func (c *Cache) onEvict(key string) { func (c *Cache) onEvict(key string) {
// already locked // already locked
delete(c.data, key) delete(c.data, key)

View File

@@ -1,6 +1,7 @@
package collection package collection
import ( import (
"errors"
"strconv" "strconv"
"sync" "sync"
"sync/atomic" "sync/atomic"
@@ -10,6 +11,8 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
var errDummy = errors.New("dummy")
func TestCacheSet(t *testing.T) { func TestCacheSet(t *testing.T) {
cache, err := NewCache(time.Second*2, WithName("any")) cache, err := NewCache(time.Second*2, WithName("any"))
assert.Nil(t, err) assert.Nil(t, err)
@@ -63,6 +66,54 @@ func TestCacheTake(t *testing.T) {
assert.Equal(t, int32(1), atomic.LoadInt32(&count)) assert.Equal(t, int32(1), atomic.LoadInt32(&count))
} }
func TestCacheTakeExists(t *testing.T) {
cache, err := NewCache(time.Second * 2)
assert.Nil(t, err)
var count int32
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
cache.Set("first", "first element")
cache.Take("first", func() (interface{}, error) {
atomic.AddInt32(&count, 1)
time.Sleep(time.Millisecond * 100)
return "first element", nil
})
wg.Done()
}()
}
wg.Wait()
assert.Equal(t, 1, cache.size())
assert.Equal(t, int32(0), atomic.LoadInt32(&count))
}
func TestCacheTakeError(t *testing.T) {
cache, err := NewCache(time.Second * 2)
assert.Nil(t, err)
var count int32
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
_, err := cache.Take("first", func() (interface{}, error) {
atomic.AddInt32(&count, 1)
time.Sleep(time.Millisecond * 100)
return "", errDummy
})
assert.Equal(t, errDummy, err)
wg.Done()
}()
}
wg.Wait()
assert.Equal(t, 0, cache.size())
assert.Equal(t, int32(1), atomic.LoadInt32(&count))
}
func TestCacheWithLruEvicts(t *testing.T) { func TestCacheWithLruEvicts(t *testing.T) {
cache, err := NewCache(time.Minute, WithLimit(3)) cache, err := NewCache(time.Minute, WithLimit(3))
assert.Nil(t, err) assert.Nil(t, err)

View File

@@ -15,6 +15,7 @@ const (
stringType stringType
) )
// Set is not thread-safe, for concurrent use, make sure to use it with synchronization.
type Set struct { type Set struct {
data map[interface{}]lang.PlaceholderType data map[interface{}]lang.PlaceholderType
tp int tp int
@@ -182,10 +183,7 @@ func (s *Set) add(i interface{}) {
} }
func (s *Set) setType(i interface{}) { func (s *Set) setType(i interface{}) {
if s.tp != untyped { // s.tp can only be untyped here
return
}
switch i.(type) { switch i.(type) {
case int: case int:
s.tp = intType s.tp = intType

View File

@@ -5,8 +5,13 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/logx"
) )
func init() {
logx.Disable()
}
func BenchmarkRawSet(b *testing.B) { func BenchmarkRawSet(b *testing.B) {
m := make(map[interface{}]struct{}) m := make(map[interface{}]struct{})
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
@@ -147,3 +152,51 @@ func TestCount(t *testing.T) {
// then // then
assert.Equal(t, set.Count(), 3) assert.Equal(t, set.Count(), 3)
} }
func TestKeysIntMismatch(t *testing.T) {
set := NewSet()
set.add(int64(1))
set.add(2)
vals := set.KeysInt()
assert.EqualValues(t, []int{2}, vals)
}
func TestKeysInt64Mismatch(t *testing.T) {
set := NewSet()
set.add(1)
set.add(int64(2))
vals := set.KeysInt64()
assert.EqualValues(t, []int64{2}, vals)
}
func TestKeysUintMismatch(t *testing.T) {
set := NewSet()
set.add(1)
set.add(uint(2))
vals := set.KeysUint()
assert.EqualValues(t, []uint{2}, vals)
}
func TestKeysUint64Mismatch(t *testing.T) {
set := NewSet()
set.add(1)
set.add(uint64(2))
vals := set.KeysUint64()
assert.EqualValues(t, []uint64{2}, vals)
}
func TestKeysStrMismatch(t *testing.T) {
set := NewSet()
set.add(1)
set.add("2")
vals := set.KeysStr()
assert.EqualValues(t, []string{"2"}, vals)
}
func TestSetType(t *testing.T) {
set := NewUnmanagedSet()
set.add(1)
set.add("2")
vals := set.Keys()
assert.ElementsMatch(t, []interface{}{1, "2"}, vals)
}

58
core/conf/config_test.go Normal file
View File

@@ -0,0 +1,58 @@
package conf
import (
"io/ioutil"
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/hash"
)
func TestConfigJson(t *testing.T) {
tests := []string{
".json",
".yaml",
".yml",
}
text := `{
"a": "foo",
"b": 1
}`
for _, test := range tests {
test := test
t.Run(test, func(t *testing.T) {
t.Parallel()
tmpfile, err := createTempFile(test, text)
assert.Nil(t, err)
defer os.Remove(tmpfile)
var val struct {
A string `json:"a"`
B int `json:"b"`
}
MustLoad(tmpfile, &val)
assert.Equal(t, "foo", val.A)
assert.Equal(t, 1, val.B)
})
}
}
func createTempFile(ext, text string) (string, error) {
tmpfile, err := ioutil.TempFile(os.TempDir(), hash.Md5Hex([]byte(text))+"*"+ext)
if err != nil {
return "", err
}
if err := ioutil.WriteFile(tmpfile.Name(), []byte(text), os.ModeTemporary); err != nil {
return "", err
}
filename := tmpfile.Name()
if err = tmpfile.Close(); err != nil {
return "", err
}
return filename, nil
}

View File

@@ -30,12 +30,12 @@ type mapBasedProperties struct {
lock sync.RWMutex lock sync.RWMutex
} }
// Loads the properties into a properties configuration instance. May return the // Loads the properties into a properties configuration instance.
// configuration itself along with an error that indicates if there was a problem loading the configuration. // Returns an error that indicates if there was a problem loading the configuration.
func LoadProperties(filename string) (Properties, error) { func LoadProperties(filename string) (Properties, error) {
lines, err := iox.ReadTextLines(filename, iox.WithoutBlank(), iox.OmitWithPrefix("#")) lines, err := iox.ReadTextLines(filename, iox.WithoutBlank(), iox.OmitWithPrefix("#"))
if err != nil { if err != nil {
return nil, nil return nil, err
} }
raw := make(map[string]string) raw := make(map[string]string)

View File

@@ -41,3 +41,8 @@ func TestSetInt(t *testing.T) {
props.SetInt(key, value) props.SetInt(key, value)
assert.Equal(t, value, props.GetInt(key)) assert.Equal(t, value, props.GetInt(key))
} }
func TestLoadBadFile(t *testing.T) {
_, err := LoadProperties("nosuchfile")
assert.NotNil(t, err)
}

View File

@@ -1,4 +1,3 @@
//go:generate mockgen -package internal -destination listener_mock.go -source listener.go Listener
package internal package internal
type Listener interface { type Listener interface {

View File

@@ -1,45 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: listener.go
// Package internal is a generated GoMock package.
package internal
import (
gomock "github.com/golang/mock/gomock"
reflect "reflect"
)
// MockListener is a mock of Listener interface
type MockListener struct {
ctrl *gomock.Controller
recorder *MockListenerMockRecorder
}
// MockListenerMockRecorder is the mock recorder for MockListener
type MockListenerMockRecorder struct {
mock *MockListener
}
// NewMockListener creates a new mock instance
func NewMockListener(ctrl *gomock.Controller) *MockListener {
mock := &MockListener{ctrl: ctrl}
mock.recorder = &MockListenerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockListener) EXPECT() *MockListenerMockRecorder {
return m.recorder
}
// OnUpdate mocks base method
func (m *MockListener) OnUpdate(keys, values []string, newKey string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "OnUpdate", keys, values, newKey)
}
// OnUpdate indicates an expected call of OnUpdate
func (mr *MockListenerMockRecorder) OnUpdate(keys, values, newKey interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnUpdate", reflect.TypeOf((*MockListener)(nil).OnUpdate), keys, values, newKey)
}

View File

@@ -40,6 +40,7 @@ spec:
- etcd0=http://etcd0:2380,etcd1=http://etcd1:2380,etcd2=http://etcd2:2380,etcd3=http://etcd3:2380,etcd4=http://etcd4:2380 - etcd0=http://etcd0:2380,etcd1=http://etcd1:2380,etcd2=http://etcd2:2380,etcd3=http://etcd3:2380,etcd4=http://etcd4:2380
- --initial-cluster-state - --initial-cluster-state
- new - new
- --auto-compaction-retention=1
image: quay.io/coreos/etcd:latest image: quay.io/coreos/etcd:latest
name: etcd0 name: etcd0
ports: ports:
@@ -111,6 +112,7 @@ spec:
- etcd0=http://etcd0:2380,etcd1=http://etcd1:2380,etcd2=http://etcd2:2380,etcd3=http://etcd3:2380,etcd4=http://etcd4:2380 - etcd0=http://etcd0:2380,etcd1=http://etcd1:2380,etcd2=http://etcd2:2380,etcd3=http://etcd3:2380,etcd4=http://etcd4:2380
- --initial-cluster-state - --initial-cluster-state
- new - new
- --auto-compaction-retention=1
image: quay.io/coreos/etcd:latest image: quay.io/coreos/etcd:latest
name: etcd1 name: etcd1
ports: ports:
@@ -182,6 +184,7 @@ spec:
- etcd0=http://etcd0:2380,etcd1=http://etcd1:2380,etcd2=http://etcd2:2380,etcd3=http://etcd3:2380,etcd4=http://etcd4:2380 - etcd0=http://etcd0:2380,etcd1=http://etcd1:2380,etcd2=http://etcd2:2380,etcd3=http://etcd3:2380,etcd4=http://etcd4:2380
- --initial-cluster-state - --initial-cluster-state
- new - new
- --auto-compaction-retention=1
image: quay.io/coreos/etcd:latest image: quay.io/coreos/etcd:latest
name: etcd2 name: etcd2
ports: ports:
@@ -253,6 +256,7 @@ spec:
- etcd0=http://etcd0:2380,etcd1=http://etcd1:2380,etcd2=http://etcd2:2380,etcd3=http://etcd3:2380,etcd4=http://etcd4:2380 - etcd0=http://etcd0:2380,etcd1=http://etcd1:2380,etcd2=http://etcd2:2380,etcd3=http://etcd3:2380,etcd4=http://etcd4:2380
- --initial-cluster-state - --initial-cluster-state
- new - new
- --auto-compaction-retention=1
image: quay.io/coreos/etcd:latest image: quay.io/coreos/etcd:latest
name: etcd3 name: etcd3
ports: ports:
@@ -324,6 +328,7 @@ spec:
- etcd0=http://etcd0:2380,etcd1=http://etcd1:2380,etcd2=http://etcd2:2380,etcd3=http://etcd3:2380,etcd4=http://etcd4:2380 - etcd0=http://etcd0:2380,etcd1=http://etcd1:2380,etcd2=http://etcd2:2380,etcd3=http://etcd3:2380,etcd4=http://etcd4:2380
- --initial-cluster-state - --initial-cluster-state
- new - new
- --auto-compaction-retention=1
image: quay.io/coreos/etcd:latest image: quay.io/coreos/etcd:latest
name: etcd4 name: etcd4
ports: ports:

View File

@@ -111,6 +111,10 @@ func TestPublisher_keepAliveAsyncQuit(t *testing.T) {
defer ctrl.Finish() defer ctrl.Finish()
const id clientv3.LeaseID = 1 const id clientv3.LeaseID = 1
cli := internal.NewMockEtcdClient(ctrl) cli := internal.NewMockEtcdClient(ctrl)
cli.EXPECT().ActiveConnection()
cli.EXPECT().Close()
defer cli.Close()
cli.ActiveConnection()
restore := setMockClient(cli) restore := setMockClient(cli)
defer restore() defer restore()
cli.EXPECT().Ctx().AnyTimes() cli.EXPECT().Ctx().AnyTimes()

View File

@@ -1,21 +1,18 @@
package errorx package errorx
import "sync" import "sync/atomic"
type AtomicError struct { type AtomicError struct {
err error err atomic.Value // error
lock sync.Mutex
} }
func (ae *AtomicError) Set(err error) { func (ae *AtomicError) Set(err error) {
ae.lock.Lock() ae.err.Store(err)
ae.err = err
ae.lock.Unlock()
} }
func (ae *AtomicError) Load() error { func (ae *AtomicError) Load() error {
ae.lock.Lock() if v := ae.err.Load(); v != nil {
err := ae.err return v.(error)
ae.lock.Unlock() }
return err return nil
} }

View File

@@ -2,6 +2,8 @@ package errorx
import ( import (
"errors" "errors"
"sync"
"sync/atomic"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -19,3 +21,53 @@ func TestAtomicErrorNil(t *testing.T) {
var err AtomicError var err AtomicError
assert.Nil(t, err.Load()) assert.Nil(t, err.Load())
} }
func BenchmarkAtomicError(b *testing.B) {
var aerr AtomicError
wg := sync.WaitGroup{}
b.Run("Load", func(b *testing.B) {
var done uint32
go func() {
for {
if atomic.LoadUint32(&done) != 0 {
break
}
wg.Add(1)
go func() {
aerr.Set(errDummy)
wg.Done()
}()
}
}()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = aerr.Load()
}
b.StopTimer()
atomic.StoreUint32(&done, 1)
wg.Wait()
})
b.Run("Set", func(b *testing.B) {
var done uint32
go func() {
for {
if atomic.LoadUint32(&done) != 0 {
break
}
wg.Add(1)
go func() {
_ = aerr.Load()
wg.Done()
}()
}
}()
b.ResetTimer()
for i := 0; i < b.N; i++ {
aerr.Set(errDummy)
}
b.StopTimer()
atomic.StoreUint32(&done, 1)
wg.Wait()
})
}

View File

@@ -84,6 +84,14 @@ func (p Stream) Buffer(n int) Stream {
return Range(source) return Range(source)
} }
// Count counts the number of elements in the result.
func (p Stream) Count() (count int) {
for range p.source {
count++
}
return
}
// Distinct removes the duplicated items base on the given KeyFunc. // Distinct removes the duplicated items base on the given KeyFunc.
func (p Stream) Distinct(fn KeyFunc) Stream { func (p Stream) Distinct(fn KeyFunc) Stream {
source := make(chan interface{}) source := make(chan interface{})

View File

@@ -49,6 +49,36 @@ func TestBufferNegative(t *testing.T) {
assert.Equal(t, 10, result) assert.Equal(t, 10, result)
} }
func TestCount(t *testing.T) {
tests := []struct {
name string
elements []interface{}
}{
{
name: "no elements with nil",
},
{
name: "no elements",
elements: []interface{}{},
},
{
name: "1 element",
elements: []interface{}{1},
},
{
name: "multiple elements",
elements: []interface{}{1, 2, 3},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
val := Just(test.elements...).Count()
assert.Equal(t, len(test.elements), val)
})
}
}
func TestDone(t *testing.T) { func TestDone(t *testing.T) {
var count int32 var count int32
Just(1, 2, 3).Walk(func(item interface{}, pipe chan<- interface{}) { Just(1, 2, 3).Walk(func(item interface{}, pipe chan<- interface{}) {

23
core/iox/pipe.go Normal file
View File

@@ -0,0 +1,23 @@
package iox
import "os"
// RedirectInOut redirects stdin to r, stdout to w, and callers need to call restore afterwards.
func RedirectInOut() (restore func(), err error) {
var r, w *os.File
r, w, err = os.Pipe()
if err != nil {
return
}
ow := os.Stdout
os.Stdout = w
or := os.Stdin
os.Stdin = r
restore = func() {
os.Stdin = or
os.Stdout = ow
}
return
}

13
core/iox/pipe_test.go Normal file
View File

@@ -0,0 +1,13 @@
package iox
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestRedirectInOut(t *testing.T) {
restore, err := RedirectInOut()
assert.Nil(t, err)
defer restore()
}

View File

@@ -135,6 +135,7 @@ func TestAdaptiveShedderShouldDrop(t *testing.T) {
passCounter: passCounter, passCounter: passCounter,
rtCounter: rtCounter, rtCounter: rtCounter,
windows: buckets, windows: buckets,
dropTime: syncx.NewAtomicDuration(),
droppedRecently: syncx.NewAtomicBool(), droppedRecently: syncx.NewAtomicBool(),
} }
// cpu >= 800, inflight < maxPass // cpu >= 800, inflight < maxPass
@@ -160,6 +161,40 @@ func TestAdaptiveShedderShouldDrop(t *testing.T) {
} }
shedder.avgFlying = 80 shedder.avgFlying = 80
assert.False(t, shedder.shouldDrop()) assert.False(t, shedder.shouldDrop())
// cpu >= 800, inflight < maxPass
systemOverloadChecker = func(int64) bool {
return true
}
shedder.avgFlying = 80
shedder.flying = 80
_, err := shedder.Allow()
assert.NotNil(t, err)
}
func TestAdaptiveShedderStillHot(t *testing.T) {
logx.Disable()
passCounter := newRollingWindow()
rtCounter := newRollingWindow()
for i := 0; i < 10; i++ {
if i > 0 {
time.Sleep(bucketDuration)
}
passCounter.Add(float64((i + 1) * 100))
for j := i*10 + 1; j <= i*10+10; j++ {
rtCounter.Add(float64(j))
}
}
shedder := &adaptiveShedder{
passCounter: passCounter,
rtCounter: rtCounter,
windows: buckets,
dropTime: syncx.NewAtomicDuration(),
droppedRecently: syncx.ForAtomicBool(true),
}
assert.False(t, shedder.stillHot())
shedder.dropTime.Set(-coolOffDuration * 2)
assert.False(t, shedder.stillHot())
} }
func BenchmarkAdaptiveShedder_Allow(b *testing.B) { func BenchmarkAdaptiveShedder_Allow(b *testing.B) {

View File

@@ -13,3 +13,8 @@ func TestGroup(t *testing.T) {
assert.NotNil(t, limiter) assert.NotNil(t, limiter)
}) })
} }
func TestShedderClose(t *testing.T) {
var nop nopCloser
assert.Nil(t, nop.Close())
}

View File

@@ -8,55 +8,60 @@ import (
"github.com/tal-tech/go-zero/core/timex" "github.com/tal-tech/go-zero/core/timex"
) )
const customCallerDepth = 3 const durationCallerDepth = 3
type customLog logEntry type durationLogger logEntry
func WithDuration(d time.Duration) Logger { func WithDuration(d time.Duration) Logger {
return customLog{ return &durationLogger{
Duration: timex.ReprOfDuration(d), Duration: timex.ReprOfDuration(d),
} }
} }
func (l customLog) Error(v ...interface{}) { func (l *durationLogger) Error(v ...interface{}) {
if shouldLog(ErrorLevel) { if shouldLog(ErrorLevel) {
l.write(errorLog, levelError, formatWithCaller(fmt.Sprint(v...), customCallerDepth)) l.write(errorLog, levelError, formatWithCaller(fmt.Sprint(v...), durationCallerDepth))
} }
} }
func (l customLog) Errorf(format string, v ...interface{}) { func (l *durationLogger) Errorf(format string, v ...interface{}) {
if shouldLog(ErrorLevel) { if shouldLog(ErrorLevel) {
l.write(errorLog, levelError, formatWithCaller(fmt.Sprintf(format, v...), customCallerDepth)) l.write(errorLog, levelError, formatWithCaller(fmt.Sprintf(format, v...), durationCallerDepth))
} }
} }
func (l customLog) Info(v ...interface{}) { func (l *durationLogger) Info(v ...interface{}) {
if shouldLog(InfoLevel) { if shouldLog(InfoLevel) {
l.write(infoLog, levelInfo, fmt.Sprint(v...)) l.write(infoLog, levelInfo, fmt.Sprint(v...))
} }
} }
func (l customLog) Infof(format string, v ...interface{}) { func (l *durationLogger) Infof(format string, v ...interface{}) {
if shouldLog(InfoLevel) { if shouldLog(InfoLevel) {
l.write(infoLog, levelInfo, fmt.Sprintf(format, v...)) l.write(infoLog, levelInfo, fmt.Sprintf(format, v...))
} }
} }
func (l customLog) Slow(v ...interface{}) { func (l *durationLogger) Slow(v ...interface{}) {
if shouldLog(ErrorLevel) { if shouldLog(ErrorLevel) {
l.write(slowLog, levelSlow, fmt.Sprint(v...)) l.write(slowLog, levelSlow, fmt.Sprint(v...))
} }
} }
func (l customLog) Slowf(format string, v ...interface{}) { func (l *durationLogger) Slowf(format string, v ...interface{}) {
if shouldLog(ErrorLevel) { if shouldLog(ErrorLevel) {
l.write(slowLog, levelSlow, fmt.Sprintf(format, v...)) l.write(slowLog, levelSlow, fmt.Sprintf(format, v...))
} }
} }
func (l customLog) write(writer io.Writer, level, content string) { func (l *durationLogger) WithDuration(duration time.Duration) Logger {
l.Duration = timex.ReprOfDuration(duration)
return l
}
func (l *durationLogger) write(writer io.Writer, level, content string) {
l.Timestamp = getTimestamp() l.Timestamp = getTimestamp()
l.Level = level l.Level = level
l.Content = content l.Content = content
outputJson(writer, logEntry(l)) outputJson(writer, logEntry(*l))
} }

View File

@@ -0,0 +1,52 @@
package logx
import (
"log"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestWithDurationError(t *testing.T) {
var builder strings.Builder
log.SetOutput(&builder)
WithDuration(time.Second).Error("foo")
assert.True(t, strings.Contains(builder.String(), "duration"), builder.String())
}
func TestWithDurationErrorf(t *testing.T) {
var builder strings.Builder
log.SetOutput(&builder)
WithDuration(time.Second).Errorf("foo")
assert.True(t, strings.Contains(builder.String(), "duration"), builder.String())
}
func TestWithDurationInfo(t *testing.T) {
var builder strings.Builder
log.SetOutput(&builder)
WithDuration(time.Second).Info("foo")
assert.True(t, strings.Contains(builder.String(), "duration"), builder.String())
}
func TestWithDurationInfof(t *testing.T) {
var builder strings.Builder
log.SetOutput(&builder)
WithDuration(time.Second).Infof("foo")
assert.True(t, strings.Contains(builder.String(), "duration"), builder.String())
}
func TestWithDurationSlow(t *testing.T) {
var builder strings.Builder
log.SetOutput(&builder)
WithDuration(time.Second).Slow("foo")
assert.True(t, strings.Contains(builder.String(), "duration"), builder.String())
}
func TestWithDurationSlowf(t *testing.T) {
var builder strings.Builder
log.SetOutput(&builder)
WithDuration(time.Second).WithDuration(time.Hour).Slowf("foo")
assert.True(t, strings.Contains(builder.String(), "duration"), builder.String())
}

View File

@@ -15,6 +15,7 @@ import (
"strings" "strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time"
"github.com/tal-tech/go-zero/core/iox" "github.com/tal-tech/go-zero/core/iox"
"github.com/tal-tech/go-zero/core/sysx" "github.com/tal-tech/go-zero/core/sysx"
@@ -96,6 +97,7 @@ type (
Infof(string, ...interface{}) Infof(string, ...interface{})
Slow(...interface{}) Slow(...interface{})
Slowf(string, ...interface{}) Slowf(string, ...interface{})
WithDuration(time.Duration) Logger
} }
) )

View File

@@ -6,8 +6,10 @@ import (
"io" "io"
"io/ioutil" "io/ioutil"
"log" "log"
"os"
"runtime" "runtime"
"strings" "strings"
"sync"
"sync/atomic" "sync/atomic"
"testing" "testing"
"time" "time"
@@ -21,10 +23,13 @@ var (
) )
type mockWriter struct { type mockWriter struct {
lock sync.Mutex
builder strings.Builder builder strings.Builder
} }
func (mw *mockWriter) Write(data []byte) (int, error) { func (mw *mockWriter) Write(data []byte) (int, error) {
mw.lock.Lock()
defer mw.lock.Unlock()
return mw.builder.Write(data) return mw.builder.Write(data)
} }
@@ -32,12 +37,22 @@ func (mw *mockWriter) Close() error {
return nil return nil
} }
func (mw *mockWriter) Contains(text string) bool {
mw.lock.Lock()
defer mw.lock.Unlock()
return strings.Contains(mw.builder.String(), text)
}
func (mw *mockWriter) Reset() { func (mw *mockWriter) Reset() {
mw.lock.Lock()
defer mw.lock.Unlock()
mw.builder.Reset() mw.builder.Reset()
} }
func (mw *mockWriter) Contains(text string) bool { func (mw *mockWriter) String() string {
return strings.Contains(mw.builder.String(), text) mw.lock.Lock()
defer mw.lock.Unlock()
return mw.builder.String()
} }
func TestFileLineFileMode(t *testing.T) { func TestFileLineFileMode(t *testing.T) {
@@ -85,6 +100,46 @@ func TestStructedLogSlow(t *testing.T) {
}) })
} }
func TestStructedLogSlowf(t *testing.T) {
doTestStructedLog(t, levelSlow, func(writer io.WriteCloser) {
slowLog = writer
}, func(v ...interface{}) {
Slowf(fmt.Sprint(v...))
})
}
func TestStructedLogStat(t *testing.T) {
doTestStructedLog(t, levelStat, func(writer io.WriteCloser) {
statLog = writer
}, func(v ...interface{}) {
Stat(v...)
})
}
func TestStructedLogStatf(t *testing.T) {
doTestStructedLog(t, levelStat, func(writer io.WriteCloser) {
statLog = writer
}, func(v ...interface{}) {
Statf(fmt.Sprint(v...))
})
}
func TestStructedLogSevere(t *testing.T) {
doTestStructedLog(t, levelSevere, func(writer io.WriteCloser) {
severeLog = writer
}, func(v ...interface{}) {
Severe(v...)
})
}
func TestStructedLogSeveref(t *testing.T) {
doTestStructedLog(t, levelSevere, func(writer io.WriteCloser) {
severeLog = writer
}, func(v ...interface{}) {
Severef(fmt.Sprint(v...))
})
}
func TestStructedLogWithDuration(t *testing.T) { func TestStructedLogWithDuration(t *testing.T) {
const message = "hello there" const message = "hello there"
writer := new(mockWriter) writer := new(mockWriter)
@@ -135,6 +190,66 @@ func TestMustNil(t *testing.T) {
Must(nil) Must(nil)
} }
func TestSetup(t *testing.T) {
MustSetup(LogConf{
ServiceName: "any",
Mode: "console",
})
MustSetup(LogConf{
ServiceName: "any",
Mode: "file",
Path: os.TempDir(),
})
MustSetup(LogConf{
ServiceName: "any",
Mode: "volume",
Path: os.TempDir(),
})
assert.NotNil(t, setupWithVolume(LogConf{}))
assert.NotNil(t, setupWithFiles(LogConf{}))
assert.Nil(t, setupWithFiles(LogConf{
ServiceName: "any",
Path: os.TempDir(),
Compress: true,
KeepDays: 1,
}))
setupLogLevel(LogConf{
Level: levelInfo,
})
setupLogLevel(LogConf{
Level: levelError,
})
setupLogLevel(LogConf{
Level: levelSevere,
})
_, err := createOutput("")
assert.NotNil(t, err)
Disable()
}
func TestDisable(t *testing.T) {
Disable()
WithKeepDays(1)
WithGzip()
assert.Nil(t, Close())
writeConsole = false
assert.Nil(t, Close())
}
func TestWithGzip(t *testing.T) {
fn := WithGzip()
var opt logOptions
fn(&opt)
assert.True(t, opt.gzipEnabled)
}
func TestWithKeepDays(t *testing.T) {
fn := WithKeepDays(1)
var opt logOptions
fn(&opt)
assert.Equal(t, 1, opt.keepDays)
}
func BenchmarkCopyByteSliceAppend(b *testing.B) { func BenchmarkCopyByteSliceAppend(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
var buf []byte var buf []byte
@@ -232,7 +347,7 @@ func doTestStructedLog(t *testing.T, level string, setup func(writer io.WriteClo
t.Error(err) t.Error(err)
} }
assert.Equal(t, level, entry.Level) assert.Equal(t, level, entry.Level)
assert.Equal(t, message, entry.Content) assert.True(t, strings.Contains(entry.Content, message))
} }
func testSetLevelTwiceWithMode(t *testing.T, mode string) { func testSetLevelTwiceWithMode(t *testing.T, mode string) {
@@ -252,4 +367,10 @@ func testSetLevelTwiceWithMode(t *testing.T, mode string) {
atomic.StoreUint32(&initialized, 1) atomic.StoreUint32(&initialized, 1)
Info(message) Info(message)
assert.Equal(t, 0, writer.builder.Len()) assert.Equal(t, 0, writer.builder.Len())
Infof(message)
assert.Equal(t, 0, writer.builder.Len())
ErrorStack(message)
assert.Equal(t, 0, writer.builder.Len())
ErrorStackf(message)
assert.Equal(t, 0, writer.builder.Len())
} }

View File

@@ -192,14 +192,16 @@ func (l *RotateLogger) init() error {
} }
func (l *RotateLogger) maybeCompressFile(file string) { func (l *RotateLogger) maybeCompressFile(file string) {
if l.compress { if !l.compress {
defer func() { return
if r := recover(); r != nil {
ErrorStack(r)
}
}()
compressLogFile(file)
} }
defer func() {
if r := recover(); r != nil {
ErrorStack(r)
}
}()
compressLogFile(file)
} }
func (l *RotateLogger) maybeDeleteOutdatedFiles() { func (l *RotateLogger) maybeDeleteOutdatedFiles() {

View File

@@ -0,0 +1,119 @@
package logx
import (
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/fs"
)
func TestDailyRotateRuleMarkRotated(t *testing.T) {
var rule DailyRotateRule
rule.MarkRotated()
assert.Equal(t, getNowDate(), rule.rotatedTime)
}
func TestDailyRotateRuleOutdatedFiles(t *testing.T) {
var rule DailyRotateRule
assert.Empty(t, rule.OutdatedFiles())
rule.days = 1
assert.Empty(t, rule.OutdatedFiles())
rule.gzip = true
assert.Empty(t, rule.OutdatedFiles())
}
func TestDailyRotateRuleShallRotate(t *testing.T) {
var rule DailyRotateRule
rule.rotatedTime = time.Now().Add(time.Hour * 24).Format(dateFormat)
assert.True(t, rule.ShallRotate())
}
func TestRotateLoggerClose(t *testing.T) {
filename, err := fs.TempFilenameWithText("foo")
assert.Nil(t, err)
if len(filename) > 0 {
defer os.Remove(filename)
}
logger, err := NewLogger(filename, new(DailyRotateRule), false)
assert.Nil(t, err)
assert.Nil(t, logger.Close())
}
func TestRotateLoggerGetBackupFilename(t *testing.T) {
filename, err := fs.TempFilenameWithText("foo")
assert.Nil(t, err)
if len(filename) > 0 {
defer os.Remove(filename)
}
logger, err := NewLogger(filename, new(DailyRotateRule), false)
assert.Nil(t, err)
assert.True(t, len(logger.getBackupFilename()) > 0)
logger.backup = ""
assert.True(t, len(logger.getBackupFilename()) > 0)
}
func TestRotateLoggerMayCompressFile(t *testing.T) {
filename, err := fs.TempFilenameWithText("foo")
assert.Nil(t, err)
if len(filename) > 0 {
defer os.Remove(filename)
}
logger, err := NewLogger(filename, new(DailyRotateRule), false)
assert.Nil(t, err)
logger.maybeCompressFile(filename)
_, err = os.Stat(filename)
assert.Nil(t, err)
}
func TestRotateLoggerMayCompressFileTrue(t *testing.T) {
filename, err := fs.TempFilenameWithText("foo")
assert.Nil(t, err)
logger, err := NewLogger(filename, new(DailyRotateRule), true)
assert.Nil(t, err)
if len(filename) > 0 {
defer func() {
os.Remove(filename)
os.Remove(filepath.Base(logger.getBackupFilename()) + ".gz")
}()
}
logger.maybeCompressFile(filename)
_, err = os.Stat(filename)
assert.NotNil(t, err)
}
func TestRotateLoggerRotate(t *testing.T) {
filename, err := fs.TempFilenameWithText("foo")
assert.Nil(t, err)
logger, err := NewLogger(filename, new(DailyRotateRule), true)
assert.Nil(t, err)
if len(filename) > 0 {
defer func() {
os.Remove(filename)
os.Remove(logger.getBackupFilename())
os.Remove(filepath.Base(logger.getBackupFilename()) + ".gz")
}()
}
err = logger.rotate()
assert.Nil(t, err)
}
func TestRotateLoggerWrite(t *testing.T) {
filename, err := fs.TempFilenameWithText("foo")
assert.Nil(t, err)
rule := new(DailyRotateRule)
logger, err := NewLogger(filename, rule, true)
assert.Nil(t, err)
if len(filename) > 0 {
defer func() {
os.Remove(filename)
os.Remove(logger.getBackupFilename())
os.Remove(filepath.Base(logger.getBackupFilename()) + ".gz")
}()
}
logger.write([]byte(`foo`))
rule.rotatedTime = time.Now().Add(-time.Hour * 24).Format(dateFormat)
logger.write([]byte(`bar`))
}

View File

@@ -33,10 +33,10 @@ func captureOutput(f func()) string {
writer := new(mockWriter) writer := new(mockWriter)
infoLog = writer infoLog = writer
prevLevel := logLevel prevLevel := atomic.LoadUint32(&logLevel)
logLevel = InfoLevel SetLevel(InfoLevel)
f() f()
logLevel = prevLevel SetLevel(prevLevel)
return writer.builder.String() return writer.builder.String()
} }

View File

@@ -1,49 +0,0 @@
package logx
import (
"context"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/trace/tracespec"
)
const (
mockTraceId = "mock-trace-id"
mockSpanId = "mock-span-id"
)
var mock tracespec.Trace = new(mockTrace)
func TestTraceLog(t *testing.T) {
var buf strings.Builder
ctx := context.WithValue(context.Background(), tracespec.TracingKey, mock)
WithContext(ctx).(tracingEntry).write(&buf, levelInfo, testlog)
assert.True(t, strings.Contains(buf.String(), mockTraceId))
assert.True(t, strings.Contains(buf.String(), mockSpanId))
}
type mockTrace struct{}
func (t mockTrace) TraceId() string {
return mockTraceId
}
func (t mockTrace) SpanId() string {
return mockSpanId
}
func (t mockTrace) Finish() {
}
func (t mockTrace) Fork(ctx context.Context, serviceName, operationName string) (context.Context, tracespec.Trace) {
return nil, nil
}
func (t mockTrace) Follow(ctx context.Context, serviceName, operationName string) (context.Context, tracespec.Trace) {
return nil, nil
}
func (t mockTrace) Visit(fn func(key string, val string) bool) {
}

View File

@@ -4,54 +4,61 @@ import (
"context" "context"
"fmt" "fmt"
"io" "io"
"time"
"github.com/tal-tech/go-zero/core/timex"
"github.com/tal-tech/go-zero/core/trace/tracespec" "github.com/tal-tech/go-zero/core/trace/tracespec"
) )
type tracingEntry struct { type traceLogger struct {
logEntry logEntry
Trace string `json:"trace,omitempty"` Trace string `json:"trace,omitempty"`
Span string `json:"span,omitempty"` Span string `json:"span,omitempty"`
ctx context.Context `json:"-"` ctx context.Context
} }
func (l tracingEntry) Error(v ...interface{}) { func (l *traceLogger) Error(v ...interface{}) {
if shouldLog(ErrorLevel) { if shouldLog(ErrorLevel) {
l.write(errorLog, levelError, formatWithCaller(fmt.Sprint(v...), customCallerDepth)) l.write(errorLog, levelError, formatWithCaller(fmt.Sprint(v...), durationCallerDepth))
} }
} }
func (l tracingEntry) Errorf(format string, v ...interface{}) { func (l *traceLogger) Errorf(format string, v ...interface{}) {
if shouldLog(ErrorLevel) { if shouldLog(ErrorLevel) {
l.write(errorLog, levelError, formatWithCaller(fmt.Sprintf(format, v...), customCallerDepth)) l.write(errorLog, levelError, formatWithCaller(fmt.Sprintf(format, v...), durationCallerDepth))
} }
} }
func (l tracingEntry) Info(v ...interface{}) { func (l *traceLogger) Info(v ...interface{}) {
if shouldLog(InfoLevel) { if shouldLog(InfoLevel) {
l.write(infoLog, levelInfo, fmt.Sprint(v...)) l.write(infoLog, levelInfo, fmt.Sprint(v...))
} }
} }
func (l tracingEntry) Infof(format string, v ...interface{}) { func (l *traceLogger) Infof(format string, v ...interface{}) {
if shouldLog(InfoLevel) { if shouldLog(InfoLevel) {
l.write(infoLog, levelInfo, fmt.Sprintf(format, v...)) l.write(infoLog, levelInfo, fmt.Sprintf(format, v...))
} }
} }
func (l tracingEntry) Slow(v ...interface{}) { func (l *traceLogger) Slow(v ...interface{}) {
if shouldLog(ErrorLevel) { if shouldLog(ErrorLevel) {
l.write(slowLog, levelSlow, fmt.Sprint(v...)) l.write(slowLog, levelSlow, fmt.Sprint(v...))
} }
} }
func (l tracingEntry) Slowf(format string, v ...interface{}) { func (l *traceLogger) Slowf(format string, v ...interface{}) {
if shouldLog(ErrorLevel) { if shouldLog(ErrorLevel) {
l.write(slowLog, levelSlow, fmt.Sprintf(format, v...)) l.write(slowLog, levelSlow, fmt.Sprintf(format, v...))
} }
} }
func (l tracingEntry) write(writer io.Writer, level, content string) { func (l *traceLogger) WithDuration(duration time.Duration) Logger {
l.Duration = timex.ReprOfDuration(duration)
return l
}
func (l *traceLogger) write(writer io.Writer, level, content string) {
l.Timestamp = getTimestamp() l.Timestamp = getTimestamp()
l.Level = level l.Level = level
l.Content = content l.Content = content
@@ -61,7 +68,7 @@ func (l tracingEntry) write(writer io.Writer, level, content string) {
} }
func WithContext(ctx context.Context) Logger { func WithContext(ctx context.Context) Logger {
return tracingEntry{ return &traceLogger{
ctx: ctx, ctx: ctx,
} }
} }

View File

@@ -0,0 +1,115 @@
package logx
import (
"context"
"log"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/trace/tracespec"
)
const (
mockTraceId = "mock-trace-id"
mockSpanId = "mock-span-id"
)
var mock tracespec.Trace = new(mockTrace)
func TestTraceLog(t *testing.T) {
var buf mockWriter
ctx := context.WithValue(context.Background(), tracespec.TracingKey, mock)
WithContext(ctx).(*traceLogger).write(&buf, levelInfo, testlog)
assert.True(t, strings.Contains(buf.String(), mockTraceId))
assert.True(t, strings.Contains(buf.String(), mockSpanId))
}
func TestTraceError(t *testing.T) {
var buf mockWriter
ctx := context.WithValue(context.Background(), tracespec.TracingKey, mock)
l := WithContext(ctx).(*traceLogger)
SetLevel(InfoLevel)
atomic.StoreUint32(&initialized, 1)
errorLog = newLogWriter(log.New(&buf, "", flags))
l.WithDuration(time.Second).Error(testlog)
assert.True(t, strings.Contains(buf.String(), mockTraceId))
assert.True(t, strings.Contains(buf.String(), mockSpanId))
buf.Reset()
l.WithDuration(time.Second).Errorf(testlog)
assert.True(t, strings.Contains(buf.String(), mockTraceId))
assert.True(t, strings.Contains(buf.String(), mockSpanId))
}
func TestTraceInfo(t *testing.T) {
var buf mockWriter
ctx := context.WithValue(context.Background(), tracespec.TracingKey, mock)
l := WithContext(ctx).(*traceLogger)
SetLevel(InfoLevel)
atomic.StoreUint32(&initialized, 1)
infoLog = newLogWriter(log.New(&buf, "", flags))
l.WithDuration(time.Second).Info(testlog)
assert.True(t, strings.Contains(buf.String(), mockTraceId))
assert.True(t, strings.Contains(buf.String(), mockSpanId))
buf.Reset()
l.WithDuration(time.Second).Infof(testlog)
assert.True(t, strings.Contains(buf.String(), mockTraceId))
assert.True(t, strings.Contains(buf.String(), mockSpanId))
}
func TestTraceSlow(t *testing.T) {
var buf mockWriter
ctx := context.WithValue(context.Background(), tracespec.TracingKey, mock)
l := WithContext(ctx).(*traceLogger)
SetLevel(InfoLevel)
atomic.StoreUint32(&initialized, 1)
slowLog = newLogWriter(log.New(&buf, "", flags))
l.WithDuration(time.Second).Slow(testlog)
assert.True(t, strings.Contains(buf.String(), mockTraceId))
assert.True(t, strings.Contains(buf.String(), mockSpanId))
buf.Reset()
l.WithDuration(time.Second).Slowf(testlog)
assert.True(t, strings.Contains(buf.String(), mockTraceId))
assert.True(t, strings.Contains(buf.String(), mockSpanId))
}
func TestTraceWithoutContext(t *testing.T) {
var buf mockWriter
l := WithContext(context.Background()).(*traceLogger)
SetLevel(InfoLevel)
atomic.StoreUint32(&initialized, 1)
infoLog = newLogWriter(log.New(&buf, "", flags))
l.WithDuration(time.Second).Info(testlog)
assert.False(t, strings.Contains(buf.String(), mockTraceId))
assert.False(t, strings.Contains(buf.String(), mockSpanId))
buf.Reset()
l.WithDuration(time.Second).Infof(testlog)
assert.False(t, strings.Contains(buf.String(), mockTraceId))
assert.False(t, strings.Contains(buf.String(), mockSpanId))
}
type mockTrace struct{}
func (t mockTrace) TraceId() string {
return mockTraceId
}
func (t mockTrace) SpanId() string {
return mockSpanId
}
func (t mockTrace) Finish() {
}
func (t mockTrace) Fork(ctx context.Context, serviceName, operationName string) (context.Context, tracespec.Trace) {
return nil, nil
}
func (t mockTrace) Follow(ctx context.Context, serviceName, operationName string) (context.Context, tracespec.Trace) {
return nil, nil
}
func (t mockTrace) Visit(fn func(key string, val string) bool) {
}

View File

@@ -0,0 +1,31 @@
package mapping
import (
"reflect"
"testing"
"github.com/stretchr/testify/assert"
)
type Bar struct {
Val string `json:"val"`
}
func TestFieldOptionOptionalDep(t *testing.T) {
var bar Bar
rt := reflect.TypeOf(bar)
for i := 0; i < rt.NumField(); i++ {
field := rt.Field(i)
val, opt, err := parseKeyAndOptions(jsonTagKey, field)
assert.Equal(t, "val", val)
assert.Nil(t, opt)
assert.Nil(t, err)
}
// check nil working
var o *fieldOptions
check := func(o *fieldOptions) {
assert.Equal(t, 0, len(o.optionalDep()))
}
check(o)
}

View File

@@ -23,6 +23,7 @@ const (
var ( var (
errTypeMismatch = errors.New("type mismatch") errTypeMismatch = errors.New("type mismatch")
errValueNotSettable = errors.New("value is not settable") errValueNotSettable = errors.New("value is not settable")
errValueNotStruct = errors.New("value type is not struct")
keyUnmarshaler = NewUnmarshaler(defaultKeyName) keyUnmarshaler = NewUnmarshaler(defaultKeyName)
cacheKeys atomic.Value cacheKeys atomic.Value
cacheKeysLock sync.Mutex cacheKeysLock sync.Mutex
@@ -80,6 +81,10 @@ func (u *Unmarshaler) unmarshalWithFullName(m Valuer, v interface{}, fullName st
} }
rte := reflect.TypeOf(v).Elem() rte := reflect.TypeOf(v).Elem()
if rte.Kind() != reflect.Struct {
return errValueNotStruct
}
rve := rv.Elem() rve := rv.Elem()
numFields := rte.NumField() numFields := rte.NumField()
for i := 0; i < numFields; i++ { for i := 0; i < numFields; i++ {
@@ -345,7 +350,7 @@ func (u *Unmarshaler) processNamedFieldWithValue(field reflect.StructField, valu
options := opts.options() options := opts.options()
if len(options) > 0 { if len(options) > 0 {
if !stringx.Contains(options, mapValue.(string)) { if !stringx.Contains(options, mapValue.(string)) {
return fmt.Errorf(`error: value "%s" for field "%s" is not defined in opts "%v"`, return fmt.Errorf(`error: value "%s" for field "%s" is not defined in options "%v"`,
mapValue, key, options) mapValue, key, options)
} }
} }

View File

@@ -14,6 +14,13 @@ import (
// so we only can test to 62 bits. // so we only can test to 62 bits.
const maxUintBitsToTest = 62 const maxUintBitsToTest = 62
func TestUnmarshalWithFullNameNotStruct(t *testing.T) {
var s map[string]interface{}
content := []byte(`{"name":"xiaoming"}`)
err := UnmarshalJsonBytes(content, &s)
assert.Equal(t, errValueNotStruct, err)
}
func TestUnmarshalWithoutTagName(t *testing.T) { func TestUnmarshalWithoutTagName(t *testing.T) {
type inner struct { type inner struct {
Optional bool `key:",optional"` Optional bool `key:",optional"`
@@ -2380,6 +2387,13 @@ func TestUnmarshalNestedMapSimpleTypeMatch(t *testing.T) {
assert.Equal(t, "1", c.Anything["id"]) assert.Equal(t, "1", c.Anything["id"])
} }
func TestUnmarshalValuer(t *testing.T) {
unmarshaler := NewUnmarshaler(jsonTagKey)
var foo string
err := unmarshaler.UnmarshalValuer(nil, foo)
assert.NotNil(t, err)
}
func BenchmarkUnmarshalString(b *testing.B) { func BenchmarkUnmarshalString(b *testing.B) {
type inner struct { type inner struct {
Value string `key:"value"` Value string `key:"value"`

View File

@@ -0,0 +1,16 @@
package proc
import (
"log"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestDumpGoroutines(t *testing.T) {
var buf strings.Builder
log.SetOutput(&buf)
dumpGoroutines()
assert.True(t, strings.Contains(buf.String(), ".dump"))
}

21
core/proc/profile_test.go Normal file
View File

@@ -0,0 +1,21 @@
package proc
import (
"log"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestProfile(t *testing.T) {
var buf strings.Builder
log.SetOutput(&buf)
profiler := StartProfile()
// start again should not work
assert.NotNil(t, StartProfile())
profiler.Stop()
// stop twice
profiler.Stop()
assert.True(t, strings.Contains(buf.String(), ".pprof"))
}

22
core/stat/alert_test.go Normal file
View File

@@ -0,0 +1,22 @@
// +build linux
package stat
import (
"strconv"
"sync/atomic"
"testing"
"github.com/stretchr/testify/assert"
)
func TestReport(t *testing.T) {
var count int32
SetReporter(func(s string) {
atomic.AddInt32(&count, 1)
})
for i := 0; i < 10; i++ {
Report(strconv.Itoa(i))
}
assert.Equal(t, int32(1), count)
}

View File

@@ -1,6 +1,14 @@
package internal package internal
import "testing" import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestRefreshCpu(t *testing.T) {
assert.True(t, RefreshCpu() >= 0)
}
func BenchmarkRefreshCpu(b *testing.B) { func BenchmarkRefreshCpu(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {

37
core/stat/metrics_test.go Normal file
View File

@@ -0,0 +1,37 @@
package stat
import (
"strconv"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestMetrics(t *testing.T) {
counts := []int{1, 5, 10, 100, 1000, 1000}
for _, count := range counts {
m := NewMetrics("foo")
m.SetName("bar")
for i := 0; i < count; i++ {
m.Add(Task{
Duration: time.Millisecond * time.Duration(i),
Description: strconv.Itoa(i),
})
}
m.AddDrop()
var writer mockedWriter
SetReportWriter(&writer)
m.executor.Flush()
assert.Equal(t, "bar", writer.report.Name)
}
}
type mockedWriter struct {
report *StatReport
}
func (m *mockedWriter) Write(report *StatReport) error {
m.report = report
return nil
}

View File

@@ -0,0 +1,30 @@
package stat
import (
"testing"
"github.com/stretchr/testify/assert"
"gopkg.in/h2non/gock.v1"
)
func TestRemoteWriter(t *testing.T) {
defer gock.Off()
gock.New("http://foo.com").Reply(200).BodyString("foo")
writer := NewRemoteWriter("http://foo.com")
err := writer.Write(&StatReport{
Name: "bar",
})
assert.Nil(t, err)
}
func TestRemoteWriterFail(t *testing.T) {
defer gock.Off()
gock.New("http://foo.com").Reply(503).BodyString("foo")
writer := NewRemoteWriter("http://foo.com")
err := writer.Write(&StatReport{
Name: "bar",
})
assert.NotNil(t, err)
}

View File

@@ -1,4 +1,4 @@
package internal package cache
import ( import (
"fmt" "fmt"

View File

@@ -1,4 +1,4 @@
package internal package cache
import ( import (
"encoding/json" "encoding/json"
@@ -111,6 +111,45 @@ func TestCache_SetDel(t *testing.T) {
assert.Nil(t, c.GetCache(fmt.Sprintf("key/%d", i), &v)) assert.Nil(t, c.GetCache(fmt.Sprintf("key/%d", i), &v))
assert.Equal(t, i, v) assert.Equal(t, i, v)
} }
assert.Nil(t, c.DelCache())
for i := 0; i < total; i++ {
assert.Nil(t, c.DelCache(fmt.Sprintf("key/%d", i)))
}
for i := 0; i < total; i++ {
var v int
assert.Equal(t, errPlaceholder, c.GetCache(fmt.Sprintf("key/%d", i), &v))
assert.Equal(t, 0, v)
}
}
func TestCache_OneNode(t *testing.T) {
const total = 1000
r := miniredis.NewMiniRedis()
assert.Nil(t, r.Start())
defer r.Close()
conf := ClusterConf{
{
RedisConf: redis.RedisConf{
Host: r.Addr(),
Type: redis.NodeType,
},
Weight: 100,
},
}
c := NewCache(conf, syncx.NewSharedCalls(), NewCacheStat("mock"), errPlaceholder)
for i := 0; i < total; i++ {
if i%2 == 0 {
assert.Nil(t, c.SetCache(fmt.Sprintf("key/%d", i), i))
} else {
assert.Nil(t, c.SetCacheWithExpire(fmt.Sprintf("key/%d", i), i, 0))
}
}
for i := 0; i < total; i++ {
var v int
assert.Nil(t, c.GetCache(fmt.Sprintf("key/%d", i), &v))
assert.Equal(t, i, v)
}
assert.Nil(t, c.DelCache())
for i := 0; i < total; i++ { for i := 0; i < total; i++ {
assert.Nil(t, c.DelCache(fmt.Sprintf("key/%d", i))) assert.Nil(t, c.DelCache(fmt.Sprintf("key/%d", i)))
} }
@@ -188,6 +227,25 @@ func TestCache_Balance(t *testing.T) {
assert.Equal(t, total/10, count) assert.Equal(t, total/10, count)
} }
func TestCacheNoNode(t *testing.T) {
dispatcher := hash.NewConsistentHash()
c := cacheCluster{
dispatcher: dispatcher,
errNotFound: errPlaceholder,
}
assert.NotNil(t, c.DelCache("foo"))
assert.NotNil(t, c.DelCache("foo", "bar", "any"))
assert.NotNil(t, c.GetCache("foo", nil))
assert.NotNil(t, c.SetCache("foo", nil))
assert.NotNil(t, c.SetCacheWithExpire("foo", nil, time.Second))
assert.NotNil(t, c.Take(nil, "foo", func(v interface{}) error {
return nil
}))
assert.NotNil(t, c.TakeWithExpire(nil, "foo", func(v interface{}, duration time.Duration) error {
return nil
}))
}
func calcEntropy(m map[int]int, total int) float64 { func calcEntropy(m map[int]int, total int) float64 {
var entropy float64 var entropy float64

View File

@@ -1,5 +1,3 @@
package cache package cache
import "github.com/tal-tech/go-zero/core/stores/internal" type CacheConf = ClusterConf
type CacheConf = internal.ClusterConf

View File

@@ -1,13 +1,13 @@
package internal package cache
import ( import (
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"math/rand" "math/rand"
"sync" "sync"
"time" "time"
"github.com/tal-tech/go-zero/core/jsonx"
"github.com/tal-tech/go-zero/core/logx" "github.com/tal-tech/go-zero/core/logx"
"github.com/tal-tech/go-zero/core/mathx" "github.com/tal-tech/go-zero/core/mathx"
"github.com/tal-tech/go-zero/core/stat" "github.com/tal-tech/go-zero/core/stat"
@@ -79,7 +79,7 @@ func (c cacheNode) SetCache(key string, v interface{}) error {
} }
func (c cacheNode) SetCacheWithExpire(key string, v interface{}, expire time.Duration) error { func (c cacheNode) SetCacheWithExpire(key string, v interface{}, expire time.Duration) error {
data, err := json.Marshal(v) data, err := jsonx.Marshal(v)
if err != nil { if err != nil {
return err return err
} }
@@ -168,7 +168,7 @@ func (c cacheNode) doTake(v interface{}, key string, query func(v interface{}) e
} }
} }
return json.Marshal(v) return jsonx.Marshal(v)
}) })
if err != nil { if err != nil {
return err return err
@@ -181,11 +181,11 @@ func (c cacheNode) doTake(v interface{}, key string, query func(v interface{}) e
c.stat.IncrementHit() c.stat.IncrementHit()
} }
return json.Unmarshal(val.([]byte), v) return jsonx.Unmarshal(val.([]byte), v)
} }
func (c cacheNode) processCache(key string, data string, v interface{}) error { func (c cacheNode) processCache(key string, data string, v interface{}) error {
err := json.Unmarshal([]byte(data), v) err := jsonx.Unmarshal([]byte(data), v)
if err == nil { if err == nil {
return nil return nil
} }

208
core/stores/cache/cachenode_test.go vendored Normal file
View File

@@ -0,0 +1,208 @@
package cache
import (
"errors"
"fmt"
"math/rand"
"strconv"
"sync"
"testing"
"time"
"github.com/alicebob/miniredis"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/logx"
"github.com/tal-tech/go-zero/core/mathx"
"github.com/tal-tech/go-zero/core/stat"
"github.com/tal-tech/go-zero/core/stores/redis"
"github.com/tal-tech/go-zero/core/syncx"
)
var errTestNotFound = errors.New("not found")
func init() {
logx.Disable()
stat.SetReporter(nil)
}
func TestCacheNode_DelCache(t *testing.T) {
s, err := miniredis.Run()
assert.Nil(t, err)
defer s.Close()
cn := cacheNode{
rds: redis.NewRedis(s.Addr(), redis.NodeType),
r: rand.New(rand.NewSource(time.Now().UnixNano())),
lock: new(sync.Mutex),
unstableExpiry: mathx.NewUnstable(expiryDeviation),
stat: NewCacheStat("any"),
errNotFound: errTestNotFound,
}
assert.Nil(t, cn.DelCache())
assert.Nil(t, cn.DelCache([]string{}...))
assert.Nil(t, cn.DelCache(make([]string, 0)...))
cn.SetCache("first", "one")
assert.Nil(t, cn.DelCache("first"))
cn.SetCache("first", "one")
cn.SetCache("second", "two")
assert.Nil(t, cn.DelCache("first", "second"))
}
func TestCacheNode_InvalidCache(t *testing.T) {
s, err := miniredis.Run()
assert.Nil(t, err)
defer s.Close()
cn := cacheNode{
rds: redis.NewRedis(s.Addr(), redis.NodeType),
r: rand.New(rand.NewSource(time.Now().UnixNano())),
lock: new(sync.Mutex),
unstableExpiry: mathx.NewUnstable(expiryDeviation),
stat: NewCacheStat("any"),
errNotFound: errTestNotFound,
}
s.Set("any", "value")
var str string
assert.NotNil(t, cn.GetCache("any", &str))
assert.Equal(t, "", str)
_, err = s.Get("any")
assert.Equal(t, miniredis.ErrKeyNotFound, err)
}
func TestCacheNode_Take(t *testing.T) {
s, err := miniredis.Run()
assert.Nil(t, err)
defer s.Close()
cn := cacheNode{
rds: redis.NewRedis(s.Addr(), redis.NodeType),
r: rand.New(rand.NewSource(time.Now().UnixNano())),
barrier: syncx.NewSharedCalls(),
lock: new(sync.Mutex),
unstableExpiry: mathx.NewUnstable(expiryDeviation),
stat: NewCacheStat("any"),
errNotFound: errTestNotFound,
}
var str string
err = cn.Take(&str, "any", func(v interface{}) error {
*v.(*string) = "value"
return nil
})
assert.Nil(t, err)
assert.Equal(t, "value", str)
assert.Nil(t, cn.GetCache("any", &str))
val, err := s.Get("any")
assert.Nil(t, err)
assert.Equal(t, `"value"`, val)
}
func TestCacheNode_TakeNotFound(t *testing.T) {
s, err := miniredis.Run()
assert.Nil(t, err)
defer s.Close()
cn := cacheNode{
rds: redis.NewRedis(s.Addr(), redis.NodeType),
r: rand.New(rand.NewSource(time.Now().UnixNano())),
barrier: syncx.NewSharedCalls(),
lock: new(sync.Mutex),
unstableExpiry: mathx.NewUnstable(expiryDeviation),
stat: NewCacheStat("any"),
errNotFound: errTestNotFound,
}
var str string
err = cn.Take(&str, "any", func(v interface{}) error {
return errTestNotFound
})
assert.Equal(t, errTestNotFound, err)
assert.Equal(t, errTestNotFound, cn.GetCache("any", &str))
val, err := s.Get("any")
assert.Nil(t, err)
assert.Equal(t, `*`, val)
s.Set("any", "*")
err = cn.Take(&str, "any", func(v interface{}) error {
return nil
})
assert.Equal(t, errTestNotFound, err)
assert.Equal(t, errTestNotFound, cn.GetCache("any", &str))
s.Del("any")
var errDummy = errors.New("dummy")
err = cn.Take(&str, "any", func(v interface{}) error {
return errDummy
})
assert.Equal(t, errDummy, err)
}
func TestCacheNode_TakeWithExpire(t *testing.T) {
s, err := miniredis.Run()
assert.Nil(t, err)
defer s.Close()
cn := cacheNode{
rds: redis.NewRedis(s.Addr(), redis.NodeType),
r: rand.New(rand.NewSource(time.Now().UnixNano())),
barrier: syncx.NewSharedCalls(),
lock: new(sync.Mutex),
unstableExpiry: mathx.NewUnstable(expiryDeviation),
stat: NewCacheStat("any"),
errNotFound: errors.New("any"),
}
var str string
err = cn.TakeWithExpire(&str, "any", func(v interface{}, expire time.Duration) error {
*v.(*string) = "value"
return nil
})
assert.Nil(t, err)
assert.Equal(t, "value", str)
assert.Nil(t, cn.GetCache("any", &str))
val, err := s.Get("any")
assert.Nil(t, err)
assert.Equal(t, `"value"`, val)
}
func TestCacheNode_String(t *testing.T) {
s, err := miniredis.Run()
assert.Nil(t, err)
defer s.Close()
cn := cacheNode{
rds: redis.NewRedis(s.Addr(), redis.NodeType),
r: rand.New(rand.NewSource(time.Now().UnixNano())),
barrier: syncx.NewSharedCalls(),
lock: new(sync.Mutex),
unstableExpiry: mathx.NewUnstable(expiryDeviation),
stat: NewCacheStat("any"),
errNotFound: errors.New("any"),
}
assert.Equal(t, s.Addr(), cn.String())
}
func TestCacheValueWithBigInt(t *testing.T) {
s, err := miniredis.Run()
if err != nil {
t.Error(err)
}
defer s.Close()
cn := cacheNode{
rds: redis.NewRedis(s.Addr(), redis.NodeType),
r: rand.New(rand.NewSource(time.Now().UnixNano())),
barrier: syncx.NewSharedCalls(),
lock: new(sync.Mutex),
unstableExpiry: mathx.NewUnstable(expiryDeviation),
stat: NewCacheStat("any"),
errNotFound: errors.New("any"),
}
const (
key = "key"
value int64 = 323427211229009810
)
assert.Nil(t, cn.SetCache(key, value))
var val interface{}
assert.Nil(t, cn.GetCache(key, &val))
assert.Equal(t, strconv.FormatInt(value, 10), fmt.Sprintf("%v", val))
}

View File

@@ -1,21 +1,45 @@
package cache package cache
import ( import "time"
"time"
"github.com/tal-tech/go-zero/core/stores/internal" const (
defaultExpiry = time.Hour * 24 * 7
defaultNotFoundExpiry = time.Minute
) )
type Option = internal.Option type (
Options struct {
Expiry time.Duration
NotFoundExpiry time.Duration
}
Option func(o *Options)
)
func newOptions(opts ...Option) Options {
var o Options
for _, opt := range opts {
opt(&o)
}
if o.Expiry <= 0 {
o.Expiry = defaultExpiry
}
if o.NotFoundExpiry <= 0 {
o.NotFoundExpiry = defaultNotFoundExpiry
}
return o
}
func WithExpiry(expiry time.Duration) Option { func WithExpiry(expiry time.Duration) Option {
return func(o *internal.Options) { return func(o *Options) {
o.Expiry = expiry o.Expiry = expiry
} }
} }
func WithNotFoundExpiry(expiry time.Duration) Option { func WithNotFoundExpiry(expiry time.Duration) Option {
return func(o *internal.Options) { return func(o *Options) {
o.NotFoundExpiry = expiry o.NotFoundExpiry = expiry
} }
} }

View File

@@ -1,4 +1,4 @@
package internal package cache
import ( import (
"sync/atomic" "sync/atomic"

View File

@@ -1,4 +1,4 @@
package internal package cache
import ( import (
"fmt" "fmt"

56
core/stores/cache/cleaner_test.go vendored Normal file
View File

@@ -0,0 +1,56 @@
package cache
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestNextDelay(t *testing.T) {
tests := []struct {
name string
input time.Duration
output time.Duration
ok bool
}{
{
name: "second",
input: time.Second,
output: time.Second * 5,
ok: true,
},
{
name: "5 seconds",
input: time.Second * 5,
output: time.Minute,
ok: true,
},
{
name: "minute",
input: time.Minute,
output: time.Minute * 5,
ok: true,
},
{
name: "5 minutes",
input: time.Minute * 5,
output: time.Hour,
ok: true,
},
{
name: "hour",
input: time.Hour,
output: 0,
ok: false,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
next, ok := nextDelay(test.input)
assert.Equal(t, test.ok, ok)
assert.Equal(t, test.output, next)
})
}
}

View File

@@ -1,4 +1,4 @@
package internal package cache
import "github.com/tal-tech/go-zero/core/stores/redis" import "github.com/tal-tech/go-zero/core/stores/redis"

View File

@@ -1,4 +1,4 @@
package internal package cache
import "strings" import "strings"

26
core/stores/cache/util_test.go vendored Normal file
View File

@@ -0,0 +1,26 @@
package cache
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestFormatKeys(t *testing.T) {
assert.Equal(t, "a,b", formatKeys([]string{"a", "b"}))
}
func TestTotalWeights(t *testing.T) {
val := TotalWeights([]NodeConf{
{
Weight: -1,
},
{
Weight: 0,
},
{
Weight: 1,
},
})
assert.Equal(t, 1, val)
}

View File

@@ -1,65 +0,0 @@
package internal
import (
"errors"
"math/rand"
"sync"
"testing"
"time"
"github.com/alicebob/miniredis"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/logx"
"github.com/tal-tech/go-zero/core/mathx"
"github.com/tal-tech/go-zero/core/stat"
"github.com/tal-tech/go-zero/core/stores/redis"
)
func init() {
logx.Disable()
stat.SetReporter(nil)
}
func TestCacheNode_DelCache(t *testing.T) {
s, err := miniredis.Run()
assert.Nil(t, err)
defer s.Close()
cn := cacheNode{
rds: redis.NewRedis(s.Addr(), redis.NodeType),
r: rand.New(rand.NewSource(time.Now().UnixNano())),
lock: new(sync.Mutex),
unstableExpiry: mathx.NewUnstable(expiryDeviation),
stat: NewCacheStat("any"),
errNotFound: errors.New("any"),
}
assert.Nil(t, cn.DelCache())
assert.Nil(t, cn.DelCache([]string{}...))
assert.Nil(t, cn.DelCache(make([]string, 0)...))
cn.SetCache("first", "one")
assert.Nil(t, cn.DelCache("first"))
cn.SetCache("first", "one")
cn.SetCache("second", "two")
assert.Nil(t, cn.DelCache("first", "second"))
}
func TestCacheNode_InvalidCache(t *testing.T) {
s, err := miniredis.Run()
assert.Nil(t, err)
defer s.Close()
cn := cacheNode{
rds: redis.NewRedis(s.Addr(), redis.NodeType),
r: rand.New(rand.NewSource(time.Now().UnixNano())),
lock: new(sync.Mutex),
unstableExpiry: mathx.NewUnstable(expiryDeviation),
stat: NewCacheStat("any"),
errNotFound: errors.New("any"),
}
s.Set("any", "value")
var str string
assert.NotNil(t, cn.GetCache("any", &str))
assert.Equal(t, "", str)
_, err = s.Get("any")
assert.Equal(t, miniredis.ErrKeyNotFound, err)
}

View File

@@ -1,33 +0,0 @@
package internal
import "time"
const (
defaultExpiry = time.Hour * 24 * 7
defaultNotFoundExpiry = time.Minute
)
type (
Options struct {
Expiry time.Duration
NotFoundExpiry time.Duration
}
Option func(o *Options)
)
func newOptions(opts ...Option) Options {
var o Options
for _, opt := range opts {
opt(&o)
}
if o.Expiry <= 0 {
o.Expiry = defaultExpiry
}
if o.NotFoundExpiry <= 0 {
o.NotFoundExpiry = defaultNotFoundExpiry
}
return o
}

View File

@@ -1,5 +1,7 @@
package kv package kv
import "github.com/tal-tech/go-zero/core/stores/internal" import (
"github.com/tal-tech/go-zero/core/stores/cache"
)
type KvConf = internal.ClusterConf type KvConf = cache.ClusterConf

View File

@@ -6,7 +6,7 @@ import (
"github.com/tal-tech/go-zero/core/errorx" "github.com/tal-tech/go-zero/core/errorx"
"github.com/tal-tech/go-zero/core/hash" "github.com/tal-tech/go-zero/core/hash"
"github.com/tal-tech/go-zero/core/stores/internal" "github.com/tal-tech/go-zero/core/stores/cache"
"github.com/tal-tech/go-zero/core/stores/redis" "github.com/tal-tech/go-zero/core/stores/redis"
) )
@@ -81,7 +81,7 @@ type (
) )
func NewStore(c KvConf) Store { func NewStore(c KvConf) Store {
if len(c) == 0 || internal.TotalWeights(c) <= 0 { if len(c) == 0 || cache.TotalWeights(c) <= 0 {
log.Fatal("no cache nodes") log.Fatal("no cache nodes")
} }

View File

@@ -6,7 +6,8 @@ import (
"github.com/alicebob/miniredis" "github.com/alicebob/miniredis"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/stores/internal" "github.com/tal-tech/go-zero/core/hash"
"github.com/tal-tech/go-zero/core/stores/cache"
"github.com/tal-tech/go-zero/core/stores/redis" "github.com/tal-tech/go-zero/core/stores/redis"
"github.com/tal-tech/go-zero/core/stringx" "github.com/tal-tech/go-zero/core/stringx"
) )
@@ -15,6 +16,10 @@ var s1, _ = miniredis.Run()
var s2, _ = miniredis.Run() var s2, _ = miniredis.Run()
func TestRedis_Exists(t *testing.T) { func TestRedis_Exists(t *testing.T) {
store := clusterStore{dispatcher: hash.NewConsistentHash()}
_, err := store.Exists("foo")
assert.NotNil(t, err)
runOnCluster(t, func(client Store) { runOnCluster(t, func(client Store) {
ok, err := client.Exists("a") ok, err := client.Exists("a")
assert.Nil(t, err) assert.Nil(t, err)
@@ -27,6 +32,10 @@ func TestRedis_Exists(t *testing.T) {
} }
func TestRedis_Eval(t *testing.T) { func TestRedis_Eval(t *testing.T) {
store := clusterStore{dispatcher: hash.NewConsistentHash()}
_, err := store.Eval(`redis.call("EXISTS", KEYS[1])`, "key1")
assert.NotNil(t, err)
runOnCluster(t, func(client Store) { runOnCluster(t, func(client Store) {
_, err := client.Eval(`redis.call("EXISTS", KEYS[1])`, "notexist") _, err := client.Eval(`redis.call("EXISTS", KEYS[1])`, "notexist")
assert.Equal(t, redis.Nil, err) assert.Equal(t, redis.Nil, err)
@@ -41,6 +50,12 @@ func TestRedis_Eval(t *testing.T) {
} }
func TestRedis_Hgetall(t *testing.T) { func TestRedis_Hgetall(t *testing.T) {
store := clusterStore{dispatcher: hash.NewConsistentHash()}
err := store.Hset("a", "aa", "aaa")
assert.NotNil(t, err)
_, err = store.Hgetall("a")
assert.NotNil(t, err)
runOnCluster(t, func(client Store) { runOnCluster(t, func(client Store) {
assert.Nil(t, client.Hset("a", "aa", "aaa")) assert.Nil(t, client.Hset("a", "aa", "aaa"))
assert.Nil(t, client.Hset("a", "bb", "bbb")) assert.Nil(t, client.Hset("a", "bb", "bbb"))
@@ -54,6 +69,10 @@ func TestRedis_Hgetall(t *testing.T) {
} }
func TestRedis_Hvals(t *testing.T) { func TestRedis_Hvals(t *testing.T) {
store := clusterStore{dispatcher: hash.NewConsistentHash()}
_, err := store.Hvals("a")
assert.NotNil(t, err)
runOnCluster(t, func(client Store) { runOnCluster(t, func(client Store) {
assert.Nil(t, client.Hset("a", "aa", "aaa")) assert.Nil(t, client.Hset("a", "aa", "aaa"))
assert.Nil(t, client.Hset("a", "bb", "bbb")) assert.Nil(t, client.Hset("a", "bb", "bbb"))
@@ -64,6 +83,10 @@ func TestRedis_Hvals(t *testing.T) {
} }
func TestRedis_Hsetnx(t *testing.T) { func TestRedis_Hsetnx(t *testing.T) {
store := clusterStore{dispatcher: hash.NewConsistentHash()}
_, err := store.Hsetnx("a", "dd", "ddd")
assert.NotNil(t, err)
runOnCluster(t, func(client Store) { runOnCluster(t, func(client Store) {
assert.Nil(t, client.Hset("a", "aa", "aaa")) assert.Nil(t, client.Hset("a", "aa", "aaa"))
assert.Nil(t, client.Hset("a", "bb", "bbb")) assert.Nil(t, client.Hset("a", "bb", "bbb"))
@@ -80,6 +103,12 @@ func TestRedis_Hsetnx(t *testing.T) {
} }
func TestRedis_HdelHlen(t *testing.T) { func TestRedis_HdelHlen(t *testing.T) {
store := clusterStore{dispatcher: hash.NewConsistentHash()}
_, err := store.Hdel("a", "aa")
assert.NotNil(t, err)
_, err = store.Hlen("a")
assert.NotNil(t, err)
runOnCluster(t, func(client Store) { runOnCluster(t, func(client Store) {
assert.Nil(t, client.Hset("a", "aa", "aaa")) assert.Nil(t, client.Hset("a", "aa", "aaa"))
assert.Nil(t, client.Hset("a", "bb", "bbb")) assert.Nil(t, client.Hset("a", "bb", "bbb"))
@@ -96,6 +125,10 @@ func TestRedis_HdelHlen(t *testing.T) {
} }
func TestRedis_HIncrBy(t *testing.T) { func TestRedis_HIncrBy(t *testing.T) {
store := clusterStore{dispatcher: hash.NewConsistentHash()}
_, err := store.Hincrby("key", "field", 3)
assert.NotNil(t, err)
runOnCluster(t, func(client Store) { runOnCluster(t, func(client Store) {
val, err := client.Hincrby("key", "field", 2) val, err := client.Hincrby("key", "field", 2)
assert.Nil(t, err) assert.Nil(t, err)
@@ -107,6 +140,10 @@ func TestRedis_HIncrBy(t *testing.T) {
} }
func TestRedis_Hkeys(t *testing.T) { func TestRedis_Hkeys(t *testing.T) {
store := clusterStore{dispatcher: hash.NewConsistentHash()}
_, err := store.Hkeys("a")
assert.NotNil(t, err)
runOnCluster(t, func(client Store) { runOnCluster(t, func(client Store) {
assert.Nil(t, client.Hset("a", "aa", "aaa")) assert.Nil(t, client.Hset("a", "aa", "aaa"))
assert.Nil(t, client.Hset("a", "bb", "bbb")) assert.Nil(t, client.Hset("a", "bb", "bbb"))
@@ -117,6 +154,10 @@ func TestRedis_Hkeys(t *testing.T) {
} }
func TestRedis_Hmget(t *testing.T) { func TestRedis_Hmget(t *testing.T) {
store := clusterStore{dispatcher: hash.NewConsistentHash()}
_, err := store.Hmget("a", "aa", "bb")
assert.NotNil(t, err)
runOnCluster(t, func(client Store) { runOnCluster(t, func(client Store) {
assert.Nil(t, client.Hset("a", "aa", "aaa")) assert.Nil(t, client.Hset("a", "aa", "aaa"))
assert.Nil(t, client.Hset("a", "bb", "bbb")) assert.Nil(t, client.Hset("a", "bb", "bbb"))
@@ -130,6 +171,12 @@ func TestRedis_Hmget(t *testing.T) {
} }
func TestRedis_Hmset(t *testing.T) { func TestRedis_Hmset(t *testing.T) {
store := clusterStore{dispatcher: hash.NewConsistentHash()}
err := store.Hmset("a", map[string]string{
"aa": "aaa",
})
assert.NotNil(t, err)
runOnCluster(t, func(client Store) { runOnCluster(t, func(client Store) {
assert.Nil(t, client.Hmset("a", map[string]string{ assert.Nil(t, client.Hmset("a", map[string]string{
"aa": "aaa", "aa": "aaa",
@@ -142,6 +189,10 @@ func TestRedis_Hmset(t *testing.T) {
} }
func TestRedis_Incr(t *testing.T) { func TestRedis_Incr(t *testing.T) {
store := clusterStore{dispatcher: hash.NewConsistentHash()}
_, err := store.Incr("a")
assert.NotNil(t, err)
runOnCluster(t, func(client Store) { runOnCluster(t, func(client Store) {
val, err := client.Incr("a") val, err := client.Incr("a")
assert.Nil(t, err) assert.Nil(t, err)
@@ -153,6 +204,10 @@ func TestRedis_Incr(t *testing.T) {
} }
func TestRedis_IncrBy(t *testing.T) { func TestRedis_IncrBy(t *testing.T) {
store := clusterStore{dispatcher: hash.NewConsistentHash()}
_, err := store.Incrby("a", 2)
assert.NotNil(t, err)
runOnCluster(t, func(client Store) { runOnCluster(t, func(client Store) {
val, err := client.Incrby("a", 2) val, err := client.Incrby("a", 2)
assert.Nil(t, err) assert.Nil(t, err)
@@ -164,6 +219,20 @@ func TestRedis_IncrBy(t *testing.T) {
} }
func TestRedis_List(t *testing.T) { func TestRedis_List(t *testing.T) {
store := clusterStore{dispatcher: hash.NewConsistentHash()}
_, err := store.Lpush("key", "value1", "value2")
assert.NotNil(t, err)
_, err = store.Rpush("key", "value3", "value4")
assert.NotNil(t, err)
_, err = store.Llen("key")
assert.NotNil(t, err)
_, err = store.Lrange("key", 0, 10)
assert.NotNil(t, err)
_, err = store.Lpop("key")
assert.NotNil(t, err)
_, err = store.Lrem("key", 0, "val")
assert.NotNil(t, err)
runOnCluster(t, func(client Store) { runOnCluster(t, func(client Store) {
val, err := client.Lpush("key", "value1", "value2") val, err := client.Lpush("key", "value1", "value2")
assert.Nil(t, err) assert.Nil(t, err)
@@ -202,6 +271,14 @@ func TestRedis_List(t *testing.T) {
} }
func TestRedis_Persist(t *testing.T) { func TestRedis_Persist(t *testing.T) {
store := clusterStore{dispatcher: hash.NewConsistentHash()}
_, err := store.Persist("key")
assert.NotNil(t, err)
err = store.Expire("key", 5)
assert.NotNil(t, err)
err = store.Expireat("key", time.Now().Unix()+5)
assert.NotNil(t, err)
runOnCluster(t, func(client Store) { runOnCluster(t, func(client Store) {
ok, err := client.Persist("key") ok, err := client.Persist("key")
assert.Nil(t, err) assert.Nil(t, err)
@@ -225,8 +302,16 @@ func TestRedis_Persist(t *testing.T) {
} }
func TestRedis_Sscan(t *testing.T) { func TestRedis_Sscan(t *testing.T) {
key := "list"
store := clusterStore{dispatcher: hash.NewConsistentHash()}
_, err := store.Sadd(key, nil)
assert.NotNil(t, err)
_, _, err = store.Sscan(key, 0, "", 100)
assert.NotNil(t, err)
_, err = store.Del(key)
assert.NotNil(t, err)
runOnCluster(t, func(client Store) { runOnCluster(t, func(client Store) {
key := "list"
var list []string var list []string
for i := 0; i < 1550; i++ { for i := 0; i < 1550; i++ {
list = append(list, stringx.Randn(i)) list = append(list, stringx.Randn(i))
@@ -254,6 +339,20 @@ func TestRedis_Sscan(t *testing.T) {
} }
func TestRedis_Set(t *testing.T) { func TestRedis_Set(t *testing.T) {
store := clusterStore{dispatcher: hash.NewConsistentHash()}
_, err := store.Scard("key")
assert.NotNil(t, err)
_, err = store.Sismember("key", 2)
assert.NotNil(t, err)
_, err = store.Srem("key", 3, 4)
assert.NotNil(t, err)
_, err = store.Smembers("key")
assert.NotNil(t, err)
_, err = store.Srandmember("key", 1)
assert.NotNil(t, err)
_, err = store.Spop("key")
assert.NotNil(t, err)
runOnCluster(t, func(client Store) { runOnCluster(t, func(client Store) {
num, err := client.Sadd("key", 1, 2, 3, 4) num, err := client.Sadd("key", 1, 2, 3, 4)
assert.Nil(t, err) assert.Nil(t, err)
@@ -290,6 +389,14 @@ func TestRedis_Set(t *testing.T) {
} }
func TestRedis_SetGetDel(t *testing.T) { func TestRedis_SetGetDel(t *testing.T) {
store := clusterStore{dispatcher: hash.NewConsistentHash()}
err := store.Set("hello", "world")
assert.NotNil(t, err)
_, err = store.Get("hello")
assert.NotNil(t, err)
_, err = store.Del("hello")
assert.NotNil(t, err)
runOnCluster(t, func(client Store) { runOnCluster(t, func(client Store) {
err := client.Set("hello", "world") err := client.Set("hello", "world")
assert.Nil(t, err) assert.Nil(t, err)
@@ -303,6 +410,16 @@ func TestRedis_SetGetDel(t *testing.T) {
} }
func TestRedis_SetExNx(t *testing.T) { func TestRedis_SetExNx(t *testing.T) {
store := clusterStore{dispatcher: hash.NewConsistentHash()}
err := store.Setex("hello", "world", 5)
assert.NotNil(t, err)
_, err = store.Setnx("newhello", "newworld")
assert.NotNil(t, err)
_, err = store.Ttl("hello")
assert.NotNil(t, err)
_, err = store.SetnxEx("newhello", "newworld", 5)
assert.NotNil(t, err)
runOnCluster(t, func(client Store) { runOnCluster(t, func(client Store) {
err := client.Setex("hello", "world", 5) err := client.Setex("hello", "world", 5)
assert.Nil(t, err) assert.Nil(t, err)
@@ -337,6 +454,16 @@ func TestRedis_SetExNx(t *testing.T) {
} }
func TestRedis_SetGetDelHashField(t *testing.T) { func TestRedis_SetGetDelHashField(t *testing.T) {
store := clusterStore{dispatcher: hash.NewConsistentHash()}
err := store.Hset("key", "field", "value")
assert.NotNil(t, err)
_, err = store.Hget("key", "field")
assert.NotNil(t, err)
_, err = store.Hexists("key", "field")
assert.NotNil(t, err)
_, err = store.Hdel("key", "field")
assert.NotNil(t, err)
runOnCluster(t, func(client Store) { runOnCluster(t, func(client Store) {
err := client.Hset("key", "field", "value") err := client.Hset("key", "field", "value")
assert.Nil(t, err) assert.Nil(t, err)
@@ -356,6 +483,48 @@ func TestRedis_SetGetDelHashField(t *testing.T) {
} }
func TestRedis_SortedSet(t *testing.T) { func TestRedis_SortedSet(t *testing.T) {
store := clusterStore{dispatcher: hash.NewConsistentHash()}
_, err := store.Zadd("key", 1, "value1")
assert.NotNil(t, err)
_, err = store.Zscore("key", "value1")
assert.NotNil(t, err)
_, err = store.Zcount("key", 6, 7)
assert.NotNil(t, err)
_, err = store.Zincrby("key", 3, "value1")
assert.NotNil(t, err)
_, err = store.Zrank("key", "value2")
assert.NotNil(t, err)
_, err = store.Zrem("key", "value2", "value3")
assert.NotNil(t, err)
_, err = store.Zremrangebyscore("key", 6, 7)
assert.NotNil(t, err)
_, err = store.Zremrangebyrank("key", 1, 2)
assert.NotNil(t, err)
_, err = store.Zcard("key")
assert.NotNil(t, err)
_, err = store.Zrange("key", 0, -1)
assert.NotNil(t, err)
_, err = store.Zrevrange("key", 0, -1)
assert.NotNil(t, err)
_, err = store.ZrangeWithScores("key", 0, -1)
assert.NotNil(t, err)
_, err = store.ZrangebyscoreWithScores("key", 5, 8)
assert.NotNil(t, err)
_, err = store.ZrangebyscoreWithScoresAndLimit("key", 5, 8, 1, 1)
assert.NotNil(t, err)
_, err = store.ZrevrangebyscoreWithScores("key", 5, 8)
assert.NotNil(t, err)
_, err = store.ZrevrangebyscoreWithScoresAndLimit("key", 5, 8, 1, 1)
assert.NotNil(t, err)
_, err = store.Zadds("key", redis.Pair{
Key: "value2",
Score: 6,
}, redis.Pair{
Key: "value3",
Score: 7,
})
assert.NotNil(t, err)
runOnCluster(t, func(client Store) { runOnCluster(t, func(client Store) {
ok, err := client.Zadd("key", 1, "value1") ok, err := client.Zadd("key", 1, "value1")
assert.Nil(t, err) assert.Nil(t, err)
@@ -471,6 +640,30 @@ func TestRedis_SortedSet(t *testing.T) {
Score: 5, Score: 5,
}, },
}, pairs) }, pairs)
val, err = client.Zadds("key", redis.Pair{
Key: "value2",
Score: 6,
}, redis.Pair{
Key: "value3",
Score: 7,
})
assert.Nil(t, err)
assert.Equal(t, int64(2), val)
})
}
func TestRedis_HyperLogLog(t *testing.T) {
store := clusterStore{dispatcher: hash.NewConsistentHash()}
_, err := store.Pfadd("key")
assert.NotNil(t, err)
_, err = store.Pfcount("key")
assert.NotNil(t, err)
runOnCluster(t, func(cluster Store) {
_, err := cluster.Pfadd("key")
assert.NotNil(t, err)
_, err = cluster.Pfcount("key")
assert.NotNil(t, err)
}) })
} }
@@ -478,7 +671,7 @@ func runOnCluster(t *testing.T, fn func(cluster Store)) {
s1.FlushAll() s1.FlushAll()
s2.FlushAll() s2.FlushAll()
store := NewStore([]internal.NodeConf{ store := NewStore([]cache.NodeConf{
{ {
RedisConf: redis.RedisConf{ RedisConf: redis.RedisConf{
Host: s1.Addr(), Host: s1.Addr(),

View File

@@ -7,6 +7,7 @@ import (
"github.com/globalsign/mgo" "github.com/globalsign/mgo"
"github.com/tal-tech/go-zero/core/breaker" "github.com/tal-tech/go-zero/core/breaker"
"github.com/tal-tech/go-zero/core/logx" "github.com/tal-tech/go-zero/core/logx"
"github.com/tal-tech/go-zero/core/stores/mongo/internal"
"github.com/tal-tech/go-zero/core/timex" "github.com/tal-tech/go-zero/core/timex"
) )
@@ -29,8 +30,9 @@ type (
} }
decoratedCollection struct { decoratedCollection struct {
*mgo.Collection name string
brk breaker.Breaker collection internal.MgoCollection
brk breaker.Breaker
} }
keepablePromise struct { keepablePromise struct {
@@ -41,7 +43,8 @@ type (
func newCollection(collection *mgo.Collection) Collection { func newCollection(collection *mgo.Collection) Collection {
return &decoratedCollection{ return &decoratedCollection{
Collection: collection, name: collection.FullName,
collection: collection,
brk: breaker.NewBreaker(), brk: breaker.NewBreaker(),
} }
} }
@@ -54,7 +57,7 @@ func (c *decoratedCollection) Find(query interface{}) Query {
startTime := timex.Now() startTime := timex.Now()
return promisedQuery{ return promisedQuery{
Query: c.Collection.Find(query), Query: c.collection.Find(query),
promise: keepablePromise{ promise: keepablePromise{
promise: promise, promise: promise,
log: func(err error) { log: func(err error) {
@@ -73,7 +76,7 @@ func (c *decoratedCollection) FindId(id interface{}) Query {
startTime := timex.Now() startTime := timex.Now()
return promisedQuery{ return promisedQuery{
Query: c.Collection.FindId(id), Query: c.collection.FindId(id),
promise: keepablePromise{ promise: keepablePromise{
promise: promise, promise: promise,
log: func(err error) { log: func(err error) {
@@ -92,7 +95,7 @@ func (c *decoratedCollection) Insert(docs ...interface{}) (err error) {
c.logDuration("insert", duration, err, docs...) c.logDuration("insert", duration, err, docs...)
}() }()
return c.Collection.Insert(docs...) return c.collection.Insert(docs...)
}, acceptable) }, acceptable)
} }
@@ -104,7 +107,7 @@ func (c *decoratedCollection) Pipe(pipeline interface{}) Pipe {
startTime := timex.Now() startTime := timex.Now()
return promisedPipe{ return promisedPipe{
Pipe: c.Collection.Pipe(pipeline), Pipe: c.collection.Pipe(pipeline),
promise: keepablePromise{ promise: keepablePromise{
promise: promise, promise: promise,
log: func(err error) { log: func(err error) {
@@ -123,7 +126,7 @@ func (c *decoratedCollection) Remove(selector interface{}) (err error) {
c.logDuration("remove", duration, err, selector) c.logDuration("remove", duration, err, selector)
}() }()
return c.Collection.Remove(selector) return c.collection.Remove(selector)
}, acceptable) }, acceptable)
} }
@@ -135,7 +138,7 @@ func (c *decoratedCollection) RemoveAll(selector interface{}) (info *mgo.ChangeI
c.logDuration("removeAll", duration, err, selector) c.logDuration("removeAll", duration, err, selector)
}() }()
info, err = c.Collection.RemoveAll(selector) info, err = c.collection.RemoveAll(selector)
return err return err
}, acceptable) }, acceptable)
@@ -150,7 +153,7 @@ func (c *decoratedCollection) RemoveId(id interface{}) (err error) {
c.logDuration("removeId", duration, err, id) c.logDuration("removeId", duration, err, id)
}() }()
return c.Collection.RemoveId(id) return c.collection.RemoveId(id)
}, acceptable) }, acceptable)
} }
@@ -162,7 +165,7 @@ func (c *decoratedCollection) Update(selector, update interface{}) (err error) {
c.logDuration("update", duration, err, selector, update) c.logDuration("update", duration, err, selector, update)
}() }()
return c.Collection.Update(selector, update) return c.collection.Update(selector, update)
}, acceptable) }, acceptable)
} }
@@ -174,7 +177,7 @@ func (c *decoratedCollection) UpdateId(id, update interface{}) (err error) {
c.logDuration("updateId", duration, err, id, update) c.logDuration("updateId", duration, err, id, update)
}() }()
return c.Collection.UpdateId(id, update) return c.collection.UpdateId(id, update)
}, acceptable) }, acceptable)
} }
@@ -186,7 +189,7 @@ func (c *decoratedCollection) Upsert(selector, update interface{}) (info *mgo.Ch
c.logDuration("upsert", duration, err, selector, update) c.logDuration("upsert", duration, err, selector, update)
}() }()
info, err = c.Collection.Upsert(selector, update) info, err = c.collection.Upsert(selector, update)
return err return err
}, acceptable) }, acceptable)
@@ -200,17 +203,17 @@ func (c *decoratedCollection) logDuration(method string, duration time.Duration,
} else if err != nil { } else if err != nil {
if duration > slowThreshold { if duration > slowThreshold {
logx.WithDuration(duration).Slowf("[MONGO] mongo(%s) - slowcall - %s - fail(%s) - %s", logx.WithDuration(duration).Slowf("[MONGO] mongo(%s) - slowcall - %s - fail(%s) - %s",
c.FullName, method, err.Error(), string(content)) c.name, method, err.Error(), string(content))
} else { } else {
logx.WithDuration(duration).Infof("mongo(%s) - %s - fail(%s) - %s", logx.WithDuration(duration).Infof("mongo(%s) - %s - fail(%s) - %s",
c.FullName, method, err.Error(), string(content)) c.name, method, err.Error(), string(content))
} }
} else { } else {
if duration > slowThreshold { if duration > slowThreshold {
logx.WithDuration(duration).Slowf("[MONGO] mongo(%s) - slowcall - %s - ok - %s", logx.WithDuration(duration).Slowf("[MONGO] mongo(%s) - slowcall - %s - ok - %s",
c.FullName, method, string(content)) c.name, method, string(content))
} else { } else {
logx.WithDuration(duration).Infof("mongo(%s) - %s - ok - %s", c.FullName, method, string(content)) logx.WithDuration(duration).Infof("mongo(%s) - %s - ok - %s", c.name, method, string(content))
} }
} }
} }

View File

@@ -5,10 +5,20 @@ import (
"testing" "testing"
"github.com/globalsign/mgo" "github.com/globalsign/mgo"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/breaker"
"github.com/tal-tech/go-zero/core/logx"
"github.com/tal-tech/go-zero/core/stores/mongo/internal"
"github.com/tal-tech/go-zero/core/stringx" "github.com/tal-tech/go-zero/core/stringx"
) )
var errDummy = errors.New("dummy")
func init() {
logx.Disable()
}
func TestKeepPromise_accept(t *testing.T) { func TestKeepPromise_accept(t *testing.T) {
p := new(mockPromise) p := new(mockPromise)
kp := keepablePromise{ kp := keepablePromise{
@@ -56,6 +66,206 @@ func TestKeepPromise_keep(t *testing.T) {
} }
} }
func TestNewCollection(t *testing.T) {
col := newCollection(&mgo.Collection{
Database: nil,
Name: "foo",
FullName: "bar",
})
assert.Equal(t, "bar", col.(*decoratedCollection).name)
}
func TestCollectionFind(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
var query mgo.Query
col := internal.NewMockMgoCollection(ctrl)
col.EXPECT().Find(gomock.Any()).Return(&query)
c := decoratedCollection{
collection: col,
brk: breaker.NewBreaker(),
}
actual := c.Find(nil)
switch v := actual.(type) {
case promisedQuery:
assert.Equal(t, &query, v.Query)
assert.Equal(t, errDummy, v.promise.keep(errDummy))
default:
t.Fail()
}
c.brk = new(dropBreaker)
actual = c.Find(nil)
assert.Equal(t, rejectedQuery{}, actual)
}
func TestCollectionFindId(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
var query mgo.Query
col := internal.NewMockMgoCollection(ctrl)
col.EXPECT().FindId(gomock.Any()).Return(&query)
c := decoratedCollection{
collection: col,
brk: breaker.NewBreaker(),
}
actual := c.FindId(nil)
switch v := actual.(type) {
case promisedQuery:
assert.Equal(t, &query, v.Query)
assert.Equal(t, errDummy, v.promise.keep(errDummy))
default:
t.Fail()
}
c.brk = new(dropBreaker)
actual = c.FindId(nil)
assert.Equal(t, rejectedQuery{}, actual)
}
func TestCollectionInsert(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
col := internal.NewMockMgoCollection(ctrl)
col.EXPECT().Insert(nil, nil).Return(errDummy)
c := decoratedCollection{
collection: col,
brk: breaker.NewBreaker(),
}
err := c.Insert(nil, nil)
assert.Equal(t, errDummy, err)
c.brk = new(dropBreaker)
err = c.Insert(nil, nil)
assert.Equal(t, errDummy, err)
}
func TestCollectionPipe(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
var pipe mgo.Pipe
col := internal.NewMockMgoCollection(ctrl)
col.EXPECT().Pipe(gomock.Any()).Return(&pipe)
c := decoratedCollection{
collection: col,
brk: breaker.NewBreaker(),
}
actual := c.Pipe(nil)
switch v := actual.(type) {
case promisedPipe:
assert.Equal(t, &pipe, v.Pipe)
assert.Equal(t, errDummy, v.promise.keep(errDummy))
default:
t.Fail()
}
c.brk = new(dropBreaker)
actual = c.Pipe(nil)
assert.Equal(t, rejectedPipe{}, actual)
}
func TestCollectionRemove(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
col := internal.NewMockMgoCollection(ctrl)
col.EXPECT().Remove(gomock.Any()).Return(errDummy)
c := decoratedCollection{
collection: col,
brk: breaker.NewBreaker(),
}
err := c.Remove(nil)
assert.Equal(t, errDummy, err)
c.brk = new(dropBreaker)
err = c.Remove(nil)
assert.Equal(t, errDummy, err)
}
func TestCollectionRemoveAll(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
col := internal.NewMockMgoCollection(ctrl)
col.EXPECT().RemoveAll(gomock.Any()).Return(nil, errDummy)
c := decoratedCollection{
collection: col,
brk: breaker.NewBreaker(),
}
_, err := c.RemoveAll(nil)
assert.Equal(t, errDummy, err)
c.brk = new(dropBreaker)
_, err = c.RemoveAll(nil)
assert.Equal(t, errDummy, err)
}
func TestCollectionRemoveId(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
col := internal.NewMockMgoCollection(ctrl)
col.EXPECT().RemoveId(gomock.Any()).Return(errDummy)
c := decoratedCollection{
collection: col,
brk: breaker.NewBreaker(),
}
err := c.RemoveId(nil)
assert.Equal(t, errDummy, err)
c.brk = new(dropBreaker)
err = c.RemoveId(nil)
assert.Equal(t, errDummy, err)
}
func TestCollectionUpdate(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
col := internal.NewMockMgoCollection(ctrl)
col.EXPECT().Update(gomock.Any(), gomock.Any()).Return(errDummy)
c := decoratedCollection{
collection: col,
brk: breaker.NewBreaker(),
}
err := c.Update(nil, nil)
assert.Equal(t, errDummy, err)
c.brk = new(dropBreaker)
err = c.Update(nil, nil)
assert.Equal(t, errDummy, err)
}
func TestCollectionUpdateId(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
col := internal.NewMockMgoCollection(ctrl)
col.EXPECT().UpdateId(gomock.Any(), gomock.Any()).Return(errDummy)
c := decoratedCollection{
collection: col,
brk: breaker.NewBreaker(),
}
err := c.UpdateId(nil, nil)
assert.Equal(t, errDummy, err)
c.brk = new(dropBreaker)
err = c.UpdateId(nil, nil)
assert.Equal(t, errDummy, err)
}
func TestCollectionUpsert(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
col := internal.NewMockMgoCollection(ctrl)
col.EXPECT().Upsert(gomock.Any(), gomock.Any()).Return(nil, errDummy)
c := decoratedCollection{
collection: col,
brk: breaker.NewBreaker(),
}
_, err := c.Upsert(nil, nil)
assert.Equal(t, errDummy, err)
c.brk = new(dropBreaker)
_, err = c.Upsert(nil, nil)
assert.Equal(t, errDummy, err)
}
type mockPromise struct { type mockPromise struct {
accepted bool accepted bool
reason string reason string
@@ -68,3 +278,31 @@ func (p *mockPromise) Accept() {
func (p *mockPromise) Reject(reason string) { func (p *mockPromise) Reject(reason string) {
p.reason = reason p.reason = reason
} }
type dropBreaker struct {
}
func (d *dropBreaker) Name() string {
return "dummy"
}
func (d *dropBreaker) Allow() (breaker.Promise, error) {
return nil, errDummy
}
func (d *dropBreaker) Do(req func() error) error {
return nil
}
func (d *dropBreaker) DoWithAcceptable(req func() error, acceptable breaker.Acceptable) error {
return errDummy
}
func (d *dropBreaker) DoWithFallback(req func() error, fallback func(err error) error) error {
return nil
}
func (d *dropBreaker) DoWithFallbackAcceptable(req func() error, fallback func(err error) error,
acceptable breaker.Acceptable) error {
return nil
}

View File

@@ -0,0 +1,17 @@
//go:generate mockgen -package internal -destination collection_mock.go -source collection.go
package internal
import "github.com/globalsign/mgo"
type MgoCollection interface {
Find(query interface{}) *mgo.Query
FindId(id interface{}) *mgo.Query
Insert(docs ...interface{}) error
Pipe(pipeline interface{}) *mgo.Pipe
Remove(selector interface{}) error
RemoveAll(selector interface{}) (*mgo.ChangeInfo, error)
RemoveId(id interface{}) error
Update(selector, update interface{}) error
UpdateId(id, update interface{}) error
Upsert(selector, update interface{}) (*mgo.ChangeInfo, error)
}

View File

@@ -0,0 +1,180 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: collection.go
// Package internal is a generated GoMock package.
package internal
import (
mgo "github.com/globalsign/mgo"
gomock "github.com/golang/mock/gomock"
reflect "reflect"
)
// MockMgoCollection is a mock of MgoCollection interface
type MockMgoCollection struct {
ctrl *gomock.Controller
recorder *MockMgoCollectionMockRecorder
}
// MockMgoCollectionMockRecorder is the mock recorder for MockMgoCollection
type MockMgoCollectionMockRecorder struct {
mock *MockMgoCollection
}
// NewMockMgoCollection creates a new mock instance
func NewMockMgoCollection(ctrl *gomock.Controller) *MockMgoCollection {
mock := &MockMgoCollection{ctrl: ctrl}
mock.recorder = &MockMgoCollectionMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockMgoCollection) EXPECT() *MockMgoCollectionMockRecorder {
return m.recorder
}
// Find mocks base method
func (m *MockMgoCollection) Find(query interface{}) *mgo.Query {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Find", query)
ret0, _ := ret[0].(*mgo.Query)
return ret0
}
// Find indicates an expected call of Find
func (mr *MockMgoCollectionMockRecorder) Find(query interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Find", reflect.TypeOf((*MockMgoCollection)(nil).Find), query)
}
// FindId mocks base method
func (m *MockMgoCollection) FindId(id interface{}) *mgo.Query {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FindId", id)
ret0, _ := ret[0].(*mgo.Query)
return ret0
}
// FindId indicates an expected call of FindId
func (mr *MockMgoCollectionMockRecorder) FindId(id interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindId", reflect.TypeOf((*MockMgoCollection)(nil).FindId), id)
}
// Insert mocks base method
func (m *MockMgoCollection) Insert(docs ...interface{}) error {
m.ctrl.T.Helper()
varargs := []interface{}{}
for _, a := range docs {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Insert", varargs...)
ret0, _ := ret[0].(error)
return ret0
}
// Insert indicates an expected call of Insert
func (mr *MockMgoCollectionMockRecorder) Insert(docs ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Insert", reflect.TypeOf((*MockMgoCollection)(nil).Insert), docs...)
}
// Pipe mocks base method
func (m *MockMgoCollection) Pipe(pipeline interface{}) *mgo.Pipe {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Pipe", pipeline)
ret0, _ := ret[0].(*mgo.Pipe)
return ret0
}
// Pipe indicates an expected call of Pipe
func (mr *MockMgoCollectionMockRecorder) Pipe(pipeline interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Pipe", reflect.TypeOf((*MockMgoCollection)(nil).Pipe), pipeline)
}
// Remove mocks base method
func (m *MockMgoCollection) Remove(selector interface{}) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Remove", selector)
ret0, _ := ret[0].(error)
return ret0
}
// Remove indicates an expected call of Remove
func (mr *MockMgoCollectionMockRecorder) Remove(selector interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockMgoCollection)(nil).Remove), selector)
}
// RemoveAll mocks base method
func (m *MockMgoCollection) RemoveAll(selector interface{}) (*mgo.ChangeInfo, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RemoveAll", selector)
ret0, _ := ret[0].(*mgo.ChangeInfo)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// RemoveAll indicates an expected call of RemoveAll
func (mr *MockMgoCollectionMockRecorder) RemoveAll(selector interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveAll", reflect.TypeOf((*MockMgoCollection)(nil).RemoveAll), selector)
}
// RemoveId mocks base method
func (m *MockMgoCollection) RemoveId(id interface{}) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RemoveId", id)
ret0, _ := ret[0].(error)
return ret0
}
// RemoveId indicates an expected call of RemoveId
func (mr *MockMgoCollectionMockRecorder) RemoveId(id interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveId", reflect.TypeOf((*MockMgoCollection)(nil).RemoveId), id)
}
// Update mocks base method
func (m *MockMgoCollection) Update(selector, update interface{}) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Update", selector, update)
ret0, _ := ret[0].(error)
return ret0
}
// Update indicates an expected call of Update
func (mr *MockMgoCollectionMockRecorder) Update(selector, update interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockMgoCollection)(nil).Update), selector, update)
}
// UpdateId mocks base method
func (m *MockMgoCollection) UpdateId(id, update interface{}) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateId", id, update)
ret0, _ := ret[0].(error)
return ret0
}
// UpdateId indicates an expected call of UpdateId
func (mr *MockMgoCollectionMockRecorder) UpdateId(id, update interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateId", reflect.TypeOf((*MockMgoCollection)(nil).UpdateId), id, update)
}
// Upsert mocks base method
func (m *MockMgoCollection) Upsert(selector, update interface{}) (*mgo.ChangeInfo, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Upsert", selector, update)
ret0, _ := ret[0].(*mgo.ChangeInfo)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Upsert indicates an expected call of Upsert
func (mr *MockMgoCollectionMockRecorder) Upsert(selector, update interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Upsert", reflect.TypeOf((*MockMgoCollection)(nil).Upsert), selector, update)
}

View File

@@ -22,8 +22,8 @@ type (
} }
) )
func MustNewModel(url, database, collection string, opts ...Option) *Model { func MustNewModel(url, collection string, opts ...Option) *Model {
model, err := NewModel(url, database, collection, opts...) model, err := NewModel(url, collection, opts...)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
@@ -31,15 +31,16 @@ func MustNewModel(url, database, collection string, opts ...Option) *Model {
return model return model
} }
func NewModel(url, database, collection string, opts ...Option) (*Model, error) { func NewModel(url, collection string, opts ...Option) (*Model, error) {
session, err := getConcurrentSession(url) session, err := getConcurrentSession(url)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &Model{ return &Model{
session: session, session: session,
db: session.DB(database), // If name is empty, the database name provided in the dialed URL is used instead
db: session.DB(""),
collection: collection, collection: collection,
opts: opts, opts: opts,
}, nil }, nil

View File

@@ -2,7 +2,7 @@ package mongoc
import ( import (
"github.com/globalsign/mgo" "github.com/globalsign/mgo"
"github.com/tal-tech/go-zero/core/stores/internal" "github.com/tal-tech/go-zero/core/stores/cache"
"github.com/tal-tech/go-zero/core/stores/mongo" "github.com/tal-tech/go-zero/core/stores/mongo"
"github.com/tal-tech/go-zero/core/syncx" "github.com/tal-tech/go-zero/core/syncx"
) )
@@ -12,7 +12,7 @@ var (
// can't use one SharedCalls per conn, because multiple conns may share the same cache key. // can't use one SharedCalls per conn, because multiple conns may share the same cache key.
sharedCalls = syncx.NewSharedCalls() sharedCalls = syncx.NewSharedCalls()
stats = internal.NewCacheStat("mongoc") stats = cache.NewCacheStat("mongoc")
) )
type ( type (
@@ -20,11 +20,11 @@ type (
cachedCollection struct { cachedCollection struct {
collection mongo.Collection collection mongo.Collection
cache internal.Cache cache cache.Cache
} }
) )
func newCollection(collection mongo.Collection, c internal.Cache) *cachedCollection { func newCollection(collection mongo.Collection, c cache.Cache) *cachedCollection {
return &cachedCollection{ return &cachedCollection{
collection: collection, collection: collection,
cache: c, cache: c,

View File

@@ -16,7 +16,7 @@ import (
"github.com/globalsign/mgo/bson" "github.com/globalsign/mgo/bson"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/stat" "github.com/tal-tech/go-zero/core/stat"
"github.com/tal-tech/go-zero/core/stores/internal" "github.com/tal-tech/go-zero/core/stores/cache"
"github.com/tal-tech/go-zero/core/stores/mongo" "github.com/tal-tech/go-zero/core/stores/mongo"
"github.com/tal-tech/go-zero/core/stores/redis" "github.com/tal-tech/go-zero/core/stores/redis"
) )
@@ -33,7 +33,7 @@ func TestStat(t *testing.T) {
} }
r := redis.NewRedis(s.Addr(), redis.NodeType) r := redis.NewRedis(s.Addr(), redis.NodeType)
cach := internal.NewCacheNode(r, sharedCalls, stats, mgo.ErrNotFound) cach := cache.NewCacheNode(r, sharedCalls, stats, mgo.ErrNotFound)
c := newCollection(dummyConn{}, cach) c := newCollection(dummyConn{}, cach)
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
@@ -56,7 +56,7 @@ func TestStatCacheFails(t *testing.T) {
defer log.SetOutput(os.Stdout) defer log.SetOutput(os.Stdout)
r := redis.NewRedis("localhost:59999", redis.NodeType) r := redis.NewRedis("localhost:59999", redis.NodeType)
cach := internal.NewCacheNode(r, sharedCalls, stats, mgo.ErrNotFound) cach := cache.NewCacheNode(r, sharedCalls, stats, mgo.ErrNotFound)
c := newCollection(dummyConn{}, cach) c := newCollection(dummyConn{}, cach)
for i := 0; i < 20; i++ { for i := 0; i < 20; i++ {
@@ -79,7 +79,7 @@ func TestStatDbFails(t *testing.T) {
} }
r := redis.NewRedis(s.Addr(), redis.NodeType) r := redis.NewRedis(s.Addr(), redis.NodeType)
cach := internal.NewCacheNode(r, sharedCalls, stats, mgo.ErrNotFound) cach := cache.NewCacheNode(r, sharedCalls, stats, mgo.ErrNotFound)
c := newCollection(dummyConn{}, cach) c := newCollection(dummyConn{}, cach)
for i := 0; i < 20; i++ { for i := 0; i < 20; i++ {
@@ -103,7 +103,7 @@ func TestStatFromMemory(t *testing.T) {
} }
r := redis.NewRedis(s.Addr(), redis.NodeType) r := redis.NewRedis(s.Addr(), redis.NodeType)
cach := internal.NewCacheNode(r, sharedCalls, stats, mgo.ErrNotFound) cach := cache.NewCacheNode(r, sharedCalls, stats, mgo.ErrNotFound)
c := newCollection(dummyConn{}, cach) c := newCollection(dummyConn{}, cach)
var all sync.WaitGroup var all sync.WaitGroup

View File

@@ -5,19 +5,18 @@ import (
"github.com/globalsign/mgo" "github.com/globalsign/mgo"
"github.com/tal-tech/go-zero/core/stores/cache" "github.com/tal-tech/go-zero/core/stores/cache"
"github.com/tal-tech/go-zero/core/stores/internal"
"github.com/tal-tech/go-zero/core/stores/mongo" "github.com/tal-tech/go-zero/core/stores/mongo"
"github.com/tal-tech/go-zero/core/stores/redis" "github.com/tal-tech/go-zero/core/stores/redis"
) )
type Model struct { type Model struct {
*mongo.Model *mongo.Model
cache internal.Cache cache cache.Cache
generateCollection func(*mgo.Session) *cachedCollection generateCollection func(*mgo.Session) *cachedCollection
} }
func MustNewNodeModel(url, database, collection string, rds *redis.Redis, opts ...cache.Option) *Model { func MustNewNodeModel(url, collection string, rds *redis.Redis, opts ...cache.Option) *Model {
model, err := NewNodeModel(url, database, collection, rds, opts...) model, err := NewNodeModel(url, collection, rds, opts...)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
@@ -25,8 +24,8 @@ func MustNewNodeModel(url, database, collection string, rds *redis.Redis, opts .
return model return model
} }
func MustNewModel(url, database, collection string, c cache.CacheConf, opts ...cache.Option) *Model { func MustNewModel(url, collection string, c cache.CacheConf, opts ...cache.Option) *Model {
model, err := NewModel(url, database, collection, c, opts...) model, err := NewModel(url, collection, c, opts...)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
@@ -34,16 +33,16 @@ func MustNewModel(url, database, collection string, c cache.CacheConf, opts ...c
return model return model
} }
func NewNodeModel(url, database, collection string, rds *redis.Redis, opts ...cache.Option) (*Model, error) { func NewNodeModel(url, collection string, rds *redis.Redis, opts ...cache.Option) (*Model, error) {
c := internal.NewCacheNode(rds, sharedCalls, stats, mgo.ErrNotFound, opts...) c := cache.NewCacheNode(rds, sharedCalls, stats, mgo.ErrNotFound, opts...)
return createModel(url, database, collection, c, func(collection mongo.Collection) *cachedCollection { return createModel(url, collection, c, func(collection mongo.Collection) *cachedCollection {
return newCollection(collection, c) return newCollection(collection, c)
}) })
} }
func NewModel(url, database, collection string, conf cache.CacheConf, opts ...cache.Option) (*Model, error) { func NewModel(url, collection string, conf cache.CacheConf, opts ...cache.Option) (*Model, error) {
c := internal.NewCache(conf, sharedCalls, stats, mgo.ErrNotFound, opts...) c := cache.NewCache(conf, sharedCalls, stats, mgo.ErrNotFound, opts...)
return createModel(url, database, collection, c, func(collection mongo.Collection) *cachedCollection { return createModel(url, collection, c, func(collection mongo.Collection) *cachedCollection {
return newCollection(collection, c) return newCollection(collection, c)
}) })
} }
@@ -224,9 +223,9 @@ func (mm *Model) pipe(fn func(c *cachedCollection) mongo.Pipe) (mongo.Pipe, erro
return fn(mm.GetCollection(session)), nil return fn(mm.GetCollection(session)), nil
} }
func createModel(url, database, collection string, c internal.Cache, func createModel(url, collection string, c cache.Cache,
create func(mongo.Collection) *cachedCollection) (*Model, error) { create func(mongo.Collection) *cachedCollection) (*Model, error) {
model, err := mongo.NewModel(url, database, collection) model, err := mongo.NewModel(url, collection)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -0,0 +1,110 @@
package redis
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/stringx"
)
func TestRedisConf(t *testing.T) {
tests := []struct {
name string
RedisConf
ok bool
}{
{
name: "missing host",
RedisConf: RedisConf{
Host: "",
Type: NodeType,
Pass: "",
},
ok: false,
},
{
name: "missing type",
RedisConf: RedisConf{
Host: "localhost:6379",
Type: "",
Pass: "",
},
ok: false,
},
{
name: "ok",
RedisConf: RedisConf{
Host: "localhost:6379",
Type: NodeType,
Pass: "",
},
ok: true,
},
}
for _, test := range tests {
t.Run(stringx.RandId(), func(t *testing.T) {
if test.ok {
assert.Nil(t, test.RedisConf.Validate())
assert.NotNil(t, test.RedisConf.NewRedis())
} else {
assert.NotNil(t, test.RedisConf.Validate())
}
})
}
}
func TestRedisKeyConf(t *testing.T) {
tests := []struct {
name string
RedisKeyConf
ok bool
}{
{
name: "missing host",
RedisKeyConf: RedisKeyConf{
RedisConf: RedisConf{
Host: "",
Type: NodeType,
Pass: "",
},
Key: "foo",
},
ok: false,
},
{
name: "missing key",
RedisKeyConf: RedisKeyConf{
RedisConf: RedisConf{
Host: "localhost:6379",
Type: NodeType,
Pass: "",
},
Key: "",
},
ok: false,
},
{
name: "ok",
RedisKeyConf: RedisKeyConf{
RedisConf: RedisConf{
Host: "localhost:6379",
Type: NodeType,
Pass: "",
},
Key: "foo",
},
ok: true,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
if test.ok {
assert.Nil(t, test.RedisKeyConf.Validate())
} else {
assert.NotNil(t, test.RedisKeyConf.Validate())
}
})
}
}

View File

@@ -7,11 +7,14 @@ import (
"time" "time"
"github.com/alicebob/miniredis" "github.com/alicebob/miniredis"
red "github.com/go-redis/redis"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestRedis_Exists(t *testing.T) { func TestRedis_Exists(t *testing.T) {
runOnRedis(t, func(client *Redis) { runOnRedis(t, func(client *Redis) {
_, err := NewRedis(client.Addr, "").Exists("a")
assert.NotNil(t, err)
ok, err := client.Exists("a") ok, err := client.Exists("a")
assert.Nil(t, err) assert.Nil(t, err)
assert.False(t, ok) assert.False(t, ok)
@@ -24,7 +27,9 @@ func TestRedis_Exists(t *testing.T) {
func TestRedis_Eval(t *testing.T) { func TestRedis_Eval(t *testing.T) {
runOnRedis(t, func(client *Redis) { runOnRedis(t, func(client *Redis) {
_, err := client.Eval(`redis.call("EXISTS", KEYS[1])`, []string{"notexist"}) _, err := NewRedis(client.Addr, "").Eval(`redis.call("EXISTS", KEYS[1])`, []string{"notexist"})
assert.NotNil(t, err)
_, err = client.Eval(`redis.call("EXISTS", KEYS[1])`, []string{"notexist"})
assert.Equal(t, Nil, err) assert.Equal(t, Nil, err)
err = client.Set("key1", "value1") err = client.Set("key1", "value1")
assert.Nil(t, err) assert.Nil(t, err)
@@ -40,6 +45,8 @@ func TestRedis_Hgetall(t *testing.T) {
runOnRedis(t, func(client *Redis) { runOnRedis(t, func(client *Redis) {
assert.Nil(t, client.Hset("a", "aa", "aaa")) assert.Nil(t, client.Hset("a", "aa", "aaa"))
assert.Nil(t, client.Hset("a", "bb", "bbb")) assert.Nil(t, client.Hset("a", "bb", "bbb"))
_, err := NewRedis(client.Addr, "").Hgetall("a")
assert.NotNil(t, err)
vals, err := client.Hgetall("a") vals, err := client.Hgetall("a")
assert.Nil(t, err) assert.Nil(t, err)
assert.EqualValues(t, map[string]string{ assert.EqualValues(t, map[string]string{
@@ -51,8 +58,11 @@ func TestRedis_Hgetall(t *testing.T) {
func TestRedis_Hvals(t *testing.T) { func TestRedis_Hvals(t *testing.T) {
runOnRedis(t, func(client *Redis) { runOnRedis(t, func(client *Redis) {
assert.NotNil(t, NewRedis(client.Addr, "").Hset("a", "aa", "aaa"))
assert.Nil(t, client.Hset("a", "aa", "aaa")) assert.Nil(t, client.Hset("a", "aa", "aaa"))
assert.Nil(t, client.Hset("a", "bb", "bbb")) assert.Nil(t, client.Hset("a", "bb", "bbb"))
_, err := NewRedis(client.Addr, "").Hvals("a")
assert.NotNil(t, err)
vals, err := client.Hvals("a") vals, err := client.Hvals("a")
assert.Nil(t, err) assert.Nil(t, err)
assert.ElementsMatch(t, []string{"aaa", "bbb"}, vals) assert.ElementsMatch(t, []string{"aaa", "bbb"}, vals)
@@ -63,6 +73,8 @@ func TestRedis_Hsetnx(t *testing.T) {
runOnRedis(t, func(client *Redis) { runOnRedis(t, func(client *Redis) {
assert.Nil(t, client.Hset("a", "aa", "aaa")) assert.Nil(t, client.Hset("a", "aa", "aaa"))
assert.Nil(t, client.Hset("a", "bb", "bbb")) assert.Nil(t, client.Hset("a", "bb", "bbb"))
_, err := NewRedis(client.Addr, "").Hsetnx("a", "bb", "ccc")
assert.NotNil(t, err)
ok, err := client.Hsetnx("a", "bb", "ccc") ok, err := client.Hsetnx("a", "bb", "ccc")
assert.Nil(t, err) assert.Nil(t, err)
assert.False(t, ok) assert.False(t, ok)
@@ -79,6 +91,8 @@ func TestRedis_HdelHlen(t *testing.T) {
runOnRedis(t, func(client *Redis) { runOnRedis(t, func(client *Redis) {
assert.Nil(t, client.Hset("a", "aa", "aaa")) assert.Nil(t, client.Hset("a", "aa", "aaa"))
assert.Nil(t, client.Hset("a", "bb", "bbb")) assert.Nil(t, client.Hset("a", "bb", "bbb"))
_, err := NewRedis(client.Addr, "").Hlen("a")
assert.NotNil(t, err)
num, err := client.Hlen("a") num, err := client.Hlen("a")
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, 2, num) assert.Equal(t, 2, num)
@@ -93,6 +107,8 @@ func TestRedis_HdelHlen(t *testing.T) {
func TestRedis_HIncrBy(t *testing.T) { func TestRedis_HIncrBy(t *testing.T) {
runOnRedis(t, func(client *Redis) { runOnRedis(t, func(client *Redis) {
_, err := NewRedis(client.Addr, "").Hincrby("key", "field", 2)
assert.NotNil(t, err)
val, err := client.Hincrby("key", "field", 2) val, err := client.Hincrby("key", "field", 2)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, 2, val) assert.Equal(t, 2, val)
@@ -106,6 +122,8 @@ func TestRedis_Hkeys(t *testing.T) {
runOnRedis(t, func(client *Redis) { runOnRedis(t, func(client *Redis) {
assert.Nil(t, client.Hset("a", "aa", "aaa")) assert.Nil(t, client.Hset("a", "aa", "aaa"))
assert.Nil(t, client.Hset("a", "bb", "bbb")) assert.Nil(t, client.Hset("a", "bb", "bbb"))
_, err := NewRedis(client.Addr, "").Hkeys("a")
assert.NotNil(t, err)
vals, err := client.Hkeys("a") vals, err := client.Hkeys("a")
assert.Nil(t, err) assert.Nil(t, err)
assert.ElementsMatch(t, []string{"aa", "bb"}, vals) assert.ElementsMatch(t, []string{"aa", "bb"}, vals)
@@ -116,6 +134,8 @@ func TestRedis_Hmget(t *testing.T) {
runOnRedis(t, func(client *Redis) { runOnRedis(t, func(client *Redis) {
assert.Nil(t, client.Hset("a", "aa", "aaa")) assert.Nil(t, client.Hset("a", "aa", "aaa"))
assert.Nil(t, client.Hset("a", "bb", "bbb")) assert.Nil(t, client.Hset("a", "bb", "bbb"))
_, err := NewRedis(client.Addr, "").Hmget("a", "aa", "bb")
assert.NotNil(t, err)
vals, err := client.Hmget("a", "aa", "bb") vals, err := client.Hmget("a", "aa", "bb")
assert.Nil(t, err) assert.Nil(t, err)
assert.EqualValues(t, []string{"aaa", "bbb"}, vals) assert.EqualValues(t, []string{"aaa", "bbb"}, vals)
@@ -127,6 +147,7 @@ func TestRedis_Hmget(t *testing.T) {
func TestRedis_Hmset(t *testing.T) { func TestRedis_Hmset(t *testing.T) {
runOnRedis(t, func(client *Redis) { runOnRedis(t, func(client *Redis) {
assert.NotNil(t, NewRedis(client.Addr, "").Hmset("a", nil))
assert.Nil(t, client.Hmset("a", map[string]string{ assert.Nil(t, client.Hmset("a", map[string]string{
"aa": "aaa", "aa": "aaa",
"bb": "bbb", "bb": "bbb",
@@ -139,6 +160,8 @@ func TestRedis_Hmset(t *testing.T) {
func TestRedis_Incr(t *testing.T) { func TestRedis_Incr(t *testing.T) {
runOnRedis(t, func(client *Redis) { runOnRedis(t, func(client *Redis) {
_, err := NewRedis(client.Addr, "").Incr("a")
assert.NotNil(t, err)
val, err := client.Incr("a") val, err := client.Incr("a")
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, int64(1), val) assert.Equal(t, int64(1), val)
@@ -150,6 +173,8 @@ func TestRedis_Incr(t *testing.T) {
func TestRedis_IncrBy(t *testing.T) { func TestRedis_IncrBy(t *testing.T) {
runOnRedis(t, func(client *Redis) { runOnRedis(t, func(client *Redis) {
_, err := NewRedis(client.Addr, "").Incrby("a", 2)
assert.NotNil(t, err)
val, err := client.Incrby("a", 2) val, err := client.Incrby("a", 2)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, int64(2), val) assert.Equal(t, int64(2), val)
@@ -165,26 +190,49 @@ func TestRedis_Keys(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
err = client.Set("key2", "value2") err = client.Set("key2", "value2")
assert.Nil(t, err) assert.Nil(t, err)
_, err = NewRedis(client.Addr, "").Keys("*")
assert.NotNil(t, err)
keys, err := client.Keys("*") keys, err := client.Keys("*")
assert.Nil(t, err) assert.Nil(t, err)
assert.ElementsMatch(t, []string{"key1", "key2"}, keys) assert.ElementsMatch(t, []string{"key1", "key2"}, keys)
}) })
} }
func TestRedis_HyperLogLog(t *testing.T) {
runOnRedis(t, func(client *Redis) {
client.Ping()
r := NewRedis(client.Addr, "")
_, err := r.Pfadd("key1")
assert.NotNil(t, err)
_, err = r.Pfcount("*")
assert.NotNil(t, err)
err = r.Pfmerge("*")
assert.NotNil(t, err)
})
}
func TestRedis_List(t *testing.T) { func TestRedis_List(t *testing.T) {
runOnRedis(t, func(client *Redis) { runOnRedis(t, func(client *Redis) {
_, err := NewRedis(client.Addr, "").Lpush("key", "value1", "value2")
assert.NotNil(t, err)
val, err := client.Lpush("key", "value1", "value2") val, err := client.Lpush("key", "value1", "value2")
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, 2, val) assert.Equal(t, 2, val)
_, err = NewRedis(client.Addr, "").Rpush("key", "value3", "value4")
assert.NotNil(t, err)
val, err = client.Rpush("key", "value3", "value4") val, err = client.Rpush("key", "value3", "value4")
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, 4, val) assert.Equal(t, 4, val)
_, err = NewRedis(client.Addr, "").Llen("key")
assert.NotNil(t, err)
val, err = client.Llen("key") val, err = client.Llen("key")
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, 4, val) assert.Equal(t, 4, val)
vals, err := client.Lrange("key", 0, 10) vals, err := client.Lrange("key", 0, 10)
assert.Nil(t, err) assert.Nil(t, err)
assert.EqualValues(t, []string{"value2", "value1", "value3", "value4"}, vals) assert.EqualValues(t, []string{"value2", "value1", "value3", "value4"}, vals)
_, err = NewRedis(client.Addr, "").Lpop("key")
assert.NotNil(t, err)
v, err := client.Lpop("key") v, err := client.Lpop("key")
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "value2", v) assert.Equal(t, "value2", v)
@@ -194,9 +242,13 @@ func TestRedis_List(t *testing.T) {
val, err = client.Rpush("key", "value3", "value3") val, err = client.Rpush("key", "value3", "value3")
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, 7, val) assert.Equal(t, 7, val)
_, err = NewRedis(client.Addr, "").Lrem("key", 2, "value1")
assert.NotNil(t, err)
n, err := client.Lrem("key", 2, "value1") n, err := client.Lrem("key", 2, "value1")
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, 2, n) assert.Equal(t, 2, n)
_, err = NewRedis(client.Addr, "").Lrange("key", 0, 10)
assert.NotNil(t, err)
vals, err = client.Lrange("key", 0, 10) vals, err = client.Lrange("key", 0, 10)
assert.Nil(t, err) assert.Nil(t, err)
assert.EqualValues(t, []string{"value2", "value3", "value4", "value3", "value3"}, vals) assert.EqualValues(t, []string{"value2", "value3", "value4", "value3", "value3"}, vals)
@@ -215,6 +267,8 @@ func TestRedis_Mget(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
err = client.Set("key2", "value2") err = client.Set("key2", "value2")
assert.Nil(t, err) assert.Nil(t, err)
_, err = NewRedis(client.Addr, "").Mget("key1", "key0", "key2", "key3")
assert.NotNil(t, err)
vals, err := client.Mget("key1", "key0", "key2", "key3") vals, err := client.Mget("key1", "key0", "key2", "key3")
assert.Nil(t, err) assert.Nil(t, err)
assert.EqualValues(t, []string{"value1", "", "value2", ""}, vals) assert.EqualValues(t, []string{"value1", "", "value2", ""}, vals)
@@ -223,7 +277,9 @@ func TestRedis_Mget(t *testing.T) {
func TestRedis_SetBit(t *testing.T) { func TestRedis_SetBit(t *testing.T) {
runOnRedis(t, func(client *Redis) { runOnRedis(t, func(client *Redis) {
err := client.SetBit("key", 1, 1) err := NewRedis(client.Addr, "").SetBit("key", 1, 1)
assert.NotNil(t, err)
err = client.SetBit("key", 1, 1)
assert.Nil(t, err) assert.Nil(t, err)
}) })
} }
@@ -232,6 +288,8 @@ func TestRedis_GetBit(t *testing.T) {
runOnRedis(t, func(client *Redis) { runOnRedis(t, func(client *Redis) {
err := client.SetBit("key", 2, 1) err := client.SetBit("key", 2, 1)
assert.Nil(t, err) assert.Nil(t, err)
_, err = NewRedis(client.Addr, "").GetBit("key", 2)
assert.NotNil(t, err)
val, err := client.GetBit("key", 2) val, err := client.GetBit("key", 2)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, 1, val) assert.Equal(t, 1, val)
@@ -240,6 +298,8 @@ func TestRedis_GetBit(t *testing.T) {
func TestRedis_Persist(t *testing.T) { func TestRedis_Persist(t *testing.T) {
runOnRedis(t, func(client *Redis) { runOnRedis(t, func(client *Redis) {
_, err := NewRedis(client.Addr, "").Persist("key")
assert.NotNil(t, err)
ok, err := client.Persist("key") ok, err := client.Persist("key")
assert.Nil(t, err) assert.Nil(t, err)
assert.False(t, ok) assert.False(t, ok)
@@ -248,11 +308,15 @@ func TestRedis_Persist(t *testing.T) {
ok, err = client.Persist("key") ok, err = client.Persist("key")
assert.Nil(t, err) assert.Nil(t, err)
assert.False(t, ok) assert.False(t, ok)
err = NewRedis(client.Addr, "").Expire("key", 5)
assert.NotNil(t, err)
err = client.Expire("key", 5) err = client.Expire("key", 5)
assert.Nil(t, err) assert.Nil(t, err)
ok, err = client.Persist("key") ok, err = client.Persist("key")
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, ok) assert.True(t, ok)
err = NewRedis(client.Addr, "").Expireat("key", time.Now().Unix()+5)
assert.NotNil(t, err)
err = client.Expireat("key", time.Now().Unix()+5) err = client.Expireat("key", time.Now().Unix()+5)
assert.Nil(t, err) assert.Nil(t, err)
ok, err = client.Persist("key") ok, err = client.Persist("key")
@@ -274,6 +338,8 @@ func TestRedis_Scan(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
err = client.Set("key2", "value2") err = client.Set("key2", "value2")
assert.Nil(t, err) assert.Nil(t, err)
_, _, err = NewRedis(client.Addr, "").Scan(0, "*", 100)
assert.NotNil(t, err)
keys, _, err := client.Scan(0, "*", 100) keys, _, err := client.Scan(0, "*", 100)
assert.Nil(t, err) assert.Nil(t, err)
assert.ElementsMatch(t, []string{"key1", "key2"}, keys) assert.ElementsMatch(t, []string{"key1", "key2"}, keys)
@@ -294,6 +360,8 @@ func TestRedis_Sscan(t *testing.T) {
var cursor uint64 = 0 var cursor uint64 = 0
sum := 0 sum := 0
for { for {
_, _, err := NewRedis(client.Addr, "").Sscan(key, cursor, "", 100)
assert.NotNil(t, err)
keys, next, err := client.Sscan(key, cursor, "", 100) keys, next, err := client.Sscan(key, cursor, "", 100)
assert.Nil(t, err) assert.Nil(t, err)
sum += len(keys) sum += len(keys)
@@ -304,6 +372,8 @@ func TestRedis_Sscan(t *testing.T) {
} }
assert.Equal(t, sum, 1550) assert.Equal(t, sum, 1550)
_, err = NewRedis(client.Addr, "").Del(key)
assert.NotNil(t, err)
_, err = client.Del(key) _, err = client.Del(key)
assert.Nil(t, err) assert.Nil(t, err)
}) })
@@ -311,46 +381,72 @@ func TestRedis_Sscan(t *testing.T) {
func TestRedis_Set(t *testing.T) { func TestRedis_Set(t *testing.T) {
runOnRedis(t, func(client *Redis) { runOnRedis(t, func(client *Redis) {
_, err := NewRedis(client.Addr, "").Sadd("key", 1, 2, 3, 4)
assert.NotNil(t, err)
num, err := client.Sadd("key", 1, 2, 3, 4) num, err := client.Sadd("key", 1, 2, 3, 4)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, 4, num) assert.Equal(t, 4, num)
_, err = NewRedis(client.Addr, "").Scard("key")
assert.NotNil(t, err)
val, err := client.Scard("key") val, err := client.Scard("key")
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, int64(4), val) assert.Equal(t, int64(4), val)
_, err = NewRedis(client.Addr, "").Sismember("key", 2)
assert.NotNil(t, err)
ok, err := client.Sismember("key", 2) ok, err := client.Sismember("key", 2)
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, ok) assert.True(t, ok)
_, err = NewRedis(client.Addr, "").Srem("key", 3, 4)
assert.NotNil(t, err)
num, err = client.Srem("key", 3, 4) num, err = client.Srem("key", 3, 4)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, 2, num) assert.Equal(t, 2, num)
_, err = NewRedis(client.Addr, "").Smembers("key")
assert.NotNil(t, err)
vals, err := client.Smembers("key") vals, err := client.Smembers("key")
assert.Nil(t, err) assert.Nil(t, err)
assert.ElementsMatch(t, []string{"1", "2"}, vals) assert.ElementsMatch(t, []string{"1", "2"}, vals)
_, err = NewRedis(client.Addr, "").Srandmember("key", 1)
assert.NotNil(t, err)
members, err := client.Srandmember("key", 1) members, err := client.Srandmember("key", 1)
assert.Nil(t, err) assert.Nil(t, err)
assert.Len(t, members, 1) assert.Len(t, members, 1)
assert.Contains(t, []string{"1", "2"}, members[0]) assert.Contains(t, []string{"1", "2"}, members[0])
_, err = NewRedis(client.Addr, "").Spop("key")
assert.NotNil(t, err)
member, err := client.Spop("key") member, err := client.Spop("key")
assert.Nil(t, err) assert.Nil(t, err)
assert.Contains(t, []string{"1", "2"}, member) assert.Contains(t, []string{"1", "2"}, member)
_, err = NewRedis(client.Addr, "").Smembers("key")
assert.NotNil(t, err)
vals, err = client.Smembers("key") vals, err = client.Smembers("key")
assert.Nil(t, err) assert.Nil(t, err)
assert.NotContains(t, vals, member) assert.NotContains(t, vals, member)
_, err = NewRedis(client.Addr, "").Sadd("key1", 1, 2, 3, 4)
assert.NotNil(t, err)
num, err = client.Sadd("key1", 1, 2, 3, 4) num, err = client.Sadd("key1", 1, 2, 3, 4)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, 4, num) assert.Equal(t, 4, num)
num, err = client.Sadd("key2", 2, 3, 4, 5) num, err = client.Sadd("key2", 2, 3, 4, 5)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, 4, num) assert.Equal(t, 4, num)
_, err = NewRedis(client.Addr, "").Sunion("key1", "key2")
assert.NotNil(t, err)
vals, err = client.Sunion("key1", "key2") vals, err = client.Sunion("key1", "key2")
assert.Nil(t, err) assert.Nil(t, err)
assert.ElementsMatch(t, []string{"1", "2", "3", "4", "5"}, vals) assert.ElementsMatch(t, []string{"1", "2", "3", "4", "5"}, vals)
_, err = NewRedis(client.Addr, "").Sunionstore("key3", "key1", "key2")
assert.NotNil(t, err)
num, err = client.Sunionstore("key3", "key1", "key2") num, err = client.Sunionstore("key3", "key1", "key2")
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, 5, num) assert.Equal(t, 5, num)
_, err = NewRedis(client.Addr, "").Sdiff("key1", "key2")
assert.NotNil(t, err)
vals, err = client.Sdiff("key1", "key2") vals, err = client.Sdiff("key1", "key2")
assert.Nil(t, err) assert.Nil(t, err)
assert.EqualValues(t, []string{"1"}, vals) assert.EqualValues(t, []string{"1"}, vals)
_, err = NewRedis(client.Addr, "").Sdiffstore("key4", "key1", "key2")
assert.NotNil(t, err)
num, err = client.Sdiffstore("key4", "key1", "key2") num, err = client.Sdiffstore("key4", "key1", "key2")
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, 1, num) assert.Equal(t, 1, num)
@@ -359,8 +455,12 @@ func TestRedis_Set(t *testing.T) {
func TestRedis_SetGetDel(t *testing.T) { func TestRedis_SetGetDel(t *testing.T) {
runOnRedis(t, func(client *Redis) { runOnRedis(t, func(client *Redis) {
err := client.Set("hello", "world") err := NewRedis(client.Addr, "").Set("hello", "world")
assert.NotNil(t, err)
err = client.Set("hello", "world")
assert.Nil(t, err) assert.Nil(t, err)
_, err = NewRedis(client.Addr, "").Get("hello")
assert.NotNil(t, err)
val, err := client.Get("hello") val, err := client.Get("hello")
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "world", val) assert.Equal(t, "world", val)
@@ -372,8 +472,12 @@ func TestRedis_SetGetDel(t *testing.T) {
func TestRedis_SetExNx(t *testing.T) { func TestRedis_SetExNx(t *testing.T) {
runOnRedis(t, func(client *Redis) { runOnRedis(t, func(client *Redis) {
err := client.Setex("hello", "world", 5) err := NewRedis(client.Addr, "").Setex("hello", "world", 5)
assert.NotNil(t, err)
err = client.Setex("hello", "world", 5)
assert.Nil(t, err) assert.Nil(t, err)
_, err = NewRedis(client.Addr, "").Setnx("hello", "newworld")
assert.NotNil(t, err)
ok, err := client.Setnx("hello", "newworld") ok, err := client.Setnx("hello", "newworld")
assert.Nil(t, err) assert.Nil(t, err)
assert.False(t, ok) assert.False(t, ok)
@@ -389,6 +493,8 @@ func TestRedis_SetExNx(t *testing.T) {
ttl, err := client.Ttl("hello") ttl, err := client.Ttl("hello")
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, ttl > 0) assert.True(t, ttl > 0)
_, err = NewRedis(client.Addr, "").SetnxEx("newhello", "newworld", 5)
assert.NotNil(t, err)
ok, err = client.SetnxEx("newhello", "newworld", 5) ok, err = client.SetnxEx("newhello", "newworld", 5)
assert.Nil(t, err) assert.Nil(t, err)
assert.False(t, ok) assert.False(t, ok)
@@ -408,12 +514,18 @@ func TestRedis_SetGetDelHashField(t *testing.T) {
runOnRedis(t, func(client *Redis) { runOnRedis(t, func(client *Redis) {
err := client.Hset("key", "field", "value") err := client.Hset("key", "field", "value")
assert.Nil(t, err) assert.Nil(t, err)
_, err = NewRedis(client.Addr, "").Hget("key", "field")
assert.NotNil(t, err)
val, err := client.Hget("key", "field") val, err := client.Hget("key", "field")
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "value", val) assert.Equal(t, "value", val)
_, err = NewRedis(client.Addr, "").Hexists("key", "field")
assert.NotNil(t, err)
ok, err := client.Hexists("key", "field") ok, err := client.Hexists("key", "field")
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, ok) assert.True(t, ok)
_, err = NewRedis(client.Addr, "").Hdel("key", "field")
assert.NotNil(t, err)
ret, err := client.Hdel("key", "field") ret, err := client.Hdel("key", "field")
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, ret) assert.True(t, ret)
@@ -434,23 +546,50 @@ func TestRedis_SortedSet(t *testing.T) {
val, err := client.Zscore("key", "value1") val, err := client.Zscore("key", "value1")
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, int64(2), val) assert.Equal(t, int64(2), val)
_, err = NewRedis(client.Addr, "").Zincrby("key", 3, "value1")
assert.NotNil(t, err)
val, err = client.Zincrby("key", 3, "value1") val, err = client.Zincrby("key", 3, "value1")
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, int64(5), val) assert.Equal(t, int64(5), val)
_, err = NewRedis(client.Addr, "").Zscore("key", "value1")
assert.NotNil(t, err)
val, err = client.Zscore("key", "value1") val, err = client.Zscore("key", "value1")
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, int64(5), val) assert.Equal(t, int64(5), val)
ok, err = client.Zadd("key", 6, "value2") val, err = NewRedis(client.Addr, "").Zadds("key")
assert.NotNil(t, err)
val, err = client.Zadds("key", Pair{
Key: "value2",
Score: 6,
}, Pair{
Key: "value3",
Score: 7,
})
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, ok) assert.Equal(t, int64(2), val)
ok, err = client.Zadd("key", 7, "value3") pairs, err := NewRedis(client.Addr, "").ZRevRangeWithScores("key", 1, 3)
assert.NotNil(t, err)
pairs, err = client.ZRevRangeWithScores("key", 1, 3)
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, ok) assert.EqualValues(t, []Pair{
{
Key: "value2",
Score: 6,
},
{
Key: "value1",
Score: 5,
},
}, pairs)
rank, err := client.Zrank("key", "value2") rank, err := client.Zrank("key", "value2")
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, int64(1), rank) assert.Equal(t, int64(1), rank)
_, err = NewRedis(client.Addr, "").Zrank("key", "value4")
assert.NotNil(t, err)
_, err = client.Zrank("key", "value4") _, err = client.Zrank("key", "value4")
assert.Equal(t, Nil, err) assert.Equal(t, Nil, err)
_, err = NewRedis(client.Addr, "").Zrem("key", "value2", "value3")
assert.NotNil(t, err)
num, err := client.Zrem("key", "value2", "value3") num, err := client.Zrem("key", "value2", "value3")
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, 2, num) assert.Equal(t, 2, num)
@@ -463,31 +602,47 @@ func TestRedis_SortedSet(t *testing.T) {
ok, err = client.Zadd("key", 8, "value4") ok, err = client.Zadd("key", 8, "value4")
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, ok) assert.True(t, ok)
_, err = NewRedis(client.Addr, "").Zremrangebyscore("key", 6, 7)
assert.NotNil(t, err)
num, err = client.Zremrangebyscore("key", 6, 7) num, err = client.Zremrangebyscore("key", 6, 7)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, 2, num) assert.Equal(t, 2, num)
ok, err = client.Zadd("key", 6, "value2") ok, err = client.Zadd("key", 6, "value2")
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, ok) assert.True(t, ok)
_, err = NewRedis(client.Addr, "").Zadd("key", 7, "value3")
assert.NotNil(t, err)
ok, err = client.Zadd("key", 7, "value3") ok, err = client.Zadd("key", 7, "value3")
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, ok) assert.True(t, ok)
_, err = NewRedis(client.Addr, "").Zcount("key", 6, 7)
assert.NotNil(t, err)
num, err = client.Zcount("key", 6, 7) num, err = client.Zcount("key", 6, 7)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, 2, num) assert.Equal(t, 2, num)
_, err = NewRedis(client.Addr, "").Zremrangebyrank("key", 1, 2)
assert.NotNil(t, err)
num, err = client.Zremrangebyrank("key", 1, 2) num, err = client.Zremrangebyrank("key", 1, 2)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, 2, num) assert.Equal(t, 2, num)
_, err = NewRedis(client.Addr, "").Zcard("key")
assert.NotNil(t, err)
card, err := client.Zcard("key") card, err := client.Zcard("key")
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, 2, card) assert.Equal(t, 2, card)
_, err = NewRedis(client.Addr, "").Zrange("key", 0, -1)
assert.NotNil(t, err)
vals, err := client.Zrange("key", 0, -1) vals, err := client.Zrange("key", 0, -1)
assert.Nil(t, err) assert.Nil(t, err)
assert.EqualValues(t, []string{"value1", "value4"}, vals) assert.EqualValues(t, []string{"value1", "value4"}, vals)
_, err = NewRedis(client.Addr, "").Zrevrange("key", 0, -1)
assert.NotNil(t, err)
vals, err = client.Zrevrange("key", 0, -1) vals, err = client.Zrevrange("key", 0, -1)
assert.Nil(t, err) assert.Nil(t, err)
assert.EqualValues(t, []string{"value4", "value1"}, vals) assert.EqualValues(t, []string{"value4", "value1"}, vals)
pairs, err := client.ZrangeWithScores("key", 0, -1) _, err = NewRedis(client.Addr, "").ZrangeWithScores("key", 0, -1)
assert.NotNil(t, err)
pairs, err = client.ZrangeWithScores("key", 0, -1)
assert.Nil(t, err) assert.Nil(t, err)
assert.EqualValues(t, []Pair{ assert.EqualValues(t, []Pair{
{ {
@@ -499,6 +654,8 @@ func TestRedis_SortedSet(t *testing.T) {
Score: 8, Score: 8,
}, },
}, pairs) }, pairs)
_, err = NewRedis(client.Addr, "").ZrangebyscoreWithScores("key", 5, 8)
assert.NotNil(t, err)
pairs, err = client.ZrangebyscoreWithScores("key", 5, 8) pairs, err = client.ZrangebyscoreWithScores("key", 5, 8)
assert.Nil(t, err) assert.Nil(t, err)
assert.EqualValues(t, []Pair{ assert.EqualValues(t, []Pair{
@@ -511,6 +668,9 @@ func TestRedis_SortedSet(t *testing.T) {
Score: 8, Score: 8,
}, },
}, pairs) }, pairs)
_, err = NewRedis(client.Addr, "").ZrangebyscoreWithScoresAndLimit(
"key", 5, 8, 1, 1)
assert.NotNil(t, err)
pairs, err = client.ZrangebyscoreWithScoresAndLimit("key", 5, 8, 1, 1) pairs, err = client.ZrangebyscoreWithScoresAndLimit("key", 5, 8, 1, 1)
assert.Nil(t, err) assert.Nil(t, err)
assert.EqualValues(t, []Pair{ assert.EqualValues(t, []Pair{
@@ -519,6 +679,11 @@ func TestRedis_SortedSet(t *testing.T) {
Score: 8, Score: 8,
}, },
}, pairs) }, pairs)
pairs, err = client.ZrangebyscoreWithScoresAndLimit("key", 5, 8, 1, 0)
assert.Nil(t, err)
assert.Equal(t, 0, len(pairs))
_, err = NewRedis(client.Addr, "").ZrevrangebyscoreWithScores("key", 5, 8)
assert.NotNil(t, err)
pairs, err = client.ZrevrangebyscoreWithScores("key", 5, 8) pairs, err = client.ZrevrangebyscoreWithScores("key", 5, 8)
assert.Nil(t, err) assert.Nil(t, err)
assert.EqualValues(t, []Pair{ assert.EqualValues(t, []Pair{
@@ -531,6 +696,9 @@ func TestRedis_SortedSet(t *testing.T) {
Score: 5, Score: 5,
}, },
}, pairs) }, pairs)
_, err = NewRedis(client.Addr, "").ZrevrangebyscoreWithScoresAndLimit(
"key", 5, 8, 1, 1)
assert.NotNil(t, err)
pairs, err = client.ZrevrangebyscoreWithScoresAndLimit("key", 5, 8, 1, 1) pairs, err = client.ZrevrangebyscoreWithScoresAndLimit("key", 5, 8, 1, 1)
assert.Nil(t, err) assert.Nil(t, err)
assert.EqualValues(t, []Pair{ assert.EqualValues(t, []Pair{
@@ -539,11 +707,17 @@ func TestRedis_SortedSet(t *testing.T) {
Score: 5, Score: 5,
}, },
}, pairs) }, pairs)
pairs, err = client.ZrevrangebyscoreWithScoresAndLimit("key", 5, 8, 1, 0)
assert.Nil(t, err)
assert.Equal(t, 0, len(pairs))
}) })
} }
func TestRedis_Pipelined(t *testing.T) { func TestRedis_Pipelined(t *testing.T) {
runOnRedis(t, func(client *Redis) { runOnRedis(t, func(client *Redis) {
assert.NotNil(t, NewRedis(client.Addr, "").Pipelined(func(pipeliner Pipeliner) error {
return nil
}))
err := client.Pipelined( err := client.Pipelined(
func(pipe Pipeliner) error { func(pipe Pipeliner) error {
pipe.Incr("pipelined_counter") pipe.Incr("pipelined_counter")
@@ -553,6 +727,8 @@ func TestRedis_Pipelined(t *testing.T) {
}, },
) )
assert.Nil(t, err) assert.Nil(t, err)
_, err = NewRedis(client.Addr, "").Ttl("pipelined_counter")
assert.NotNil(t, err)
ttl, err := client.Ttl("pipelined_counter") ttl, err := client.Ttl("pipelined_counter")
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, 3600, ttl) assert.Equal(t, 3600, ttl)
@@ -565,6 +741,76 @@ func TestRedis_Pipelined(t *testing.T) {
}) })
} }
func TestRedisString(t *testing.T) {
runOnRedis(t, func(client *Redis) {
client.Ping()
_, err := getRedis(NewRedis(client.Addr, ClusterType))
assert.Nil(t, err)
assert.Equal(t, client.Addr, client.String())
assert.NotNil(t, NewRedis(client.Addr, "").Ping())
})
}
func TestRedisScriptLoad(t *testing.T) {
runOnRedis(t, func(client *Redis) {
client.Ping()
_, err := NewRedis(client.Addr, "").scriptLoad("foo")
assert.NotNil(t, err)
_, err = client.scriptLoad("foo")
assert.NotNil(t, err)
})
}
func TestRedisToPairs(t *testing.T) {
pairs := toPairs([]red.Z{
{
Member: 1,
Score: 1,
},
{
Member: 2,
Score: 2,
},
})
assert.EqualValues(t, []Pair{
{
Key: "1",
Score: 1,
},
{
Key: "2",
Score: 2,
},
}, pairs)
}
func TestRedisToStrings(t *testing.T) {
vals := toStrings([]interface{}{1, 2})
assert.EqualValues(t, []string{"1", "2"}, vals)
}
func TestRedisBlpop(t *testing.T) {
runOnRedis(t, func(client *Redis) {
client.Ping()
var node mockedNode
_, err := client.Blpop(nil, "foo")
assert.NotNil(t, err)
_, err = client.Blpop(node, "foo")
assert.NotNil(t, err)
})
}
func TestRedisBlpopEx(t *testing.T) {
runOnRedis(t, func(client *Redis) {
client.Ping()
var node mockedNode
_, _, err := client.BlpopEx(nil, "foo")
assert.NotNil(t, err)
_, _, err = client.BlpopEx(node, "foo")
assert.NotNil(t, err)
})
}
func runOnRedis(t *testing.T, fn func(client *Redis)) { func runOnRedis(t *testing.T, fn func(client *Redis)) {
s, err := miniredis.Run() s, err := miniredis.Run()
assert.Nil(t, err) assert.Nil(t, err)
@@ -576,8 +822,18 @@ func runOnRedis(t *testing.T, fn func(client *Redis)) {
t.Error(err) t.Error(err)
} }
client.Close() if client != nil {
client.Close()
}
}() }()
fn(NewRedis(s.Addr(), NodeType)) fn(NewRedis(s.Addr(), NodeType))
} }
type mockedNode struct {
RedisNode
}
func (n mockedNode) BLPop(timeout time.Duration, keys ...string) *red.StringSliceCmd {
return red.NewStringSliceCmd("foo", "bar")
}

View File

@@ -5,7 +5,6 @@ import (
"time" "time"
"github.com/tal-tech/go-zero/core/stores/cache" "github.com/tal-tech/go-zero/core/stores/cache"
"github.com/tal-tech/go-zero/core/stores/internal"
"github.com/tal-tech/go-zero/core/stores/redis" "github.com/tal-tech/go-zero/core/stores/redis"
"github.com/tal-tech/go-zero/core/stores/sqlx" "github.com/tal-tech/go-zero/core/stores/sqlx"
"github.com/tal-tech/go-zero/core/syncx" "github.com/tal-tech/go-zero/core/syncx"
@@ -19,7 +18,7 @@ var (
// can't use one SharedCalls per conn, because multiple conns may share the same cache key. // can't use one SharedCalls per conn, because multiple conns may share the same cache key.
exclusiveCalls = syncx.NewSharedCalls() exclusiveCalls = syncx.NewSharedCalls()
stats = internal.NewCacheStat("sqlc") stats = cache.NewCacheStat("sqlc")
) )
type ( type (
@@ -30,21 +29,21 @@ type (
CachedConn struct { CachedConn struct {
db sqlx.SqlConn db sqlx.SqlConn
cache internal.Cache cache cache.Cache
} }
) )
func NewNodeConn(db sqlx.SqlConn, rds *redis.Redis, opts ...cache.Option) CachedConn { func NewNodeConn(db sqlx.SqlConn, rds *redis.Redis, opts ...cache.Option) CachedConn {
return CachedConn{ return CachedConn{
db: db, db: db,
cache: internal.NewCacheNode(rds, exclusiveCalls, stats, sql.ErrNoRows, opts...), cache: cache.NewCacheNode(rds, exclusiveCalls, stats, sql.ErrNoRows, opts...),
} }
} }
func NewConn(db sqlx.SqlConn, c cache.CacheConf, opts ...cache.Option) CachedConn { func NewConn(db sqlx.SqlConn, c cache.CacheConf, opts ...cache.Option) CachedConn {
return CachedConn{ return CachedConn{
db: db, db: db,
cache: internal.NewCache(c, exclusiveCalls, stats, sql.ErrNoRows, opts...), cache: cache.NewCache(c, exclusiveCalls, stats, sql.ErrNoRows, opts...),
} }
} }
@@ -83,6 +82,7 @@ func (cc CachedConn) QueryRowIndex(v interface{}, key string, keyer func(primary
indexQuery IndexQueryFn, primaryQuery PrimaryQueryFn) error { indexQuery IndexQueryFn, primaryQuery PrimaryQueryFn) error {
var primaryKey interface{} var primaryKey interface{}
var found bool var found bool
if err := cc.cache.TakeWithExpire(&primaryKey, key, func(val interface{}, expire time.Duration) (err error) { if err := cc.cache.TakeWithExpire(&primaryKey, key, func(val interface{}, expire time.Duration) (err error) {
primaryKey, err = indexQuery(cc.db, v) primaryKey, err = indexQuery(cc.db, v)
if err != nil { if err != nil {

View File

@@ -79,9 +79,29 @@ func TestCachedConn_QueryRowIndex_NoCache(t *testing.T) {
} }
r := redis.NewRedis(s.Addr(), redis.NodeType) r := redis.NewRedis(s.Addr(), redis.NodeType)
c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10)) c := NewConn(dummySqlConn{}, cache.CacheConf{
{
RedisConf: redis.RedisConf{
Host: s.Addr(),
Type: redis.NodeType,
},
Weight: 100,
},
}, cache.WithExpiry(time.Second*10))
var str string var str string
err = c.QueryRowIndex(&str, "index", func(s interface{}) string {
return fmt.Sprintf("%s/1234", s)
}, func(conn sqlx.SqlConn, v interface{}) (interface{}, error) {
*v.(*string) = "zero"
return "primary", errors.New("foo")
}, func(conn sqlx.SqlConn, v, pri interface{}) error {
assert.Equal(t, "primary", pri)
*v.(*string) = "xin"
return nil
})
assert.NotNil(t, err)
err = c.QueryRowIndex(&str, "index", func(s interface{}) string { err = c.QueryRowIndex(&str, "index", func(s interface{}) string {
return fmt.Sprintf("%s/1234", s) return fmt.Sprintf("%s/1234", s)
}, func(conn sqlx.SqlConn, v interface{}) (interface{}, error) { }, func(conn sqlx.SqlConn, v interface{}) (interface{}, error) {
@@ -135,6 +155,103 @@ func TestCachedConn_QueryRowIndex_HasCache(t *testing.T) {
assert.Equal(t, `"xin"`, val) assert.Equal(t, `"xin"`, val)
} }
func TestCachedConn_QueryRowIndex_HasCache_IntPrimary(t *testing.T) {
const (
primaryInt8 int8 = 100
primaryInt16 int16 = 10000
primaryInt32 int32 = 10000000
primaryInt64 int64 = 10000000
primaryUint8 uint8 = 100
primaryUint16 uint16 = 10000
primaryUint32 uint32 = 10000000
primaryUint64 uint64 = 10000000
)
tests := []struct {
name string
primary interface{}
primaryCache string
}{
{
name: "int8 primary",
primary: primaryInt8,
primaryCache: fmt.Sprint(primaryInt8),
},
{
name: "int16 primary",
primary: primaryInt16,
primaryCache: fmt.Sprint(primaryInt16),
},
{
name: "int32 primary",
primary: primaryInt32,
primaryCache: fmt.Sprint(primaryInt32),
},
{
name: "int64 primary",
primary: primaryInt64,
primaryCache: fmt.Sprint(primaryInt64),
},
{
name: "uint8 primary",
primary: primaryUint8,
primaryCache: fmt.Sprint(primaryUint8),
},
{
name: "uint16 primary",
primary: primaryUint16,
primaryCache: fmt.Sprint(primaryUint16),
},
{
name: "uint32 primary",
primary: primaryUint32,
primaryCache: fmt.Sprint(primaryUint32),
},
{
name: "uint64 primary",
primary: primaryUint64,
primaryCache: fmt.Sprint(primaryUint64),
},
}
s, err := miniredis.Run()
if err != nil {
t.Error(err)
}
defer s.Close()
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
resetStats()
s.FlushAll()
r := redis.NewRedis(s.Addr(), redis.NodeType)
c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10),
cache.WithNotFoundExpiry(time.Second))
var str string
r.Set("index", test.primaryCache)
err = c.QueryRowIndex(&str, "index", func(s interface{}) string {
return fmt.Sprintf("%v/1234", s)
}, func(conn sqlx.SqlConn, v interface{}) (interface{}, error) {
assert.Fail(t, "should not go here")
return test.primary, nil
}, func(conn sqlx.SqlConn, v, primary interface{}) error {
*v.(*string) = "xin"
assert.Equal(t, primary, primary)
return nil
})
assert.Nil(t, err)
assert.Equal(t, "xin", str)
val, err := r.Get("index")
assert.Nil(t, err)
assert.Equal(t, test.primaryCache, val)
val, err = r.Get(test.primaryCache + "/1234")
assert.Nil(t, err)
assert.Equal(t, `"xin"`, val)
})
}
}
func TestCachedConn_QueryRowIndex_HasWrongCache(t *testing.T) { func TestCachedConn_QueryRowIndex_HasWrongCache(t *testing.T) {
caches := map[string]string{ caches := map[string]string{
"index": "primary", "index": "primary",
@@ -148,6 +265,8 @@ func TestCachedConn_QueryRowIndex_HasWrongCache(t *testing.T) {
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
s.FlushAll()
defer s.Close()
r := redis.NewRedis(s.Addr(), redis.NodeType) r := redis.NewRedis(s.Addr(), redis.NodeType)
c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10), c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10),
@@ -401,6 +520,10 @@ func TestCachedConnExecDropCache(t *testing.T) {
assert.True(t, conn.execValue) assert.True(t, conn.execValue)
_, err = s.Get(key) _, err = s.Get(key)
assert.Exactly(t, miniredis.ErrKeyNotFound, err) assert.Exactly(t, miniredis.ErrKeyNotFound, err)
_, err = c.Exec(func(conn sqlx.SqlConn) (result sql.Result, e error) {
return nil, errors.New("foo")
}, key)
assert.NotNil(t, err)
} }
func TestCachedConnExecDropCacheFailed(t *testing.T) { func TestCachedConnExecDropCacheFailed(t *testing.T) {
@@ -446,6 +569,31 @@ func TestCachedConnTransact(t *testing.T) {
assert.True(t, conn.transactValue) assert.True(t, conn.transactValue)
} }
func TestQueryRowNoCache(t *testing.T) {
s, err := miniredis.Run()
if err != nil {
t.Error(err)
}
const (
key = "user"
value = "any"
)
var user string
var ran bool
r := redis.NewRedis(s.Addr(), redis.NodeType)
conn := dummySqlConn{queryRow: func(v interface{}, q string, args ...interface{}) error {
user = value
ran = true
return nil
}}
c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*30))
err = c.QueryRowNoCache(&user, key)
assert.Nil(t, err)
assert.Equal(t, value, user)
assert.True(t, ran)
}
func resetStats() { func resetStats() {
atomic.StoreUint64(&stats.Total, 0) atomic.StoreUint64(&stats.Total, 0)
atomic.StoreUint64(&stats.Hit, 0) atomic.StoreUint64(&stats.Hit, 0)
@@ -454,6 +602,7 @@ func resetStats() {
} }
type dummySqlConn struct { type dummySqlConn struct {
queryRow func(interface{}, string, ...interface{}) error
} }
func (d dummySqlConn) Exec(query string, args ...interface{}) (sql.Result, error) { func (d dummySqlConn) Exec(query string, args ...interface{}) (sql.Result, error) {
@@ -465,6 +614,9 @@ func (d dummySqlConn) Prepare(query string) (sqlx.StmtSession, error) {
} }
func (d dummySqlConn) QueryRow(v interface{}, query string, args ...interface{}) error { func (d dummySqlConn) QueryRow(v interface{}, query string, args ...interface{}) error {
if d.queryRow != nil {
return d.queryRow(v, query, args...)
}
return nil return nil
} }

View File

@@ -2,6 +2,7 @@ package sqlx
import ( import (
"database/sql" "database/sql"
"errors"
"strconv" "strconv"
"testing" "testing"
@@ -11,14 +12,15 @@ import (
) )
type mockedConn struct { type mockedConn struct {
query string query string
args []interface{} args []interface{}
execErr error
} }
func (c *mockedConn) Exec(query string, args ...interface{}) (sql.Result, error) { func (c *mockedConn) Exec(query string, args ...interface{}) (sql.Result, error) {
c.query = query c.query = query
c.args = args c.args = args
return nil, nil return nil, c.execErr
} }
func (c *mockedConn) Prepare(query string) (StmtSession, error) { func (c *mockedConn) Prepare(query string) (StmtSession, error) {
@@ -68,9 +70,12 @@ func TestBulkInserterSuffix(t *testing.T) {
inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES`+ inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES`+
`(?, ?, ?) ON DUPLICATE KEY UPDATE is_overtime=VALUES(is_overtime)`) `(?, ?, ?) ON DUPLICATE KEY UPDATE is_overtime=VALUES(is_overtime)`)
assert.Nil(t, err) assert.Nil(t, err)
assert.Nil(t, inserter.UpdateStmt(`INSERT INTO classroom_dau(classroom, user, count) VALUES`+
`(?, ?, ?) ON DUPLICATE KEY UPDATE is_overtime=VALUES(is_overtime)`))
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
assert.Nil(t, inserter.Insert("class_"+strconv.Itoa(i), "user_"+strconv.Itoa(i), i)) assert.Nil(t, inserter.Insert("class_"+strconv.Itoa(i), "user_"+strconv.Itoa(i), i))
} }
inserter.SetResultHandler(func(result sql.Result, err error) {})
inserter.Flush() inserter.Flush()
assert.Equal(t, `INSERT INTO classroom_dau(classroom, user, count) VALUES `+ assert.Equal(t, `INSERT INTO classroom_dau(classroom, user, count) VALUES `+
`('class_0', 'user_0', 0), ('class_1', 'user_1', 1), ('class_2', 'user_2', 2), `+ `('class_0', 'user_0', 0), ('class_1', 'user_1', 1), ('class_2', 'user_2', 2), `+
@@ -80,6 +85,33 @@ func TestBulkInserterSuffix(t *testing.T) {
}) })
} }
func TestBulkInserterBadStatement(t *testing.T) {
runSqlTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var conn mockedConn
_, err := NewBulkInserter(&conn, "foo")
assert.NotNil(t, err)
})
}
func TestBulkInserter_Update(t *testing.T) {
conn := mockedConn{
execErr: errors.New("foo"),
}
_, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES()`)
assert.NotNil(t, err)
_, err = NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES(?)`)
assert.NotNil(t, err)
inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES(?, ?, ?)`)
assert.Nil(t, err)
inserter.inserter.Execute([]string{"bar"})
inserter.SetResultHandler(func(result sql.Result, err error) {
})
inserter.UpdateOrDelete(func() {})
inserter.inserter.Execute([]string(nil))
assert.NotNil(t, inserter.UpdateStmt("foo"))
assert.NotNil(t, inserter.Insert("foo", "bar"))
}
func runSqlTest(t *testing.T, fn func(db *sql.DB, mock sqlmock.Sqlmock)) { func runSqlTest(t *testing.T, fn func(db *sql.DB, mock sqlmock.Sqlmock)) {
logx.Disable() logx.Disable()

View File

@@ -2,6 +2,7 @@ package sqlx
import ( import (
"database/sql" "database/sql"
"errors"
"testing" "testing"
"github.com/DATA-DOG/go-sqlmock" "github.com/DATA-DOG/go-sqlmock"
@@ -22,6 +23,18 @@ func TestUnmarshalRowBool(t *testing.T) {
}) })
} }
func TestUnmarshalRowBoolNotSettable(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value bool
assert.NotNil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRow(value, rows, true)
}, "select value from users where user=?", "anyone"))
})
}
func TestUnmarshalRowInt(t *testing.T) { func TestUnmarshalRowInt(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2") rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2")
@@ -228,6 +241,22 @@ func TestUnmarshalRowStructWithTags(t *testing.T) {
}) })
} }
func TestUnmarshalRowStructWithTagsWrongColumns(t *testing.T) {
var value = new(struct {
Age *int `db:"age"`
Name string `db:"name"`
})
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name"}).FromCSVString("liao")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.NotNil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRow(value, rows, true)
}, "select name, age from users where user=?", "anyone"))
})
}
func TestUnmarshalRowsBool(t *testing.T) { func TestUnmarshalRowsBool(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []bool{true, false} var expect = []bool{true, false}
@@ -955,6 +984,62 @@ func TestCommonSqlConn_QueryRowOptional(t *testing.T) {
}) })
} }
func TestUnmarshalRowError(t *testing.T) {
tests := []struct {
name string
colErr error
scanErr error
err error
next int
validate func(err error)
}{
{
name: "with error",
err: errors.New("foo"),
validate: func(err error) {
assert.NotNil(t, err)
},
},
{
name: "without next",
validate: func(err error) {
assert.Equal(t, ErrNotFound, err)
},
},
{
name: "with error",
scanErr: errors.New("foo"),
next: 1,
validate: func(err error) {
assert.Equal(t, ErrNotFound, err)
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"age"}).FromCSVString("5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs(
"anyone").WillReturnRows(rs)
var r struct {
User string `db:"user"`
Age int `db:"age"`
}
test.validate(query(db, func(rows *sql.Rows) error {
scanner := mockedScanner{
colErr: test.colErr,
scanErr: test.scanErr,
err: test.err,
}
return unmarshalRow(&r, &scanner, false)
}, "select age from users where user=?", "anyone"))
})
})
}
}
func runOrmTest(t *testing.T, fn func(db *sql.DB, mock sqlmock.Sqlmock)) { func runOrmTest(t *testing.T, fn func(db *sql.DB, mock sqlmock.Sqlmock)) {
logx.Disable() logx.Disable()
@@ -970,3 +1055,30 @@ func runOrmTest(t *testing.T, fn func(db *sql.DB, mock sqlmock.Sqlmock)) {
t.Errorf("there were unfulfilled expectations: %s", err) t.Errorf("there were unfulfilled expectations: %s", err)
} }
} }
type mockedScanner struct {
colErr error
scanErr error
err error
next int
}
func (m *mockedScanner) Columns() ([]string, error) {
return nil, m.colErr
}
func (m *mockedScanner) Err() error {
return m.err
}
func (m *mockedScanner) Next() bool {
if m.next > 0 {
m.next--
return true
}
return false
}
func (m *mockedScanner) Scan(v ...interface{}) error {
return m.scanErr
}

View File

@@ -6,18 +6,22 @@ import (
"github.com/tal-tech/go-zero/core/lang" "github.com/tal-tech/go-zero/core/lang"
) )
var ErrReturn = errors.New("discarding limited token, resource pool is full, someone returned multiple times") // ErrLimitReturn indicates that the more than borrowed elements were returned.
var ErrLimitReturn = errors.New("discarding limited token, resource pool is full, someone returned multiple times")
// Limit controls the concurrent requests.
type Limit struct { type Limit struct {
pool chan lang.PlaceholderType pool chan lang.PlaceholderType
} }
// NewLimit creates a Limit that can borrow n elements from it concurrently.
func NewLimit(n int) Limit { func NewLimit(n int) Limit {
return Limit{ return Limit{
pool: make(chan lang.PlaceholderType, n), pool: make(chan lang.PlaceholderType, n),
} }
} }
// Borrow borrows an element from Limit in blocking mode.
func (l Limit) Borrow() { func (l Limit) Borrow() {
l.pool <- lang.Placeholder l.pool <- lang.Placeholder
} }
@@ -28,10 +32,12 @@ func (l Limit) Return() error {
case <-l.pool: case <-l.pool:
return nil return nil
default: default:
return ErrReturn return ErrLimitReturn
} }
} }
// TryBorrow tries to borrow an element from Limit, in non-blocking mode.
// If success, true returned, false for otherwise.
func (l Limit) TryBorrow() bool { func (l Limit) TryBorrow() bool {
select { select {
case l.pool <- lang.Placeholder: case l.pool <- lang.Placeholder:

View File

@@ -13,5 +13,5 @@ func TestLimit(t *testing.T) {
assert.False(t, limit.TryBorrow()) assert.False(t, limit.TryBorrow())
assert.Nil(t, limit.Return()) assert.Nil(t, limit.Return())
assert.Nil(t, limit.Return()) assert.Nil(t, limit.Return())
assert.Equal(t, ErrReturn, limit.Return()) assert.Equal(t, ErrLimitReturn, limit.Return())
} }

View File

@@ -33,35 +33,42 @@ func NewSharedCalls() SharedCalls {
} }
func (g *sharedGroup) Do(key string, fn func() (interface{}, error)) (interface{}, error) { func (g *sharedGroup) Do(key string, fn func() (interface{}, error)) (interface{}, error) {
g.lock.Lock() c, done := g.createCall(key, fn)
if c, ok := g.calls[key]; ok { if done {
g.lock.Unlock()
c.wg.Wait()
return c.val, c.err return c.val, c.err
} }
c := g.makeCall(key, fn) g.makeCall(c, key, fn)
return c.val, c.err return c.val, c.err
} }
func (g *sharedGroup) DoEx(key string, fn func() (interface{}, error)) (val interface{}, fresh bool, err error) { func (g *sharedGroup) DoEx(key string, fn func() (interface{}, error)) (val interface{}, fresh bool, err error) {
c, done := g.createCall(key, fn)
if done {
return c.val, false, c.err
}
g.makeCall(c, key, fn)
return c.val, true, c.err
}
func (g *sharedGroup) createCall(key string, fn func() (interface{}, error)) (c *call, done bool) {
g.lock.Lock() g.lock.Lock()
if c, ok := g.calls[key]; ok { if c, ok := g.calls[key]; ok {
g.lock.Unlock() g.lock.Unlock()
c.wg.Wait() c.wg.Wait()
return c.val, false, c.err return c, true
} }
c := g.makeCall(key, fn) c = new(call)
return c.val, true, c.err
}
func (g *sharedGroup) makeCall(key string, fn func() (interface{}, error)) *call {
c := new(call)
c.wg.Add(1) c.wg.Add(1)
g.calls[key] = c g.calls[key] = c
g.lock.Unlock() g.lock.Unlock()
return c, false
}
func (g *sharedGroup) makeCall(c *call, key string, fn func() (interface{}, error)) {
defer func() { defer func() {
// delete key first, done later. can't reverse the order, because if reverse, // delete key first, done later. can't reverse the order, because if reverse,
// another Do call might wg.Wait() without get notified with wg.Done() // another Do call might wg.Wait() without get notified with wg.Done()
@@ -72,5 +79,4 @@ func (g *sharedGroup) makeCall(key string, fn func() (interface{}, error)) *call
}() }()
c.val, c.err = fn() c.val, c.err = fn()
return c
} }

View File

@@ -68,6 +68,38 @@ func TestExclusiveCallDoDupSuppress(t *testing.T) {
} }
} }
func TestExclusiveCallDoDiffDupSuppress(t *testing.T) {
g := NewSharedCalls()
broadcast := make(chan struct{})
var calls int32
tests := []string{"e", "a", "e", "a", "b", "c", "b", "a", "c", "d", "b", "c", "d"}
var wg sync.WaitGroup
for _, key := range tests {
wg.Add(1)
go func(k string) {
<-broadcast // get all goroutines ready
_, err := g.Do(k, func() (interface{}, error) {
atomic.AddInt32(&calls, 1)
time.Sleep(10 * time.Millisecond)
return nil, nil
})
if err != nil {
t.Errorf("Do error: %v", err)
}
wg.Done()
}(key)
}
time.Sleep(100 * time.Millisecond) // let goroutines above block
close(broadcast)
wg.Wait()
if got := atomic.LoadInt32(&calls); got != 5 { // five letters
t.Errorf("number of calls = %d; want 5", got)
}
}
func TestExclusiveCallDoExDupSuppress(t *testing.T) { func TestExclusiveCallDoExDupSuppress(t *testing.T) {
g := NewSharedCalls() g := NewSharedCalls()
c := make(chan string) c := make(chan string)

View File

@@ -29,5 +29,5 @@ func TestTimeoutLimit(t *testing.T) {
assert.Equal(t, ErrTimeout, limit.Borrow(time.Millisecond*100)) assert.Equal(t, ErrTimeout, limit.Borrow(time.Millisecond*100))
assert.Nil(t, limit.Return()) assert.Nil(t, limit.Return())
assert.Nil(t, limit.Return()) assert.Nil(t, limit.Return())
assert.Equal(t, ErrReturn, limit.Return()) assert.Equal(t, ErrLimitReturn, limit.Return())
} }

View File

@@ -0,0 +1,12 @@
package tracespec
// TracingKey is tracing key for context
var TracingKey = contextKey("X-Trace")
// contextKey a type for context key
type contextKey string
// Printing a context will reveal a fair amount of information about it.
func (c contextKey) String() string {
return "trace/tracespec context key " + string(c)
}

View File

@@ -1,3 +0,0 @@
package tracespec
const TracingKey = "X-Trace"

View File

@@ -57,13 +57,19 @@ And now, lets walk through the complete flow of quickly create a microservice
* install etcd, mysql, redis * install etcd, mysql, redis
* install protoc-gen-go
```shell
go get -u github.com/golang/protobuf/protoc-gen-go
```
* install goctl * install goctl
```shell ```shell
GO111MODULE=on go get -u github.com/tal-tech/go-zero/tools/goctl GO111MODULE=on go get -u github.com/tal-tech/go-zero/tools/goctl
``` ```
* create the working dir bookstore * create the working dir `bookstore` and `bookstore/api`
* in `bookstore` dir, execute `go mod init bookstore` to initialize `go.mod`` * in `bookstore` dir, execute `go mod init bookstore` to initialize `go.mod``
@@ -185,6 +191,8 @@ And now, lets walk through the complete flow of quickly create a microservice
## 6. Write code for add rpc service ## 6. Write code for add rpc service
- under directory `bookstore` create dir `rpc`
* under directory `rpc/add` create `add.proto` file * under directory `rpc/add` create `add.proto` file
```shell ```shell
@@ -347,8 +355,8 @@ you can change the listening port in file `etc/add.yaml`.
```go ```go
type Config struct { type Config struct {
rest.RestConf rest.RestConf
Add rpcx.RpcClientConf // manual code Add zrpc.RpcClientConf // manual code
Check rpcx.RpcClientConf // manual code Check zrpc.RpcClientConf // manual code
} }
``` ```
@@ -364,8 +372,8 @@ you can change the listening port in file `etc/add.yaml`.
func NewServiceContext(c config.Config) *ServiceContext { func NewServiceContext(c config.Config) *ServiceContext {
return &ServiceContext{ return &ServiceContext{
Config: c, Config: c,
Adder: adder.NewAdder(rpcx.MustNewClient(c.Add)), // manual code Adder: adder.NewAdder(zrpc.MustNewClient(c.Add)), // manual code
Checker: checker.NewChecker(rpcx.MustNewClient(c.Check)), // manual code Checker: checker.NewChecker(zrpc.MustNewClient(c.Check)), // manual code
} }
} }
``` ```
@@ -477,7 +485,7 @@ Till now, weve done the modification of API Gateway. All the manually added c
```go ```go
type Config struct { type Config struct {
rpcx.RpcServerConf zrpc.RpcServerConf
DataSource string // manual code DataSource string // manual code
Table string // manual code Table string // manual code
Cache cache.CacheConf // manual code Cache cache.CacheConf // manual code

View File

@@ -57,13 +57,19 @@
* 安装etcd, mysql, redis * 安装etcd, mysql, redis
* 安装`protoc-gen-go`
```shell
go get -u github.com/golang/protobuf/protoc-gen-go
```
* 安装goctl工具 * 安装goctl工具
```shell ```shell
GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/tal-tech/go-zero/tools/goctl GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/tal-tech/go-zero/tools/goctl
``` ```
* 创建工作目录`bookstore` * 创建工作目录 `bookstore` 和 `bookstore/api`
* 在`bookstore`目录下执行`go mod init bookstore`初始化`go.mod` * 在`bookstore`目录下执行`go mod init bookstore`初始化`go.mod`
@@ -185,6 +191,8 @@
## 6. 编写add rpc服务 ## 6. 编写add rpc服务
- 在 `bookstore` 下创建 `rpc` 目录
* 在`rpc/add`目录下编写`add.proto`文件 * 在`rpc/add`目录下编写`add.proto`文件
可以通过命令生成proto文件模板 可以通过命令生成proto文件模板
@@ -349,8 +357,8 @@
```go ```go
type Config struct { type Config struct {
rest.RestConf rest.RestConf
Add rpcx.RpcClientConf // 手动代码 Add zrpc.RpcClientConf // 手动代码
Check rpcx.RpcClientConf // 手动代码 Check zrpc.RpcClientConf // 手动代码
} }
``` ```
@@ -366,8 +374,8 @@
func NewServiceContext(c config.Config) *ServiceContext { func NewServiceContext(c config.Config) *ServiceContext {
return &ServiceContext{ return &ServiceContext{
Config: c, Config: c,
Adder: adder.NewAdder(rpcx.MustNewClient(c.Add)), // 手动代码 Adder: adder.NewAdder(zrpc.MustNewClient(c.Add)), // 手动代码
Checker: checker.NewChecker(rpcx.MustNewClient(c.Check)), // 手动代码 Checker: checker.NewChecker(zrpc.MustNewClient(c.Check)), // 手动代码
} }
} }
``` ```
@@ -477,7 +485,7 @@
```go ```go
type Config struct { type Config struct {
rpcx.RpcServerConf zrpc.RpcServerConf
DataSource string // 手动代码 DataSource string // 手动代码
Table string // 手动代码 Table string // 手动代码
Cache cache.CacheConf // 手动代码 Cache cache.CacheConf // 手动代码
@@ -540,7 +548,7 @@
} }
``` ```
至此代码修改完成,凡手动修改的代码我加了标注 至此代码修改完成,凡手动修改的代码我加了标注
## 11. 完整调用演示 ## 11. 完整调用演示

View File

@@ -221,11 +221,11 @@ OPTIONS:
*rrBalanced does not implement Picker (wrong type for Pick method) *rrBalanced does not implement Picker (wrong type for Pick method)
have Pick(context.Context, balancer.PickInfo) (balancer.SubConn, func(balancer.DoneInfo), error) have Pick(context.Context, balancer.PickInfo) (balancer.SubConn, func(balancer.DoneInfo), error)
want Pick(balancer.PickInfo) (balancer.PickResult, error) want Pick(balancer.PickInfo) (balancer.PickResult, error)
#github.com/tal-tech/go-zero/rpcx/internal/balancer/p2c #github.com/tal-tech/go-zero/zrpc/internal/balancer/p2c
../../../go/pkg/mod/github.com/tal-tech/go-zero@v1.0.12/rpcx/internal/balancer/p2c/p2c.go:41:32: not enough arguments in call to base.NewBalancerBuilder ../../../go/pkg/mod/github.com/tal-tech/go-zero@v1.0.12/zrpc/internal/balancer/p2c/p2c.go:41:32: not enough arguments in call to base.NewBalancerBuilder
have (string, *p2cPickerBuilder) have (string, *p2cPickerBuilder)
want (string, base.PickerBuilder, base.Config) want (string, base.PickerBuilder, base.Config)
../../../go/pkg/mod/github.com/tal-tech/go-zero@v1.0.12/rpcx/internal/balancer/p2c/p2c.go:58:9: cannot use &p2cPicker literal (type *p2cPicker) as type balancer.Picker in return argument: ../../../go/pkg/mod/github.com/tal-tech/go-zero@v1.0.12/zrpc/internal/balancer/p2c/p2c.go:58:9: cannot use &p2cPicker literal (type *p2cPicker) as type balancer.Picker in return argument:
*p2cPicker does not implement balancer.Picker (wrong type for Pick method) *p2cPicker does not implement balancer.Picker (wrong type for Pick method)
have Pick(context.Context, balancer.PickInfo) (balancer.SubConn, func(balancer.DoneInfo), error) have Pick(context.Context, balancer.PickInfo) (balancer.SubConn, func(balancer.DoneInfo), error)
want Pick(balancer.PickInfo) (balancer.PickResult, error) want Pick(balancer.PickInfo) (balancer.PickResult, error)

View File

@@ -95,20 +95,20 @@ service user-api {
) )
@server( @server(
handler: GetUserHandler handler: GetUserHandler
folder: user group: user
) )
get /api/user/:name(getRequest) returns(getResponse) get /api/user/:name(getRequest) returns(getResponse)
@server( @server(
handler: CreateUserHandler handler: CreateUserHandler
folder: user group: user
) )
post /api/users/create(createRequest) post /api/users/create(createRequest)
} }
@server( @server(
jwt: Auth jwt: Auth
folder: profile group: profile
) )
service user-api { service user-api {
@doc(summary: user title) @doc(summary: user title)
@@ -135,7 +135,7 @@ service user-api {
1. info部分描述了api基本信息比如Authapi是哪个用途。 1. info部分描述了api基本信息比如Authapi是哪个用途。
2. type部分type类型声明和golang语法兼容。 2. type部分type类型声明和golang语法兼容。
3. service部分service代表一组服务一个服务可以由多组名称相同的service组成可以针对每一组service配置jwt和auth认证另外通过folder属性可以指定service生成所在子目录。 3. service部分service代表一组服务一个服务可以由多组名称相同的service组成可以针对每一组service配置jwt和auth认证另外通过group属性可以指定service生成所在子目录。
service里面包含api路由比如上面第一组service的第一个路由doc用来描述此路由的用途GetProfileHandler表示处理这个路由的handler service里面包含api路由比如上面第一组service的第一个路由doc用来描述此路由的用途GetProfileHandler表示处理这个路由的handler
`get /api/profile/:name(getRequest) returns(getResponse)` 中get代表api的请求方式get/post/put/delete, `/api/profile/:name` 描述了路由path`:name`通过 `get /api/profile/:name(getRequest) returns(getResponse)` 中get代表api的请求方式get/post/put/delete, `/api/profile/:name` 描述了路由path`:name`通过
请求getRequest里面的属性赋值getResponse为返回的结构体这两个类型都定义在2描述的类型中。 请求getRequest里面的属性赋值getResponse为返回的结构体这两个类型都定义在2描述的类型中。
@@ -239,10 +239,10 @@ src 示例代码如下
``` ```
结构体中不需要提供Id,CreateTime,UpdateTime三个字段会自动生成 结构体中不需要提供Id,CreateTime,UpdateTime三个字段会自动生成
结构体中每个tag有两个可选标签 c 和 o 结构体中每个tag有两个可选标签 c 和 o
c 是字段的注释 c 是字段的注释
o 是字段需要生产的操作函数 可以取得get,find,set 分别表示生成返回单个对象的查询方法,返回多个对象的查询方法,设置该字段方法 o 是字段需要生产的操作函数 可以取得get,find,set 分别表示生成返回单个对象的查询方法,返回多个对象的查询方法,设置该字段方法
生成的目标文件会覆盖该简单go文件 生成的目标文件会覆盖该简单go文件
## goctl rpc生成业务剥离中暂未开放 ## goctl rpc生成业务剥离中暂未开放

Binary file not shown.

Before

Width:  |  Height:  |  Size: 125 KiB

After

Width:  |  Height:  |  Size: 141 KiB

140
doc/jwt.md Normal file
View File

@@ -0,0 +1,140 @@
# 基于go-zero实现JWT认证
关于JWT是什么大家可以看看[官网](https://jwt.io/),一句话介绍下:是可以实现服务器无状态的鉴权认证方案,也是目前最流行的跨域认证解决方案。
要实现JWT认证我们需要分成如下两个步骤
* 客户端获取JWT token。
* 服务器对客户端带来的JWT token认证。
## 1. 客户端获取JWT Token
我们定义一个协议供客户端调用获取JWT token我们新建一个目录jwt然后在目录中执行 `goctl api -o jwt.api`将生成的jwt.api改成如下
````go
type JwtTokenRequest struct {
}
type JwtTokenResponse struct {
AccessToken string `json:"access_token"`
AccessExpire int64 `json:"access_expire"`
RefreshAfter int64 `json:"refresh_after"` // 建议客户端刷新token的绝对时间
}
type GetUserRequest struct {
UserId string `json:"userId"`
}
type GetUserResponse struct {
Name string `json:"name"`
}
service jwt-api {
@server(
handler: JwtHandler
)
post /user/token(JwtTokenRequest) returns (JwtTokenResponse)
}
@server(
jwt: JwtAuth
)
service jwt-api {
@server(
handler: GetUserHandler
)
post /user/info(GetUserRequest) returns (GetUserResponse)
}
````
在服务jwt目录中执行`goctl api go -api jwt.api -dir .`
打开jwtlogic.go文件修改 `func (l *JwtLogic) Jwt(req types.JwtTokenRequest) (*types.JwtTokenResponse, error) {` 方法如下:
```go
func (l *JwtLogic) Jwt(req types.JwtTokenRequest) (*types.JwtTokenResponse, error) {
var accessExpire = l.svcCtx.Config.JwtAuth.AccessExpire
now := time.Now().Unix()
accessToken, err := l.GenToken(now, l.svcCtx.Config.JwtAuth.AccessSecret, nil, accessExpire)
if err != nil {
return nil, err
}
return &types.JwtTokenResponse{
AccessToken: accessToken,
AccessExpire: now + accessExpire,
RefreshAfter: now + accessExpire/2,
}, nil
}
func (l *JwtLogic) GenToken(iat int64, secretKey string, payloads map[string]interface{}, seconds int64) (string, error) {
claims := make(jwt.MapClaims)
claims["exp"] = iat + seconds
claims["iat"] = iat
for k, v := range payloads {
claims[k] = v
}
token := jwt.New(jwt.SigningMethodHS256)
token.Claims = claims
return token.SignedString([]byte(secretKey))
}
```
在启动服务之前我们需要修改etc/jwt-api.yaml文件如下
```yaml
Name: jwt-api
Host: 0.0.0.0
Port: 8888
JwtAuth:
AccessSecret: xxxxxxxxxxxxxxxxxxxxxxxxxxxxx
AccessExpire: 604800
```
启动服务器然后测试下获取到的token。
```sh
➜ curl --location --request POST '127.0.0.1:8888/user/token'
{"access_token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE2MDEyNjE0MjksImlhdCI6MTYwMDY1NjYyOX0.6u_hpE_4m5gcI90taJLZtvfekwUmjrbNJ-5saaDGeQc","access_expire":1601261429,"refresh_after":1600959029}
```
## 2. 服务器验证JWT token
1. 在api文件中通过`jwt: JwtAuth`标记的service表示激活了jwt认证。
2. 可以阅读rest/handler/authhandler.go文件了解服务器jwt实现。
3. 修改getuserlogic.go如下
```go
func (l *GetUserLogic) GetUser(req types.GetUserRequest) (*types.GetUserResponse, error) {
return &types.GetUserResponse{Name: "kim"}, nil
}
```
* 我们先不带JWT Authorization header请求头测试下返回http status code是401符合预期。
```sh
➜ curl -w "\nhttp: %{http_code} \n" --location --request POST '127.0.0.1:8888/user/info' \
--header 'Content-Type: application/json' \
--data-raw '{
"userId": "a"
}'
http: 401
```
* 加上Authorization header请求头测试。
```sh
➜ curl -w "\nhttp: %{http_code} \n" --location --request POST '127.0.0.1:8888/user/info' \
--header 'Authorization: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE2MDEyNjE0MjksImlhdCI6MTYwMDY1NjYyOX0.6u_hpE_4m5gcI90taJLZtvfekwUmjrbNJ-5saaDGeQc' \
--header 'Content-Type: application/json' \
--data-raw '{
"userId": "a"
}'
{"name":"kim"}
http: 200
```
综上所述基于go-zero的JWT认证完成在真实生产环境部署时候AccessSecret, AccessExpire, RefreshAfter根据业务场景通过配置文件配置RefreshAfter 是告诉客户端什么时候该刷新JWT token了一般都需要设置过期时间前几天。

View File

@@ -106,10 +106,10 @@ func main() {
// 没有拿到结果则调用makeCall方法去获取资源注意此处仍然是锁住的可以保证只有一个goroutine可以调用makecall // 没有拿到结果则调用makeCall方法去获取资源注意此处仍然是锁住的可以保证只有一个goroutine可以调用makecall
c := g.makeCall(key, fn) c := g.makeCall(key, fn)
// 返回调用结果 // 返回调用结果
return c.val, c.err return c.val, c.err
} }
``` ```
- sharedGroup的DoEx方法 - sharedGroup的DoEx方法
@@ -160,7 +160,7 @@ func main() {
c.val, c.err = fn() c.val, c.err = fn()
return c return c
} }
``` ```
## 最后 ## 最后

View File

@@ -60,13 +60,19 @@ And now, lets walk through the complete flow of quickly create a microservice
* install etcd, mysql, redis * install etcd, mysql, redis
* install protoc-gen-go
```
go get -u github.com/golang/protobuf/protoc-gen-go
```
* install goctl * install goctl
```shell ```shell
GO111MODULE=on go get -u github.com/tal-tech/go-zero/tools/goctl GO111MODULE=on go get -u github.com/tal-tech/go-zero/tools/goctl
``` ```
* create the working dir `shorturl` * create the working dir `shorturl` and `shorturl/api`
* in `shorturl` dir, execute `go mod init shorturl` to initialize `go.mod` * in `shorturl` dir, execute `go mod init shorturl` to initialize `go.mod`
@@ -189,6 +195,8 @@ And now, lets walk through the complete flow of quickly create a microservice
## 6. Write code for transform rpc service ## 6. Write code for transform rpc service
- under directory `shorturl` create dir `rpc`
* under directory `rpc/transform` create `transform.proto` file * under directory `rpc/transform` create `transform.proto` file
```shell ```shell
@@ -284,7 +292,7 @@ And now, lets walk through the complete flow of quickly create a microservice
```go ```go
type Config struct { type Config struct {
rest.RestConf rest.RestConf
Transform rpcx.RpcClientConf // manual code Transform zrpc.RpcClientConf // manual code
} }
``` ```
@@ -299,7 +307,7 @@ And now, lets walk through the complete flow of quickly create a microservice
func NewServiceContext(c config.Config) *ServiceContext { func NewServiceContext(c config.Config) *ServiceContext {
return &ServiceContext{ return &ServiceContext{
Config: c, Config: c,
Transformer: transformer.NewTransformer(rpcx.MustNewClient(c.Transform)), // manual code Transformer: transformer.NewTransformer(zrpc.MustNewClient(c.Transform)), // manual code
} }
} }
``` ```
@@ -409,7 +417,7 @@ Till now, weve done the modification of API Gateway. All the manually added c
```go ```go
type Config struct { type Config struct {
rpcx.RpcServerConf zrpc.RpcServerConf
DataSource string // manual code DataSource string // manual code
Table string // manual code Table string // manual code
Cache cache.CacheConf // manual code Cache cache.CacheConf // manual code

View File

@@ -60,19 +60,31 @@
* 安装etcd, mysql, redis * 安装etcd, mysql, redis
* 安装`protoc-gen-go`
```shell
go get -u github.com/golang/protobuf/protoc-gen-go
```
* 安装goctl工具 * 安装goctl工具
```shell ```shell
GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/tal-tech/go-zero/tools/goctl GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/tal-tech/go-zero/tools/goctl
``` ```
* 创建工作目录`shorturl` * 创建工作目录 `shorturl` 和 `shorturl/api`
* 在`shorturl`目录下执行`go mod init shorturl`初始化`go.mod` * 在`shorturl`目录下执行`go mod init shorturl`初始化`go.mod`
## 5. 编写API Gateway代码 ## 5. 编写API Gateway代码
* 通过goctl生成`api/shorturl.api`并编辑,为了简洁,去除了文件开头的`info`,代码如下 * 在`shorturl/api`目录下通过goctl生成`api/shorturl.api`
```shell
goctl api -o shorturl.api
```
* 编辑`api/shorturl.api`,为了简洁,去除了文件开头的`info`,代码如下:
```go ```go
type ( type (
@@ -183,6 +195,8 @@
## 6. 编写transform rpc服务 ## 6. 编写transform rpc服务
- 在 `shorturl` 目录下创建 `rpc` 目录
* 在`rpc/transform`目录下编写`transform.proto`文件 * 在`rpc/transform`目录下编写`transform.proto`文件
可以通过命令生成proto文件模板 可以通过命令生成proto文件模板
@@ -280,7 +294,7 @@
```go ```go
type Config struct { type Config struct {
rest.RestConf rest.RestConf
Transform rpcx.RpcClientConf // 手动代码 Transform zrpc.RpcClientConf // 手动代码
} }
``` ```
@@ -295,7 +309,7 @@
func NewServiceContext(c config.Config) *ServiceContext { func NewServiceContext(c config.Config) *ServiceContext {
return &ServiceContext{ return &ServiceContext{
Config: c, Config: c,
Transformer: transformer.NewTransformer(rpcx.MustNewClient(c.Transform)), // 手动代码 Transformer: transformer.NewTransformer(zrpc.MustNewClient(c.Transform)), // 手动代码
} }
} }
``` ```
@@ -405,7 +419,7 @@
```go ```go
type Config struct { type Config struct {
rpcx.RpcServerConf zrpc.RpcServerConf
DataSource string // 手动代码 DataSource string // 手动代码
Table string // 手动代码 Table string // 手动代码
Cache cache.CacheConf // 手动代码 Cache cache.CacheConf // 手动代码
@@ -468,7 +482,7 @@
} }
``` ```
至此代码修改完成,凡手动修改的代码我加了标注 至此代码修改完成,凡手动修改的代码我加了标注
## 10. 完整调用演示 ## 10. 完整调用演示

View File

@@ -1,33 +1,33 @@
type ( type (
addReq struct { addReq struct {
book string `form:"book"` book string `form:"book"`
price int64 `form:"price"` price int64 `form:"price"`
} }
addResp struct { addResp struct {
ok bool `json:"ok"` ok bool `json:"ok"`
} }
) )
type ( type (
checkReq struct { checkReq struct {
book string `form:"book"` book string `form:"book"`
} }
checkResp struct { checkResp struct {
found bool `json:"found"` found bool `json:"found"`
price int64 `json:"price"` price int64 `json:"price"`
} }
) )
service bookstore-api { service bookstore-api {
@server( @server(
handler: AddHandler handler: AddHandler
) )
get /add(addReq) returns(addResp) get /add (addReq) returns (addResp)
@server( @server(
handler: CheckHandler handler: CheckHandler
) )
get /check(checkReq) returns(checkResp) get /check (checkReq) returns (checkResp)
} }

View File

@@ -2,11 +2,11 @@ package config
import ( import (
"github.com/tal-tech/go-zero/rest" "github.com/tal-tech/go-zero/rest"
"github.com/tal-tech/go-zero/rpcx" "github.com/tal-tech/go-zero/zrpc"
) )
type Config struct { type Config struct {
rest.RestConf rest.RestConf
Add rpcx.RpcClientConf Add zrpc.RpcClientConf
Check rpcx.RpcClientConf Check zrpc.RpcClientConf
} }

View File

@@ -1,10 +1,11 @@
package handler package handler
import ( import (
"net/http"
"bookstore/api/internal/logic" "bookstore/api/internal/logic"
"bookstore/api/internal/svc" "bookstore/api/internal/svc"
"bookstore/api/internal/types" "bookstore/api/internal/types"
"net/http"
"github.com/tal-tech/go-zero/rest/httpx" "github.com/tal-tech/go-zero/rest/httpx"
) )

View File

@@ -1,10 +1,11 @@
package handler package handler
import ( import (
"net/http"
"bookstore/api/internal/logic" "bookstore/api/internal/logic"
"bookstore/api/internal/svc" "bookstore/api/internal/svc"
"bookstore/api/internal/types" "bookstore/api/internal/types"
"net/http"
"github.com/tal-tech/go-zero/rest/httpx" "github.com/tal-tech/go-zero/rest/httpx"
) )

Some files were not shown because too many files have changed in this diff Show More