mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-05-18 12:18:18 +08:00
Compare commits
56 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8478474f7f | ||
|
|
df5ae9507f | ||
|
|
faf4d7e3bb | ||
|
|
f64fe5eb5e | ||
|
|
97d889103a | ||
|
|
9a44310d00 | ||
|
|
06eeef2cf3 | ||
|
|
9adc7d4cb9 | ||
|
|
006f78c3d5 | ||
|
|
64a8e65f4a | ||
|
|
8fd1e76d29 | ||
|
|
0466af5e49 | ||
|
|
7405d7f506 | ||
|
|
afd9ff889e | ||
|
|
7e087de6e6 | ||
|
|
5aded99df5 | ||
|
|
08fb980ad2 | ||
|
|
b94d7aa532 | ||
|
|
ee630b8b57 | ||
|
|
bd82b7d8de | ||
|
|
3d729c77a6 | ||
|
|
e944b59bb3 | ||
|
|
54b5e3f4b2 | ||
|
|
b913229028 | ||
|
|
9963ffb1c1 | ||
|
|
8cb6490724 | ||
|
|
05e37ee20f | ||
|
|
d88da4cc88 | ||
|
|
425430f67c | ||
|
|
4e0d91f6c0 | ||
|
|
8584351b6d | ||
|
|
b19c5223a9 | ||
|
|
99a2d95433 | ||
|
|
9db222bf5b | ||
|
|
ac648d08cb | ||
|
|
6df7fa619c | ||
|
|
bbb4ce586f | ||
|
|
888551627c | ||
|
|
bd623aaac3 | ||
|
|
9e6c2ba2c0 | ||
|
|
c0db8d017d | ||
|
|
52b4f8ca91 | ||
|
|
4884a7b3c6 | ||
|
|
3c6951577d | ||
|
|
fcd15c9b17 | ||
|
|
155e6061cb | ||
|
|
dda7666097 | ||
|
|
c954568b61 | ||
|
|
c2acc43a52 | ||
|
|
1a1a6f5239 | ||
|
|
60c7edf8f8 | ||
|
|
7ad86a52f3 | ||
|
|
1e4e5a02b2 | ||
|
|
39540e21d2 | ||
|
|
b321622c95 | ||
|
|
a25cba5380 |
@@ -122,8 +122,7 @@ func BenchmarkGoogleBreaker(b *testing.B) {
|
||||
}
|
||||
}
|
||||
|
||||
type mockedPromise struct {
|
||||
}
|
||||
type mockedPromise struct{}
|
||||
|
||||
func (m *mockedPromise) Accept() {
|
||||
}
|
||||
|
||||
@@ -40,7 +40,7 @@ func TestAesEcbBase64(t *testing.T) {
|
||||
// more than 32 chars
|
||||
badKey2 = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
|
||||
)
|
||||
var key = []byte("q4t7w!z%C*F-JaNdRgUjXn2r5u8x/A?D")
|
||||
key := []byte("q4t7w!z%C*F-JaNdRgUjXn2r5u8x/A?D")
|
||||
b64Key := base64.StdEncoding.EncodeToString(key)
|
||||
b64Val := base64.StdEncoding.EncodeToString([]byte(val))
|
||||
_, err := EcbEncryptBase64(badKey1, val)
|
||||
|
||||
@@ -139,7 +139,7 @@ func TestRollingWindowBucketTimeBoundary(t *testing.T) {
|
||||
func TestRollingWindowDataRace(t *testing.T) {
|
||||
const size = 3
|
||||
r := NewRollingWindow(size, duration)
|
||||
var stop = make(chan bool)
|
||||
stop := make(chan bool)
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
|
||||
@@ -25,7 +25,7 @@ func LoadConfig(file string, v interface{}, opts ...Option) error {
|
||||
|
||||
loader, ok := loaders[path.Ext(file)]
|
||||
if !ok {
|
||||
return fmt.Errorf("unrecoginized file type: %s", file)
|
||||
return fmt.Errorf("unrecognized file type: %s", file)
|
||||
}
|
||||
|
||||
var opt options
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
package contextx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ShrinkDeadline returns a new Context with proper deadline base on the given ctx and timeout.
|
||||
// And returns a cancel function as well.
|
||||
func ShrinkDeadline(ctx context.Context, timeout time.Duration) (context.Context, func()) {
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
leftTime := time.Until(deadline)
|
||||
if leftTime < timeout {
|
||||
timeout = leftTime
|
||||
}
|
||||
}
|
||||
|
||||
return context.WithDeadline(ctx, time.Now().Add(timeout))
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
package contextx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestShrinkDeadlineLess(t *testing.T) {
|
||||
deadline := time.Now().Add(time.Second)
|
||||
ctx, cancel := context.WithDeadline(context.Background(), deadline)
|
||||
defer cancel()
|
||||
ctx, cancel = ShrinkDeadline(ctx, time.Minute)
|
||||
defer cancel()
|
||||
dl, ok := ctx.Deadline()
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, deadline, dl)
|
||||
}
|
||||
|
||||
func TestShrinkDeadlineMore(t *testing.T) {
|
||||
deadline := time.Now().Add(time.Minute)
|
||||
ctx, cancel := context.WithDeadline(context.Background(), deadline)
|
||||
defer cancel()
|
||||
ctx, cancel = ShrinkDeadline(ctx, time.Second)
|
||||
defer cancel()
|
||||
dl, ok := ctx.Deadline()
|
||||
assert.True(t, ok)
|
||||
assert.True(t, dl.Before(deadline))
|
||||
}
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
)
|
||||
|
||||
func TestChain(t *testing.T) {
|
||||
var errDummy = errors.New("dummy")
|
||||
errDummy := errors.New("dummy")
|
||||
assert.Nil(t, Chain(func() error {
|
||||
return nil
|
||||
}, func() error {
|
||||
|
||||
@@ -15,7 +15,7 @@ type (
|
||||
|
||||
// DoWithRetry runs fn, and retries if failed. Default to retry 3 times.
|
||||
func DoWithRetry(fn func() error, opts ...RetryOption) error {
|
||||
var options = newRetryOptions()
|
||||
options := newRetryOptions()
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
@@ -30,7 +30,7 @@ func TestRetry(t *testing.T) {
|
||||
return errors.New("any")
|
||||
}))
|
||||
|
||||
var total = 2 * defaultRetryTimes
|
||||
total := 2 * defaultRetryTimes
|
||||
times = 0
|
||||
assert.Nil(t, DoWithRetry(func() error {
|
||||
times++
|
||||
|
||||
@@ -3,8 +3,6 @@ package fx
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/contextx"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -23,10 +21,11 @@ func DoWithTimeout(fn func() error, timeout time.Duration, opts ...DoOption) err
|
||||
for _, opt := range opts {
|
||||
parentCtx = opt()
|
||||
}
|
||||
ctx, cancel := contextx.ShrinkDeadline(parentCtx, timeout)
|
||||
ctx, cancel := context.WithTimeout(parentCtx, timeout)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan error)
|
||||
// create channel with buffer size 1 to avoid goroutine leak
|
||||
done := make(chan error, 1)
|
||||
panicChan := make(chan interface{}, 1)
|
||||
go func() {
|
||||
defer func() {
|
||||
@@ -35,7 +34,6 @@ func DoWithTimeout(fn func() error, timeout time.Duration, opts ...DoOption) err
|
||||
}
|
||||
}()
|
||||
done <- fn()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
|
||||
@@ -140,7 +140,7 @@ func (h *ConsistentHash) Remove(node interface{}) {
|
||||
index := sort.Search(len(h.keys), func(i int) bool {
|
||||
return h.keys[i] >= hash
|
||||
})
|
||||
if index < len(h.keys) {
|
||||
if index < len(h.keys) && h.keys[index] == hash {
|
||||
h.keys = append(h.keys[:index], h.keys[index+1:]...)
|
||||
}
|
||||
h.removeRingNode(hash, nodeRepr)
|
||||
|
||||
@@ -168,7 +168,7 @@ func (as *adaptiveShedder) maxPass() int64 {
|
||||
}
|
||||
|
||||
func (as *adaptiveShedder) minRt() float64 {
|
||||
var result = defaultMinRt
|
||||
result := defaultMinRt
|
||||
|
||||
as.rtCounter.Reduce(func(b *collection.Bucket) {
|
||||
if b.Count <= 0 {
|
||||
|
||||
@@ -201,7 +201,7 @@ func BenchmarkAdaptiveShedder_Allow(b *testing.B) {
|
||||
logx.Disable()
|
||||
|
||||
bench := func(b *testing.B) {
|
||||
var shedder = NewAdaptiveShedder()
|
||||
shedder := NewAdaptiveShedder()
|
||||
proba := mathx.NewProba()
|
||||
for i := 0; i < 6000; i++ {
|
||||
p, err := shedder.Allow()
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package load
|
||||
|
||||
type nopShedder struct {
|
||||
}
|
||||
type nopShedder struct{}
|
||||
|
||||
func newNopShedder() Shedder {
|
||||
return nopShedder{}
|
||||
@@ -11,8 +10,7 @@ func (s nopShedder) Allow() (Promise, error) {
|
||||
return nopPromise{}, nil
|
||||
}
|
||||
|
||||
type nopPromise struct {
|
||||
}
|
||||
type nopPromise struct{}
|
||||
|
||||
func (p nopPromise) Pass() {
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ package logx
|
||||
type LogConf struct {
|
||||
ServiceName string `json:",optional"`
|
||||
Mode string `json:",default=console,options=console|file|volume"`
|
||||
TimeFormat string `json:",optional"`
|
||||
Path string `json:",default=logs"`
|
||||
Level string `json:",default=info,options=info|error|severe"`
|
||||
Compress bool `json:",optional"`
|
||||
|
||||
@@ -32,8 +32,6 @@ const (
|
||||
)
|
||||
|
||||
const (
|
||||
timeFormat = "2006-01-02T15:04:05.000Z07"
|
||||
|
||||
accessFilename = "access.log"
|
||||
errorFilename = "error.log"
|
||||
severeFilename = "severe.log"
|
||||
@@ -64,6 +62,7 @@ var (
|
||||
// ErrLogServiceNameNotSet is an error that indicates that the service name is not set.
|
||||
ErrLogServiceNameNotSet = errors.New("log service name must be set")
|
||||
|
||||
timeFormat = "2006-01-02T15:04:05.000Z07"
|
||||
writeConsole bool
|
||||
logLevel uint32
|
||||
infoLog io.WriteCloser
|
||||
@@ -117,6 +116,10 @@ func MustSetup(c LogConf) {
|
||||
// we need to allow different service frameworks to initialize logx respectively.
|
||||
// the same logic for SetUp
|
||||
func SetUp(c LogConf) error {
|
||||
if len(c.TimeFormat) > 0 {
|
||||
timeFormat = c.TimeFormat
|
||||
}
|
||||
|
||||
switch c.Mode {
|
||||
case consoleMode:
|
||||
setupWithConsole(c)
|
||||
|
||||
@@ -22,8 +22,8 @@ const (
|
||||
dateFormat = "2006-01-02"
|
||||
hoursPerDay = 24
|
||||
bufferSize = 100
|
||||
defaultDirMode = 0755
|
||||
defaultFileMode = 0600
|
||||
defaultDirMode = 0o755
|
||||
defaultFileMode = 0o600
|
||||
)
|
||||
|
||||
// ErrLogFileClosed is an error that indicates the log file is already closed.
|
||||
|
||||
@@ -752,7 +752,7 @@ func TestUnmarshalJsonNumberInt64(t *testing.T) {
|
||||
for i := 0; i <= maxUintBitsToTest; i++ {
|
||||
var intValue int64 = 1 << uint(i)
|
||||
strValue := strconv.FormatInt(intValue, 10)
|
||||
var number = json.Number(strValue)
|
||||
number := json.Number(strValue)
|
||||
m := map[string]interface{}{
|
||||
"ID": number,
|
||||
}
|
||||
@@ -768,7 +768,7 @@ func TestUnmarshalJsonNumberUint64(t *testing.T) {
|
||||
for i := 0; i <= maxUintBitsToTest; i++ {
|
||||
var intValue uint64 = 1 << uint(i)
|
||||
strValue := strconv.FormatUint(intValue, 10)
|
||||
var number = json.Number(strValue)
|
||||
number := json.Number(strValue)
|
||||
m := map[string]interface{}{
|
||||
"ID": number,
|
||||
}
|
||||
@@ -784,7 +784,7 @@ func TestUnmarshalJsonNumberUint64Ptr(t *testing.T) {
|
||||
for i := 0; i <= maxUintBitsToTest; i++ {
|
||||
var intValue uint64 = 1 << uint(i)
|
||||
strValue := strconv.FormatUint(intValue, 10)
|
||||
var number = json.Number(strValue)
|
||||
number := json.Number(strValue)
|
||||
m := map[string]interface{}{
|
||||
"ID": number,
|
||||
}
|
||||
|
||||
@@ -170,6 +170,28 @@ func implicitValueRequiredStruct(tag string, tp reflect.Type) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func isLeftInclude(b byte) (bool, error) {
|
||||
switch b {
|
||||
case '[':
|
||||
return true, nil
|
||||
case '(':
|
||||
return false, nil
|
||||
default:
|
||||
return false, errNumberRange
|
||||
}
|
||||
}
|
||||
|
||||
func isRightInclude(b byte) (bool, error) {
|
||||
switch b {
|
||||
case ']':
|
||||
return true, nil
|
||||
case ')':
|
||||
return false, nil
|
||||
default:
|
||||
return false, errNumberRange
|
||||
}
|
||||
}
|
||||
|
||||
func maybeNewValue(field reflect.StructField, value reflect.Value) {
|
||||
if field.Type.Kind() == reflect.Ptr && value.IsNil() {
|
||||
value.Set(reflect.New(value.Type().Elem()))
|
||||
@@ -211,14 +233,9 @@ func parseNumberRange(str string) (*numberRange, error) {
|
||||
return nil, errNumberRange
|
||||
}
|
||||
|
||||
var leftInclude bool
|
||||
switch str[0] {
|
||||
case '[':
|
||||
leftInclude = true
|
||||
case '(':
|
||||
leftInclude = false
|
||||
default:
|
||||
return nil, errNumberRange
|
||||
leftInclude, err := isLeftInclude(str[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
str = str[1:]
|
||||
@@ -226,14 +243,9 @@ func parseNumberRange(str string) (*numberRange, error) {
|
||||
return nil, errNumberRange
|
||||
}
|
||||
|
||||
var rightInclude bool
|
||||
switch str[len(str)-1] {
|
||||
case ']':
|
||||
rightInclude = true
|
||||
case ')':
|
||||
rightInclude = false
|
||||
default:
|
||||
return nil, errNumberRange
|
||||
rightInclude, err := isRightInclude(str[len(str)-1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
str = str[:len(str)-1]
|
||||
|
||||
@@ -16,8 +16,8 @@ type Foo struct {
|
||||
}
|
||||
|
||||
func TestDeferInt(t *testing.T) {
|
||||
var i = 1
|
||||
var s = "hello"
|
||||
i := 1
|
||||
s := "hello"
|
||||
number := struct {
|
||||
f float64
|
||||
}{
|
||||
|
||||
@@ -36,7 +36,7 @@ func (u Unstable) AroundDuration(base time.Duration) time.Duration {
|
||||
return val
|
||||
}
|
||||
|
||||
// AroundInt returns a randome int64 with given base and deviation.
|
||||
// AroundInt returns a random int64 with given base and deviation.
|
||||
func (u Unstable) AroundInt(base int64) int64 {
|
||||
u.lock.Lock()
|
||||
val := int64((1 + u.deviation - 2*u.deviation*u.r.Float64()) * float64(base))
|
||||
|
||||
@@ -12,7 +12,7 @@ type (
|
||||
// CounterVec interface represents a counter vector.
|
||||
CounterVec interface {
|
||||
// Inc increments labels.
|
||||
Inc(lables ...string)
|
||||
Inc(labels ...string)
|
||||
// Add adds labels with v.
|
||||
Add(v float64, labels ...string)
|
||||
close() bool
|
||||
@@ -50,8 +50,8 @@ func (cv *promCounterVec) Inc(labels ...string) {
|
||||
cv.counter.WithLabelValues(labels...).Inc()
|
||||
}
|
||||
|
||||
func (cv *promCounterVec) Add(v float64, lables ...string) {
|
||||
cv.counter.WithLabelValues(lables...).Add(v)
|
||||
func (cv *promCounterVec) Add(v float64, labels ...string) {
|
||||
cv.counter.WithLabelValues(labels...).Add(v)
|
||||
}
|
||||
|
||||
func (cv *promCounterVec) close() bool {
|
||||
|
||||
@@ -20,7 +20,7 @@ type (
|
||||
close() bool
|
||||
}
|
||||
|
||||
promGuageVec struct {
|
||||
promGaugeVec struct {
|
||||
gauge *prom.GaugeVec
|
||||
}
|
||||
)
|
||||
@@ -39,7 +39,7 @@ func NewGaugeVec(cfg *GaugeVecOpts) GaugeVec {
|
||||
Help: cfg.Help,
|
||||
}, cfg.Labels)
|
||||
prom.MustRegister(vec)
|
||||
gv := &promGuageVec{
|
||||
gv := &promGaugeVec{
|
||||
gauge: vec,
|
||||
}
|
||||
proc.AddShutdownListener(func() {
|
||||
@@ -49,18 +49,18 @@ func NewGaugeVec(cfg *GaugeVecOpts) GaugeVec {
|
||||
return gv
|
||||
}
|
||||
|
||||
func (gv *promGuageVec) Inc(labels ...string) {
|
||||
func (gv *promGaugeVec) Inc(labels ...string) {
|
||||
gv.gauge.WithLabelValues(labels...).Inc()
|
||||
}
|
||||
|
||||
func (gv *promGuageVec) Add(v float64, lables ...string) {
|
||||
gv.gauge.WithLabelValues(lables...).Add(v)
|
||||
func (gv *promGaugeVec) Add(v float64, labels ...string) {
|
||||
gv.gauge.WithLabelValues(labels...).Add(v)
|
||||
}
|
||||
|
||||
func (gv *promGuageVec) Set(v float64, lables ...string) {
|
||||
gv.gauge.WithLabelValues(lables...).Set(v)
|
||||
func (gv *promGaugeVec) Set(v float64, labels ...string) {
|
||||
gv.gauge.WithLabelValues(labels...).Set(v)
|
||||
}
|
||||
|
||||
func (gv *promGuageVec) close() bool {
|
||||
func (gv *promGaugeVec) close() bool {
|
||||
return prom.Unregister(gv.gauge)
|
||||
}
|
||||
|
||||
@@ -29,7 +29,7 @@ func TestGaugeInc(t *testing.T) {
|
||||
Labels: []string{"path"},
|
||||
})
|
||||
defer gaugeVec.close()
|
||||
gv, _ := gaugeVec.(*promGuageVec)
|
||||
gv, _ := gaugeVec.(*promGaugeVec)
|
||||
gv.Inc("/users")
|
||||
gv.Inc("/users")
|
||||
r := testutil.ToFloat64(gv.gauge)
|
||||
@@ -45,7 +45,7 @@ func TestGaugeAdd(t *testing.T) {
|
||||
Labels: []string{"path"},
|
||||
})
|
||||
defer gaugeVec.close()
|
||||
gv, _ := gaugeVec.(*promGuageVec)
|
||||
gv, _ := gaugeVec.(*promGaugeVec)
|
||||
gv.Add(-10, "/classroom")
|
||||
gv.Add(30, "/classroom")
|
||||
r := testutil.ToFloat64(gv.gauge)
|
||||
@@ -61,7 +61,7 @@ func TestGaugeSet(t *testing.T) {
|
||||
Labels: []string{"path"},
|
||||
})
|
||||
gaugeVec.close()
|
||||
gv, _ := gaugeVec.(*promGuageVec)
|
||||
gv, _ := gaugeVec.(*promGaugeVec)
|
||||
gv.Set(666, "/users")
|
||||
r := testutil.ToFloat64(gv.gauge)
|
||||
assert.Equal(t, float64(666), r)
|
||||
|
||||
@@ -19,7 +19,7 @@ type (
|
||||
// A HistogramVec interface represents a histogram vector.
|
||||
HistogramVec interface {
|
||||
// Observe adds observation v to labels.
|
||||
Observe(v int64, lables ...string)
|
||||
Observe(v int64, labels ...string)
|
||||
close() bool
|
||||
}
|
||||
|
||||
|
||||
@@ -101,7 +101,7 @@ func Map(generate GenerateFunc, mapper MapFunc, opts ...Option) chan interface{}
|
||||
}
|
||||
|
||||
// MapReduce maps all elements generated from given generate func,
|
||||
// and reduces the output elemenets with given reducer.
|
||||
// and reduces the output elements with given reducer.
|
||||
func MapReduce(generate GenerateFunc, mapper MapperFunc, reducer ReducerFunc, opts ...Option) (interface{}, error) {
|
||||
source := buildSource(generate)
|
||||
return MapReduceWithSource(source, mapper, reducer, opts...)
|
||||
|
||||
@@ -7,10 +7,19 @@ import (
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"github.com/tal-tech/go-zero/core/logx"
|
||||
"github.com/tal-tech/go-zero/core/syncx"
|
||||
"github.com/tal-tech/go-zero/core/threading"
|
||||
)
|
||||
|
||||
var once sync.Once
|
||||
var (
|
||||
once sync.Once
|
||||
enabled syncx.AtomicBool
|
||||
)
|
||||
|
||||
// Enabled returns if prometheus is enabled.
|
||||
func Enabled() bool {
|
||||
return enabled.True()
|
||||
}
|
||||
|
||||
// StartAgent starts a prometheus agent.
|
||||
func StartAgent(c Config) {
|
||||
@@ -19,6 +28,7 @@ func StartAgent(c Config) {
|
||||
return
|
||||
}
|
||||
|
||||
enabled.Set(true)
|
||||
threading.GoSafe(func() {
|
||||
http.Handle(c.Path, promhttp.Handler())
|
||||
addr := fmt.Sprintf("%s:%d", c.Host, c.Port)
|
||||
|
||||
@@ -84,8 +84,7 @@ func (p *mockedProducer) Produce() (string, bool) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
type mockedListener struct {
|
||||
}
|
||||
type mockedListener struct{}
|
||||
|
||||
func (l *mockedListener) OnPause() {
|
||||
}
|
||||
|
||||
@@ -14,6 +14,8 @@ const (
|
||||
DevMode = "dev"
|
||||
// TestMode means test mode.
|
||||
TestMode = "test"
|
||||
// RtMode means regression test mode.
|
||||
RtMode = "rt"
|
||||
// PreMode means pre-release mode.
|
||||
PreMode = "pre"
|
||||
// ProMode means production mode.
|
||||
@@ -56,7 +58,7 @@ func (sc ServiceConf) SetUp() error {
|
||||
|
||||
func (sc ServiceConf) initMode() {
|
||||
switch sc.Mode {
|
||||
case DevMode, TestMode, PreMode:
|
||||
case DevMode, TestMode, RtMode, PreMode:
|
||||
load.Disable()
|
||||
stat.SetReporter(nil)
|
||||
}
|
||||
|
||||
@@ -95,8 +95,7 @@ func WithStarter(start Starter) Service {
|
||||
}
|
||||
|
||||
type (
|
||||
stopper struct {
|
||||
}
|
||||
stopper struct{}
|
||||
|
||||
startOnlyService struct {
|
||||
start func()
|
||||
|
||||
@@ -70,7 +70,7 @@ func TestServiceGroup_WithStart(t *testing.T) {
|
||||
wait.Add(len(multipliers))
|
||||
group := NewServiceGroup()
|
||||
for _, multiplier := range multipliers {
|
||||
var mul = multiplier
|
||||
mul := multiplier
|
||||
group.Add(WithStart(func() {
|
||||
lock.Lock()
|
||||
want *= mul
|
||||
@@ -97,7 +97,7 @@ func TestServiceGroup_WithStarter(t *testing.T) {
|
||||
wait.Add(len(multipliers))
|
||||
group := NewServiceGroup()
|
||||
for _, multiplier := range multipliers {
|
||||
var mul = multiplier
|
||||
mul := multiplier
|
||||
group.Add(WithStarter(mockedStarter{
|
||||
fn: func() {
|
||||
lock.Lock()
|
||||
|
||||
10
core/stores/cache/cachenode.go
vendored
10
core/stores/cache/cachenode.go
vendored
@@ -59,7 +59,7 @@ func NewNode(rds *redis.Redis, barrier syncx.SharedCalls, st *Stat,
|
||||
}
|
||||
}
|
||||
|
||||
// DelCache deletes cached values with keys.
|
||||
// Del deletes cached values with keys.
|
||||
func (c cacheNode) Del(keys ...string) error {
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
@@ -73,7 +73,7 @@ func (c cacheNode) Del(keys ...string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCache gets the cache with key and fills into v.
|
||||
// Get gets the cache with key and fills into v.
|
||||
func (c cacheNode) Get(key string, v interface{}) error {
|
||||
err := c.doGetCache(key, v)
|
||||
if err == errPlaceholder {
|
||||
@@ -88,12 +88,12 @@ func (c cacheNode) IsNotFound(err error) bool {
|
||||
return err == c.errNotFound
|
||||
}
|
||||
|
||||
// SetCache sets the cache with key and v, using c.expiry.
|
||||
// Set sets the cache with key and v, using c.expiry.
|
||||
func (c cacheNode) Set(key string, v interface{}) error {
|
||||
return c.SetWithExpire(key, v, c.aroundDuration(c.expiry))
|
||||
}
|
||||
|
||||
// SetCacheWithExpire sets the cache with key and v, using given expire.
|
||||
// SetWithExpire sets the cache with key and v, using given expire.
|
||||
func (c cacheNode) SetWithExpire(key string, v interface{}, expire time.Duration) error {
|
||||
data, err := jsonx.Marshal(v)
|
||||
if err != nil {
|
||||
@@ -108,7 +108,7 @@ func (c cacheNode) String() string {
|
||||
return c.rds.Addr
|
||||
}
|
||||
|
||||
// TakeWithExpire takes the result from cache first, if not found,
|
||||
// Take takes the result from cache first, if not found,
|
||||
// query from DB and set cache using c.expiry, then return the result.
|
||||
func (c cacheNode) Take(v interface{}, key string, query func(v interface{}) error) error {
|
||||
return c.doTake(v, key, query, func(v interface{}) error {
|
||||
|
||||
2
core/stores/cache/cachenode_test.go
vendored
2
core/stores/cache/cachenode_test.go
vendored
@@ -129,7 +129,7 @@ func TestCacheNode_TakeNotFound(t *testing.T) {
|
||||
assert.True(t, cn.IsNotFound(cn.Get("any", &str)))
|
||||
|
||||
store.Del("any")
|
||||
var errDummy = errors.New("dummy")
|
||||
errDummy := errors.New("dummy")
|
||||
err = cn.Take(&str, "any", func(v interface{}) error {
|
||||
return errDummy
|
||||
})
|
||||
|
||||
@@ -12,8 +12,10 @@ import (
|
||||
"github.com/tal-tech/go-zero/core/stringx"
|
||||
)
|
||||
|
||||
var s1, _ = miniredis.Run()
|
||||
var s2, _ = miniredis.Run()
|
||||
var (
|
||||
s1, _ = miniredis.Run()
|
||||
s2, _ = miniredis.Run()
|
||||
)
|
||||
|
||||
func TestRedis_Exists(t *testing.T) {
|
||||
store := clusterStore{dispatcher: hash.NewConsistentHash()}
|
||||
|
||||
@@ -43,11 +43,11 @@ type (
|
||||
}
|
||||
)
|
||||
|
||||
func newCollection(collection *mgo.Collection) Collection {
|
||||
func newCollection(collection *mgo.Collection, brk breaker.Breaker) Collection {
|
||||
return &decoratedCollection{
|
||||
name: collection.FullName,
|
||||
collection: collection,
|
||||
brk: breaker.NewBreaker(),
|
||||
brk: brk,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -71,7 +71,7 @@ func TestNewCollection(t *testing.T) {
|
||||
Database: nil,
|
||||
Name: "foo",
|
||||
FullName: "bar",
|
||||
})
|
||||
}, breaker.GetBreaker("localhost"))
|
||||
assert.Equal(t, "bar", col.(*decoratedCollection).name)
|
||||
}
|
||||
|
||||
@@ -279,8 +279,7 @@ func (p *mockPromise) Reject(reason string) {
|
||||
p.reason = reason
|
||||
}
|
||||
|
||||
type dropBreaker struct {
|
||||
}
|
||||
type dropBreaker struct{}
|
||||
|
||||
func (d *dropBreaker) Name() string {
|
||||
return "dummy"
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/globalsign/mgo"
|
||||
"github.com/tal-tech/go-zero/core/breaker"
|
||||
)
|
||||
|
||||
type (
|
||||
@@ -20,6 +21,7 @@ type (
|
||||
session *concurrentSession
|
||||
db *mgo.Database
|
||||
collection string
|
||||
brk breaker.Breaker
|
||||
opts []Option
|
||||
}
|
||||
)
|
||||
@@ -46,6 +48,7 @@ func NewModel(url, collection string, opts ...Option) (*Model, error) {
|
||||
// If name is empty, the database name provided in the dialed URL is used instead
|
||||
db: session.DB(""),
|
||||
collection: collection,
|
||||
brk: breaker.GetBreaker(url),
|
||||
opts: opts,
|
||||
}, nil
|
||||
}
|
||||
@@ -66,7 +69,7 @@ func (mm *Model) FindId(id interface{}) (Query, error) {
|
||||
|
||||
// GetCollection returns a Collection with given session.
|
||||
func (mm *Model) GetCollection(session *mgo.Session) Collection {
|
||||
return newCollection(mm.db.C(mm.collection).With(session))
|
||||
return newCollection(mm.db.C(mm.collection).With(session), mm.brk)
|
||||
}
|
||||
|
||||
// Insert inserts docs into mm.
|
||||
|
||||
@@ -17,6 +17,7 @@ type (
|
||||
Host string
|
||||
Type string `json:",default=node,options=node|cluster"`
|
||||
Pass string `json:",optional"`
|
||||
Tls bool `json:",default=false,options=true|false"`
|
||||
}
|
||||
|
||||
// A RedisKeyConf is a redis config with key.
|
||||
@@ -28,7 +29,18 @@ type (
|
||||
|
||||
// NewRedis returns a Redis.
|
||||
func (rc RedisConf) NewRedis() *Redis {
|
||||
return NewRedis(rc.Host, rc.Type, rc.Pass)
|
||||
var opts []Option
|
||||
if rc.Type == ClusterType {
|
||||
opts = append(opts, Cluster())
|
||||
}
|
||||
if len(rc.Pass) > 0 {
|
||||
opts = append(opts, WithPass(rc.Pass))
|
||||
}
|
||||
if rc.Tls {
|
||||
opts = append(opts, WithTLS())
|
||||
}
|
||||
|
||||
return New(rc.Host, opts...)
|
||||
}
|
||||
|
||||
// Validate validates the RedisConf.
|
||||
|
||||
@@ -29,6 +29,9 @@ const (
|
||||
var ErrNilNode = errors.New("nil redis node")
|
||||
|
||||
type (
|
||||
// Option defines the method to customize a Redis.
|
||||
Option func(r *Redis)
|
||||
|
||||
// A Pair is a key/pair set used in redis zset.
|
||||
Pair struct {
|
||||
Key string
|
||||
@@ -40,6 +43,7 @@ type (
|
||||
Addr string
|
||||
Type string
|
||||
Pass string
|
||||
tls bool
|
||||
brk breaker.Breaker
|
||||
}
|
||||
|
||||
@@ -69,19 +73,32 @@ type (
|
||||
FloatCmd = red.FloatCmd
|
||||
)
|
||||
|
||||
// NewRedis returns a Redis.
|
||||
func NewRedis(redisAddr, redisType string, redisPass ...string) *Redis {
|
||||
var pass string
|
||||
for _, v := range redisPass {
|
||||
pass = v
|
||||
}
|
||||
|
||||
return &Redis{
|
||||
Addr: redisAddr,
|
||||
Type: redisType,
|
||||
Pass: pass,
|
||||
// New returns a Redis with given options.
|
||||
func New(addr string, opts ...Option) *Redis {
|
||||
r := &Redis{
|
||||
Addr: addr,
|
||||
Type: NodeType,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(r)
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// NewRedis returns a Redis.
|
||||
func NewRedis(redisAddr, redisType string, redisPass ...string) *Redis {
|
||||
var opts []Option
|
||||
if redisType == ClusterType {
|
||||
opts = append(opts, Cluster())
|
||||
}
|
||||
for _, v := range redisPass {
|
||||
opts = append(opts, WithPass(v))
|
||||
}
|
||||
|
||||
return New(redisAddr, opts...)
|
||||
}
|
||||
|
||||
// BitCount is redis bitcount command implementation.
|
||||
@@ -250,6 +267,21 @@ func (s *Redis) Eval(script string, keys []string, args ...interface{}) (val int
|
||||
return
|
||||
}
|
||||
|
||||
// EvalSha is the implementation of redis evalsha command.
|
||||
func (s *Redis) EvalSha(sha string, keys []string, args ...interface{}) (val interface{}, err error) {
|
||||
err = s.brk.DoWithAcceptable(func() error {
|
||||
conn, err := getRedis(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
val, err = conn.EvalSha(sha, keys, args...).Result()
|
||||
return err
|
||||
}, acceptable)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Exists is the implementation of redis exists command.
|
||||
func (s *Redis) Exists(key string) (val bool, err error) {
|
||||
err = s.brk.DoWithAcceptable(func() error {
|
||||
@@ -449,14 +481,14 @@ func (s *Redis) GetBit(key string, offset int64) (val int, err error) {
|
||||
}
|
||||
|
||||
// Hdel is the implementation of redis hdel command.
|
||||
func (s *Redis) Hdel(key, field string) (val bool, err error) {
|
||||
func (s *Redis) Hdel(key string, fields ...string) (val bool, err error) {
|
||||
err = s.brk.DoWithAcceptable(func() error {
|
||||
conn, err := getRedis(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
v, err := conn.HDel(key, field).Result()
|
||||
v, err := conn.HDel(key, fields...).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -913,7 +945,6 @@ func (s *Redis) Pipelined(fn func(Pipeliner) error) (err error) {
|
||||
|
||||
_, err = conn.Pipelined(fn)
|
||||
return err
|
||||
|
||||
}, acceptable)
|
||||
|
||||
return
|
||||
@@ -1032,6 +1063,16 @@ func (s *Redis) Scard(key string) (val int64, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
// ScriptLoad is the implementation of redis script load command.
|
||||
func (s *Redis) ScriptLoad(script string) (string, error) {
|
||||
conn, err := getRedis(s)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return conn.ScriptLoad(script).Result()
|
||||
}
|
||||
|
||||
// Set is the implementation of redis set command.
|
||||
func (s *Redis) Set(key string, value string) error {
|
||||
return s.brk.DoWithAcceptable(func() error {
|
||||
@@ -1101,26 +1142,6 @@ func (s *Redis) Sismember(key string, value interface{}) (val bool, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
// Srem is the implementation of redis srem command.
|
||||
func (s *Redis) Srem(key string, values ...interface{}) (val int, err error) {
|
||||
err = s.brk.DoWithAcceptable(func() error {
|
||||
conn, err := getRedis(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
v, err := conn.SRem(key, values...).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
val = int(v)
|
||||
return nil
|
||||
}, acceptable)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Smembers is the implementation of redis smembers command.
|
||||
func (s *Redis) Smembers(key string) (val []string, err error) {
|
||||
err = s.brk.DoWithAcceptable(func() error {
|
||||
@@ -1166,6 +1187,31 @@ func (s *Redis) Srandmember(key string, count int) (val []string, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
// Srem is the implementation of redis srem command.
|
||||
func (s *Redis) Srem(key string, values ...interface{}) (val int, err error) {
|
||||
err = s.brk.DoWithAcceptable(func() error {
|
||||
conn, err := getRedis(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
v, err := conn.SRem(key, values...).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
val = int(v)
|
||||
return nil
|
||||
}, acceptable)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// String returns the string representation of s.
|
||||
func (s *Redis) String() string {
|
||||
return s.Addr
|
||||
}
|
||||
|
||||
// Sunion is the implementation of redis sunion command.
|
||||
func (s *Redis) Sunion(keys ...string) (val []string, err error) {
|
||||
err = s.brk.DoWithAcceptable(func() error {
|
||||
@@ -1667,18 +1713,25 @@ func (s *Redis) Zunionstore(dest string, store ZStore, keys ...string) (val int6
|
||||
return
|
||||
}
|
||||
|
||||
// String returns the string representation of s.
|
||||
func (s *Redis) String() string {
|
||||
return s.Addr
|
||||
// Cluster customizes the given Redis as a cluster.
|
||||
func Cluster() Option {
|
||||
return func(r *Redis) {
|
||||
r.Type = ClusterType
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Redis) scriptLoad(script string) (string, error) {
|
||||
conn, err := getRedis(s)
|
||||
if err != nil {
|
||||
return "", err
|
||||
// WithPass customizes the given Redis with given password.
|
||||
func WithPass(pass string) Option {
|
||||
return func(r *Redis) {
|
||||
r.Pass = pass
|
||||
}
|
||||
}
|
||||
|
||||
return conn.ScriptLoad(script).Result()
|
||||
// WithTLS customizes the given Redis with TLS enabled.
|
||||
func WithTLS() Option {
|
||||
return func(r *Redis) {
|
||||
r.tls = true
|
||||
}
|
||||
}
|
||||
|
||||
func acceptable(err error) bool {
|
||||
@@ -1688,9 +1741,9 @@ func acceptable(err error) bool {
|
||||
func getRedis(r *Redis) (RedisNode, error) {
|
||||
switch r.Type {
|
||||
case ClusterType:
|
||||
return getCluster(r.Addr, r.Pass)
|
||||
return getCluster(r)
|
||||
case NodeType:
|
||||
return getClient(r.Addr, r.Pass)
|
||||
return getClient(r)
|
||||
default:
|
||||
return nil, fmt.Errorf("redis type '%s' is not supported", r.Type)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
"strconv"
|
||||
@@ -14,7 +15,7 @@ import (
|
||||
|
||||
func TestRedis_Exists(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
_, err := NewRedis(client.Addr, "").Exists("a")
|
||||
_, err := New(client.Addr, badType()).Exists("a")
|
||||
assert.NotNil(t, err)
|
||||
ok, err := client.Exists("a")
|
||||
assert.Nil(t, err)
|
||||
@@ -26,9 +27,23 @@ func TestRedis_Exists(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedisTLS_Exists(t *testing.T) {
|
||||
runOnRedisTLS(t, func(client *Redis) {
|
||||
_, err := New(client.Addr, badType()).Exists("a")
|
||||
assert.NotNil(t, err)
|
||||
ok, err := client.Exists("a")
|
||||
assert.NotNil(t, err)
|
||||
assert.False(t, ok)
|
||||
assert.NotNil(t, client.Set("a", "b"))
|
||||
ok, err = client.Exists("a")
|
||||
assert.NotNil(t, err)
|
||||
assert.False(t, ok)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_Eval(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
_, err := NewRedis(client.Addr, "").Eval(`redis.call("EXISTS", KEYS[1])`, []string{"notexist"})
|
||||
_, err := New(client.Addr, badType()).Eval(`redis.call("EXISTS", KEYS[1])`, []string{"notexist"})
|
||||
assert.NotNil(t, err)
|
||||
_, err = client.Eval(`redis.call("EXISTS", KEYS[1])`, []string{"notexist"})
|
||||
assert.Equal(t, Nil, err)
|
||||
@@ -53,7 +68,7 @@ func TestRedis_Hgetall(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
assert.Nil(t, client.Hset("a", "aa", "aaa"))
|
||||
assert.Nil(t, client.Hset("a", "bb", "bbb"))
|
||||
_, err := NewRedis(client.Addr, "").Hgetall("a")
|
||||
_, err := New(client.Addr, badType()).Hgetall("a")
|
||||
assert.NotNil(t, err)
|
||||
vals, err := client.Hgetall("a")
|
||||
assert.Nil(t, err)
|
||||
@@ -66,10 +81,10 @@ func TestRedis_Hgetall(t *testing.T) {
|
||||
|
||||
func TestRedis_Hvals(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
assert.NotNil(t, NewRedis(client.Addr, "").Hset("a", "aa", "aaa"))
|
||||
assert.NotNil(t, New(client.Addr, badType()).Hset("a", "aa", "aaa"))
|
||||
assert.Nil(t, client.Hset("a", "aa", "aaa"))
|
||||
assert.Nil(t, client.Hset("a", "bb", "bbb"))
|
||||
_, err := NewRedis(client.Addr, "").Hvals("a")
|
||||
_, err := New(client.Addr, badType()).Hvals("a")
|
||||
assert.NotNil(t, err)
|
||||
vals, err := client.Hvals("a")
|
||||
assert.Nil(t, err)
|
||||
@@ -81,7 +96,7 @@ func TestRedis_Hsetnx(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
assert.Nil(t, client.Hset("a", "aa", "aaa"))
|
||||
assert.Nil(t, client.Hset("a", "bb", "bbb"))
|
||||
_, err := NewRedis(client.Addr, "").Hsetnx("a", "bb", "ccc")
|
||||
_, err := New(client.Addr, badType()).Hsetnx("a", "bb", "ccc")
|
||||
assert.NotNil(t, err)
|
||||
ok, err := client.Hsetnx("a", "bb", "ccc")
|
||||
assert.Nil(t, err)
|
||||
@@ -99,7 +114,7 @@ func TestRedis_HdelHlen(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
assert.Nil(t, client.Hset("a", "aa", "aaa"))
|
||||
assert.Nil(t, client.Hset("a", "bb", "bbb"))
|
||||
_, err := NewRedis(client.Addr, "").Hlen("a")
|
||||
_, err := New(client.Addr, badType()).Hlen("a")
|
||||
assert.NotNil(t, err)
|
||||
num, err := client.Hlen("a")
|
||||
assert.Nil(t, err)
|
||||
@@ -115,7 +130,7 @@ func TestRedis_HdelHlen(t *testing.T) {
|
||||
|
||||
func TestRedis_HIncrBy(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
_, err := NewRedis(client.Addr, "").Hincrby("key", "field", 2)
|
||||
_, err := New(client.Addr, badType()).Hincrby("key", "field", 2)
|
||||
assert.NotNil(t, err)
|
||||
val, err := client.Hincrby("key", "field", 2)
|
||||
assert.Nil(t, err)
|
||||
@@ -130,7 +145,7 @@ func TestRedis_Hkeys(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
assert.Nil(t, client.Hset("a", "aa", "aaa"))
|
||||
assert.Nil(t, client.Hset("a", "bb", "bbb"))
|
||||
_, err := NewRedis(client.Addr, "").Hkeys("a")
|
||||
_, err := New(client.Addr, badType()).Hkeys("a")
|
||||
assert.NotNil(t, err)
|
||||
vals, err := client.Hkeys("a")
|
||||
assert.Nil(t, err)
|
||||
@@ -142,7 +157,7 @@ func TestRedis_Hmget(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
assert.Nil(t, client.Hset("a", "aa", "aaa"))
|
||||
assert.Nil(t, client.Hset("a", "bb", "bbb"))
|
||||
_, err := NewRedis(client.Addr, "").Hmget("a", "aa", "bb")
|
||||
_, err := New(client.Addr, badType()).Hmget("a", "aa", "bb")
|
||||
assert.NotNil(t, err)
|
||||
vals, err := client.Hmget("a", "aa", "bb")
|
||||
assert.Nil(t, err)
|
||||
@@ -155,7 +170,7 @@ func TestRedis_Hmget(t *testing.T) {
|
||||
|
||||
func TestRedis_Hmset(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
assert.NotNil(t, NewRedis(client.Addr, "").Hmset("a", nil))
|
||||
assert.NotNil(t, New(client.Addr, badType()).Hmset("a", nil))
|
||||
assert.Nil(t, client.Hmset("a", map[string]string{
|
||||
"aa": "aaa",
|
||||
"bb": "bbb",
|
||||
@@ -179,7 +194,7 @@ func TestRedis_Hscan(t *testing.T) {
|
||||
var cursor uint64 = 0
|
||||
sum := 0
|
||||
for {
|
||||
_, _, err := NewRedis(client.Addr, "").Hscan(key, cursor, "*", 100)
|
||||
_, _, err := New(client.Addr, badType()).Hscan(key, cursor, "*", 100)
|
||||
assert.NotNil(t, err)
|
||||
reMap, next, err := client.Hscan(key, cursor, "*", 100)
|
||||
assert.Nil(t, err)
|
||||
@@ -191,7 +206,7 @@ func TestRedis_Hscan(t *testing.T) {
|
||||
}
|
||||
|
||||
assert.Equal(t, sum, 3100)
|
||||
_, err = NewRedis(client.Addr, "").Del(key)
|
||||
_, err = New(client.Addr, badType()).Del(key)
|
||||
assert.NotNil(t, err)
|
||||
_, err = client.Del(key)
|
||||
assert.Nil(t, err)
|
||||
@@ -200,7 +215,7 @@ func TestRedis_Hscan(t *testing.T) {
|
||||
|
||||
func TestRedis_Incr(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
_, err := NewRedis(client.Addr, "").Incr("a")
|
||||
_, err := New(client.Addr, badType()).Incr("a")
|
||||
assert.NotNil(t, err)
|
||||
val, err := client.Incr("a")
|
||||
assert.Nil(t, err)
|
||||
@@ -213,7 +228,7 @@ func TestRedis_Incr(t *testing.T) {
|
||||
|
||||
func TestRedis_IncrBy(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
_, err := NewRedis(client.Addr, "").Incrby("a", 2)
|
||||
_, err := New(client.Addr, badType()).Incrby("a", 2)
|
||||
assert.NotNil(t, err)
|
||||
val, err := client.Incrby("a", 2)
|
||||
assert.Nil(t, err)
|
||||
@@ -230,7 +245,7 @@ func TestRedis_Keys(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
err = client.Set("key2", "value2")
|
||||
assert.Nil(t, err)
|
||||
_, err = NewRedis(client.Addr, "").Keys("*")
|
||||
_, err = New(client.Addr, badType()).Keys("*")
|
||||
assert.NotNil(t, err)
|
||||
keys, err := client.Keys("*")
|
||||
assert.Nil(t, err)
|
||||
@@ -241,7 +256,7 @@ func TestRedis_Keys(t *testing.T) {
|
||||
func TestRedis_HyperLogLog(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
client.Ping()
|
||||
r := NewRedis(client.Addr, "")
|
||||
r := New(client.Addr, badType())
|
||||
_, err := r.Pfadd("key1")
|
||||
assert.NotNil(t, err)
|
||||
_, err = r.Pfcount("*")
|
||||
@@ -253,17 +268,17 @@ func TestRedis_HyperLogLog(t *testing.T) {
|
||||
|
||||
func TestRedis_List(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
_, err := NewRedis(client.Addr, "").Lpush("key", "value1", "value2")
|
||||
_, err := New(client.Addr, badType()).Lpush("key", "value1", "value2")
|
||||
assert.NotNil(t, err)
|
||||
val, err := client.Lpush("key", "value1", "value2")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 2, val)
|
||||
_, err = NewRedis(client.Addr, "").Rpush("key", "value3", "value4")
|
||||
_, err = New(client.Addr, badType()).Rpush("key", "value3", "value4")
|
||||
assert.NotNil(t, err)
|
||||
val, err = client.Rpush("key", "value3", "value4")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 4, val)
|
||||
_, err = NewRedis(client.Addr, "").Llen("key")
|
||||
_, err = New(client.Addr, badType()).Llen("key")
|
||||
assert.NotNil(t, err)
|
||||
val, err = client.Llen("key")
|
||||
assert.Nil(t, err)
|
||||
@@ -271,7 +286,7 @@ func TestRedis_List(t *testing.T) {
|
||||
vals, err := client.Lrange("key", 0, 10)
|
||||
assert.Nil(t, err)
|
||||
assert.EqualValues(t, []string{"value2", "value1", "value3", "value4"}, vals)
|
||||
_, err = NewRedis(client.Addr, "").Lpop("key")
|
||||
_, err = New(client.Addr, badType()).Lpop("key")
|
||||
assert.NotNil(t, err)
|
||||
v, err := client.Lpop("key")
|
||||
assert.Nil(t, err)
|
||||
@@ -279,7 +294,7 @@ func TestRedis_List(t *testing.T) {
|
||||
val, err = client.Lpush("key", "value1", "value2")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 5, val)
|
||||
_, err = NewRedis(client.Addr, "").Rpop("key")
|
||||
_, err = New(client.Addr, badType()).Rpop("key")
|
||||
assert.NotNil(t, err)
|
||||
v, err = client.Rpop("key")
|
||||
assert.Nil(t, err)
|
||||
@@ -287,12 +302,12 @@ func TestRedis_List(t *testing.T) {
|
||||
val, err = client.Rpush("key", "value4", "value3", "value3")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 7, val)
|
||||
_, err = NewRedis(client.Addr, "").Lrem("key", 2, "value1")
|
||||
_, err = New(client.Addr, badType()).Lrem("key", 2, "value1")
|
||||
assert.NotNil(t, err)
|
||||
n, err := client.Lrem("key", 2, "value1")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 2, n)
|
||||
_, err = NewRedis(client.Addr, "").Lrange("key", 0, 10)
|
||||
_, err = New(client.Addr, badType()).Lrange("key", 0, 10)
|
||||
assert.NotNil(t, err)
|
||||
vals, err = client.Lrange("key", 0, 10)
|
||||
assert.Nil(t, err)
|
||||
@@ -312,7 +327,7 @@ func TestRedis_Mget(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
err = client.Set("key2", "value2")
|
||||
assert.Nil(t, err)
|
||||
_, err = NewRedis(client.Addr, "").Mget("key1", "key0", "key2", "key3")
|
||||
_, err = New(client.Addr, badType()).Mget("key1", "key0", "key2", "key3")
|
||||
assert.NotNil(t, err)
|
||||
vals, err := client.Mget("key1", "key0", "key2", "key3")
|
||||
assert.Nil(t, err)
|
||||
@@ -322,7 +337,7 @@ func TestRedis_Mget(t *testing.T) {
|
||||
|
||||
func TestRedis_SetBit(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
err := NewRedis(client.Addr, "").SetBit("key", 1, 1)
|
||||
err := New(client.Addr, badType()).SetBit("key", 1, 1)
|
||||
assert.NotNil(t, err)
|
||||
err = client.SetBit("key", 1, 1)
|
||||
assert.Nil(t, err)
|
||||
@@ -333,7 +348,7 @@ func TestRedis_GetBit(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
err := client.SetBit("key", 2, 1)
|
||||
assert.Nil(t, err)
|
||||
_, err = NewRedis(client.Addr, "").GetBit("key", 2)
|
||||
_, err = New(client.Addr, badType()).GetBit("key", 2)
|
||||
assert.NotNil(t, err)
|
||||
val, err := client.GetBit("key", 2)
|
||||
assert.Nil(t, err)
|
||||
@@ -348,7 +363,7 @@ func TestRedis_BitCount(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
_, err := NewRedis(client.Addr, "").BitCount("key", 0, -1)
|
||||
_, err := New(client.Addr, badType()).BitCount("key", 0, -1)
|
||||
assert.NotNil(t, err)
|
||||
val, err := client.BitCount("key", 0, -1)
|
||||
assert.Nil(t, err)
|
||||
@@ -369,7 +384,6 @@ func TestRedis_BitCount(t *testing.T) {
|
||||
val, err = client.BitCount("key", 2, 2)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(0), val)
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
@@ -379,14 +393,14 @@ func TestRedis_BitOpAnd(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
err = client.Set("key2", "1")
|
||||
assert.Nil(t, err)
|
||||
_, err = NewRedis(client.Addr, "").BitOpAnd("destKey", "key1", "key2")
|
||||
_, err = New(client.Addr, badType()).BitOpAnd("destKey", "key1", "key2")
|
||||
assert.NotNil(t, err)
|
||||
val, err := client.BitOpAnd("destKey", "key1", "key2")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(1), val)
|
||||
valStr, err := client.Get("destKey")
|
||||
assert.Nil(t, err)
|
||||
//destKey binary 110000 ascii 0
|
||||
// destKey binary 110000 ascii 0
|
||||
assert.Equal(t, "0", valStr)
|
||||
})
|
||||
}
|
||||
@@ -395,7 +409,7 @@ func TestRedis_BitOpNot(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
err := client.Set("key1", "\u0000")
|
||||
assert.Nil(t, err)
|
||||
_, err = NewRedis(client.Addr, "").BitOpNot("destKey", "key1")
|
||||
_, err = New(client.Addr, badType()).BitOpNot("destKey", "key1")
|
||||
assert.NotNil(t, err)
|
||||
val, err := client.BitOpNot("destKey", "key1")
|
||||
assert.Nil(t, err)
|
||||
@@ -412,7 +426,7 @@ func TestRedis_BitOpOr(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
err = client.Set("key2", "0")
|
||||
assert.Nil(t, err)
|
||||
_, err = NewRedis(client.Addr, "").BitOpOr("destKey", "key1", "key2")
|
||||
_, err = New(client.Addr, badType()).BitOpOr("destKey", "key1", "key2")
|
||||
assert.NotNil(t, err)
|
||||
val, err := client.BitOpOr("destKey", "key1", "key2")
|
||||
assert.Nil(t, err)
|
||||
@@ -429,7 +443,7 @@ func TestRedis_BitOpXor(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
err = client.Set("key2", "\x0f")
|
||||
assert.Nil(t, err)
|
||||
_, err = NewRedis(client.Addr, "").BitOpXor("destKey", "key1", "key2")
|
||||
_, err = New(client.Addr, badType()).BitOpXor("destKey", "key1", "key2")
|
||||
assert.NotNil(t, err)
|
||||
val, err := client.BitOpXor("destKey", "key1", "key2")
|
||||
assert.Nil(t, err)
|
||||
@@ -439,13 +453,14 @@ func TestRedis_BitOpXor(t *testing.T) {
|
||||
assert.Equal(t, "\xf0", valStr)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_BitPos(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
//11111111 11110000 00000000
|
||||
// 11111111 11110000 00000000
|
||||
err := client.Set("key", "\xff\xf0\x00")
|
||||
assert.Nil(t, err)
|
||||
|
||||
_, err = NewRedis(client.Addr, "").BitPos("key", 0, 0, -1)
|
||||
_, err = New(client.Addr, badType()).BitPos("key", 0, 0, -1)
|
||||
assert.NotNil(t, err)
|
||||
val, err := client.BitPos("key", 0, 0, 2)
|
||||
assert.Nil(t, err)
|
||||
@@ -466,13 +481,12 @@ func TestRedis_BitPos(t *testing.T) {
|
||||
val, err = client.BitPos("key", 1, 2, 2)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(-1), val)
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_Persist(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
_, err := NewRedis(client.Addr, "").Persist("key")
|
||||
_, err := New(client.Addr, badType()).Persist("key")
|
||||
assert.NotNil(t, err)
|
||||
ok, err := client.Persist("key")
|
||||
assert.Nil(t, err)
|
||||
@@ -482,14 +496,14 @@ func TestRedis_Persist(t *testing.T) {
|
||||
ok, err = client.Persist("key")
|
||||
assert.Nil(t, err)
|
||||
assert.False(t, ok)
|
||||
err = NewRedis(client.Addr, "").Expire("key", 5)
|
||||
err = New(client.Addr, badType()).Expire("key", 5)
|
||||
assert.NotNil(t, err)
|
||||
err = client.Expire("key", 5)
|
||||
assert.Nil(t, err)
|
||||
ok, err = client.Persist("key")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
err = NewRedis(client.Addr, "").Expireat("key", time.Now().Unix()+5)
|
||||
err = New(client.Addr, badType()).Expireat("key", time.Now().Unix()+5)
|
||||
assert.NotNil(t, err)
|
||||
err = client.Expireat("key", time.Now().Unix()+5)
|
||||
assert.Nil(t, err)
|
||||
@@ -512,7 +526,7 @@ func TestRedis_Scan(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
err = client.Set("key2", "value2")
|
||||
assert.Nil(t, err)
|
||||
_, _, err = NewRedis(client.Addr, "").Scan(0, "*", 100)
|
||||
_, _, err = New(client.Addr, badType()).Scan(0, "*", 100)
|
||||
assert.NotNil(t, err)
|
||||
keys, _, err := client.Scan(0, "*", 100)
|
||||
assert.Nil(t, err)
|
||||
@@ -534,7 +548,7 @@ func TestRedis_Sscan(t *testing.T) {
|
||||
var cursor uint64 = 0
|
||||
sum := 0
|
||||
for {
|
||||
_, _, err := NewRedis(client.Addr, "").Sscan(key, cursor, "", 100)
|
||||
_, _, err := New(client.Addr, badType()).Sscan(key, cursor, "", 100)
|
||||
assert.NotNil(t, err)
|
||||
keys, next, err := client.Sscan(key, cursor, "", 100)
|
||||
assert.Nil(t, err)
|
||||
@@ -546,7 +560,7 @@ func TestRedis_Sscan(t *testing.T) {
|
||||
}
|
||||
|
||||
assert.Equal(t, sum, 1550)
|
||||
_, err = NewRedis(client.Addr, "").Del(key)
|
||||
_, err = New(client.Addr, badType()).Del(key)
|
||||
assert.NotNil(t, err)
|
||||
_, err = client.Del(key)
|
||||
assert.Nil(t, err)
|
||||
@@ -555,48 +569,48 @@ func TestRedis_Sscan(t *testing.T) {
|
||||
|
||||
func TestRedis_Set(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
_, err := NewRedis(client.Addr, "").Sadd("key", 1, 2, 3, 4)
|
||||
_, err := New(client.Addr, badType()).Sadd("key", 1, 2, 3, 4)
|
||||
assert.NotNil(t, err)
|
||||
num, err := client.Sadd("key", 1, 2, 3, 4)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 4, num)
|
||||
_, err = NewRedis(client.Addr, "").Scard("key")
|
||||
_, err = New(client.Addr, badType()).Scard("key")
|
||||
assert.NotNil(t, err)
|
||||
val, err := client.Scard("key")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(4), val)
|
||||
_, err = NewRedis(client.Addr, "").Sismember("key", 2)
|
||||
_, err = New(client.Addr, badType()).Sismember("key", 2)
|
||||
assert.NotNil(t, err)
|
||||
ok, err := client.Sismember("key", 2)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
_, err = NewRedis(client.Addr, "").Srem("key", 3, 4)
|
||||
_, err = New(client.Addr, badType()).Srem("key", 3, 4)
|
||||
assert.NotNil(t, err)
|
||||
num, err = client.Srem("key", 3, 4)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 2, num)
|
||||
_, err = NewRedis(client.Addr, "").Smembers("key")
|
||||
_, err = New(client.Addr, badType()).Smembers("key")
|
||||
assert.NotNil(t, err)
|
||||
vals, err := client.Smembers("key")
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []string{"1", "2"}, vals)
|
||||
_, err = NewRedis(client.Addr, "").Srandmember("key", 1)
|
||||
_, err = New(client.Addr, badType()).Srandmember("key", 1)
|
||||
assert.NotNil(t, err)
|
||||
members, err := client.Srandmember("key", 1)
|
||||
assert.Nil(t, err)
|
||||
assert.Len(t, members, 1)
|
||||
assert.Contains(t, []string{"1", "2"}, members[0])
|
||||
_, err = NewRedis(client.Addr, "").Spop("key")
|
||||
_, err = New(client.Addr, badType()).Spop("key")
|
||||
assert.NotNil(t, err)
|
||||
member, err := client.Spop("key")
|
||||
assert.Nil(t, err)
|
||||
assert.Contains(t, []string{"1", "2"}, member)
|
||||
_, err = NewRedis(client.Addr, "").Smembers("key")
|
||||
_, err = New(client.Addr, badType()).Smembers("key")
|
||||
assert.NotNil(t, err)
|
||||
vals, err = client.Smembers("key")
|
||||
assert.Nil(t, err)
|
||||
assert.NotContains(t, vals, member)
|
||||
_, err = NewRedis(client.Addr, "").Sadd("key1", 1, 2, 3, 4)
|
||||
_, err = New(client.Addr, badType()).Sadd("key1", 1, 2, 3, 4)
|
||||
assert.NotNil(t, err)
|
||||
num, err = client.Sadd("key1", 1, 2, 3, 4)
|
||||
assert.Nil(t, err)
|
||||
@@ -604,22 +618,22 @@ func TestRedis_Set(t *testing.T) {
|
||||
num, err = client.Sadd("key2", 2, 3, 4, 5)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 4, num)
|
||||
_, err = NewRedis(client.Addr, "").Sunion("key1", "key2")
|
||||
_, err = New(client.Addr, badType()).Sunion("key1", "key2")
|
||||
assert.NotNil(t, err)
|
||||
vals, err = client.Sunion("key1", "key2")
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []string{"1", "2", "3", "4", "5"}, vals)
|
||||
_, err = NewRedis(client.Addr, "").Sunionstore("key3", "key1", "key2")
|
||||
_, err = New(client.Addr, badType()).Sunionstore("key3", "key1", "key2")
|
||||
assert.NotNil(t, err)
|
||||
num, err = client.Sunionstore("key3", "key1", "key2")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 5, num)
|
||||
_, err = NewRedis(client.Addr, "").Sdiff("key1", "key2")
|
||||
_, err = New(client.Addr, badType()).Sdiff("key1", "key2")
|
||||
assert.NotNil(t, err)
|
||||
vals, err = client.Sdiff("key1", "key2")
|
||||
assert.Nil(t, err)
|
||||
assert.EqualValues(t, []string{"1"}, vals)
|
||||
_, err = NewRedis(client.Addr, "").Sdiffstore("key4", "key1", "key2")
|
||||
_, err = New(client.Addr, badType()).Sdiffstore("key4", "key1", "key2")
|
||||
assert.NotNil(t, err)
|
||||
num, err = client.Sdiffstore("key4", "key1", "key2")
|
||||
assert.Nil(t, err)
|
||||
@@ -629,11 +643,11 @@ func TestRedis_Set(t *testing.T) {
|
||||
|
||||
func TestRedis_SetGetDel(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
err := NewRedis(client.Addr, "").Set("hello", "world")
|
||||
err := New(client.Addr, badType()).Set("hello", "world")
|
||||
assert.NotNil(t, err)
|
||||
err = client.Set("hello", "world")
|
||||
assert.Nil(t, err)
|
||||
_, err = NewRedis(client.Addr, "").Get("hello")
|
||||
_, err = New(client.Addr, badType()).Get("hello")
|
||||
assert.NotNil(t, err)
|
||||
val, err := client.Get("hello")
|
||||
assert.Nil(t, err)
|
||||
@@ -646,11 +660,11 @@ func TestRedis_SetGetDel(t *testing.T) {
|
||||
|
||||
func TestRedis_SetExNx(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
err := NewRedis(client.Addr, "").Setex("hello", "world", 5)
|
||||
err := New(client.Addr, badType()).Setex("hello", "world", 5)
|
||||
assert.NotNil(t, err)
|
||||
err = client.Setex("hello", "world", 5)
|
||||
assert.Nil(t, err)
|
||||
_, err = NewRedis(client.Addr, "").Setnx("hello", "newworld")
|
||||
_, err = New(client.Addr, badType()).Setnx("hello", "newworld")
|
||||
assert.NotNil(t, err)
|
||||
ok, err := client.Setnx("hello", "newworld")
|
||||
assert.Nil(t, err)
|
||||
@@ -667,7 +681,7 @@ func TestRedis_SetExNx(t *testing.T) {
|
||||
ttl, err := client.Ttl("hello")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ttl > 0)
|
||||
_, err = NewRedis(client.Addr, "").SetnxEx("newhello", "newworld", 5)
|
||||
_, err = New(client.Addr, badType()).SetnxEx("newhello", "newworld", 5)
|
||||
assert.NotNil(t, err)
|
||||
ok, err = client.SetnxEx("newhello", "newworld", 5)
|
||||
assert.Nil(t, err)
|
||||
@@ -688,17 +702,17 @@ func TestRedis_SetGetDelHashField(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
err := client.Hset("key", "field", "value")
|
||||
assert.Nil(t, err)
|
||||
_, err = NewRedis(client.Addr, "").Hget("key", "field")
|
||||
_, err = New(client.Addr, badType()).Hget("key", "field")
|
||||
assert.NotNil(t, err)
|
||||
val, err := client.Hget("key", "field")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "value", val)
|
||||
_, err = NewRedis(client.Addr, "").Hexists("key", "field")
|
||||
_, err = New(client.Addr, badType()).Hexists("key", "field")
|
||||
assert.NotNil(t, err)
|
||||
ok, err := client.Hexists("key", "field")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
_, err = NewRedis(client.Addr, "").Hdel("key", "field")
|
||||
_, err = New(client.Addr, badType()).Hdel("key", "field")
|
||||
assert.NotNil(t, err)
|
||||
ret, err := client.Hdel("key", "field")
|
||||
assert.Nil(t, err)
|
||||
@@ -720,17 +734,17 @@ func TestRedis_SortedSet(t *testing.T) {
|
||||
val, err := client.Zscore("key", "value1")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(2), val)
|
||||
_, err = NewRedis(client.Addr, "").Zincrby("key", 3, "value1")
|
||||
_, err = New(client.Addr, badType()).Zincrby("key", 3, "value1")
|
||||
assert.NotNil(t, err)
|
||||
val, err = client.Zincrby("key", 3, "value1")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(5), val)
|
||||
_, err = NewRedis(client.Addr, "").Zscore("key", "value1")
|
||||
_, err = New(client.Addr, badType()).Zscore("key", "value1")
|
||||
assert.NotNil(t, err)
|
||||
val, err = client.Zscore("key", "value1")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(5), val)
|
||||
_, err = NewRedis(client.Addr, "").Zadds("key")
|
||||
_, err = New(client.Addr, badType()).Zadds("key")
|
||||
assert.NotNil(t, err)
|
||||
val, err = client.Zadds("key", Pair{
|
||||
Key: "value2",
|
||||
@@ -741,7 +755,7 @@ func TestRedis_SortedSet(t *testing.T) {
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(2), val)
|
||||
_, err = NewRedis(client.Addr, "").ZRevRangeWithScores("key", 1, 3)
|
||||
_, err = New(client.Addr, badType()).ZRevRangeWithScores("key", 1, 3)
|
||||
assert.NotNil(t, err)
|
||||
pairs, err := client.ZRevRangeWithScores("key", 1, 3)
|
||||
assert.Nil(t, err)
|
||||
@@ -761,11 +775,11 @@ func TestRedis_SortedSet(t *testing.T) {
|
||||
rank, err = client.Zrevrank("key", "value1")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(2), rank)
|
||||
_, err = NewRedis(client.Addr, "").Zrank("key", "value4")
|
||||
_, err = New(client.Addr, badType()).Zrank("key", "value4")
|
||||
assert.NotNil(t, err)
|
||||
_, err = client.Zrank("key", "value4")
|
||||
assert.Equal(t, Nil, err)
|
||||
_, err = NewRedis(client.Addr, "").Zrem("key", "value2", "value3")
|
||||
_, err = New(client.Addr, badType()).Zrem("key", "value2", "value3")
|
||||
assert.NotNil(t, err)
|
||||
num, err := client.Zrem("key", "value2", "value3")
|
||||
assert.Nil(t, err)
|
||||
@@ -779,7 +793,7 @@ func TestRedis_SortedSet(t *testing.T) {
|
||||
ok, err = client.Zadd("key", 8, "value4")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
_, err = NewRedis(client.Addr, "").Zremrangebyscore("key", 6, 7)
|
||||
_, err = New(client.Addr, badType()).Zremrangebyscore("key", 6, 7)
|
||||
assert.NotNil(t, err)
|
||||
num, err = client.Zremrangebyscore("key", 6, 7)
|
||||
assert.Nil(t, err)
|
||||
@@ -787,37 +801,37 @@ func TestRedis_SortedSet(t *testing.T) {
|
||||
ok, err = client.Zadd("key", 6, "value2")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
_, err = NewRedis(client.Addr, "").Zadd("key", 7, "value3")
|
||||
_, err = New(client.Addr, badType()).Zadd("key", 7, "value3")
|
||||
assert.NotNil(t, err)
|
||||
ok, err = client.Zadd("key", 7, "value3")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
_, err = NewRedis(client.Addr, "").Zcount("key", 6, 7)
|
||||
_, err = New(client.Addr, badType()).Zcount("key", 6, 7)
|
||||
assert.NotNil(t, err)
|
||||
num, err = client.Zcount("key", 6, 7)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 2, num)
|
||||
_, err = NewRedis(client.Addr, "").Zremrangebyrank("key", 1, 2)
|
||||
_, err = New(client.Addr, badType()).Zremrangebyrank("key", 1, 2)
|
||||
assert.NotNil(t, err)
|
||||
num, err = client.Zremrangebyrank("key", 1, 2)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 2, num)
|
||||
_, err = NewRedis(client.Addr, "").Zcard("key")
|
||||
_, err = New(client.Addr, badType()).Zcard("key")
|
||||
assert.NotNil(t, err)
|
||||
card, err := client.Zcard("key")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 2, card)
|
||||
_, err = NewRedis(client.Addr, "").Zrange("key", 0, -1)
|
||||
_, err = New(client.Addr, badType()).Zrange("key", 0, -1)
|
||||
assert.NotNil(t, err)
|
||||
vals, err := client.Zrange("key", 0, -1)
|
||||
assert.Nil(t, err)
|
||||
assert.EqualValues(t, []string{"value1", "value4"}, vals)
|
||||
_, err = NewRedis(client.Addr, "").Zrevrange("key", 0, -1)
|
||||
_, err = New(client.Addr, badType()).Zrevrange("key", 0, -1)
|
||||
assert.NotNil(t, err)
|
||||
vals, err = client.Zrevrange("key", 0, -1)
|
||||
assert.Nil(t, err)
|
||||
assert.EqualValues(t, []string{"value4", "value1"}, vals)
|
||||
_, err = NewRedis(client.Addr, "").ZrangeWithScores("key", 0, -1)
|
||||
_, err = New(client.Addr, badType()).ZrangeWithScores("key", 0, -1)
|
||||
assert.NotNil(t, err)
|
||||
pairs, err = client.ZrangeWithScores("key", 0, -1)
|
||||
assert.Nil(t, err)
|
||||
@@ -831,7 +845,7 @@ func TestRedis_SortedSet(t *testing.T) {
|
||||
Score: 8,
|
||||
},
|
||||
}, pairs)
|
||||
_, err = NewRedis(client.Addr, "").ZrangebyscoreWithScores("key", 5, 8)
|
||||
_, err = New(client.Addr, badType()).ZrangebyscoreWithScores("key", 5, 8)
|
||||
assert.NotNil(t, err)
|
||||
pairs, err = client.ZrangebyscoreWithScores("key", 5, 8)
|
||||
assert.Nil(t, err)
|
||||
@@ -845,7 +859,7 @@ func TestRedis_SortedSet(t *testing.T) {
|
||||
Score: 8,
|
||||
},
|
||||
}, pairs)
|
||||
_, err = NewRedis(client.Addr, "").ZrangebyscoreWithScoresAndLimit(
|
||||
_, err = New(client.Addr, badType()).ZrangebyscoreWithScoresAndLimit(
|
||||
"key", 5, 8, 1, 1)
|
||||
assert.NotNil(t, err)
|
||||
pairs, err = client.ZrangebyscoreWithScoresAndLimit("key", 5, 8, 1, 1)
|
||||
@@ -859,7 +873,7 @@ func TestRedis_SortedSet(t *testing.T) {
|
||||
pairs, err = client.ZrangebyscoreWithScoresAndLimit("key", 5, 8, 1, 0)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 0, len(pairs))
|
||||
_, err = NewRedis(client.Addr, "").ZrevrangebyscoreWithScores("key", 5, 8)
|
||||
_, err = New(client.Addr, badType()).ZrevrangebyscoreWithScores("key", 5, 8)
|
||||
assert.NotNil(t, err)
|
||||
pairs, err = client.ZrevrangebyscoreWithScores("key", 5, 8)
|
||||
assert.Nil(t, err)
|
||||
@@ -873,7 +887,7 @@ func TestRedis_SortedSet(t *testing.T) {
|
||||
Score: 5,
|
||||
},
|
||||
}, pairs)
|
||||
_, err = NewRedis(client.Addr, "").ZrevrangebyscoreWithScoresAndLimit(
|
||||
_, err = New(client.Addr, badType()).ZrevrangebyscoreWithScoresAndLimit(
|
||||
"key", 5, 8, 1, 1)
|
||||
assert.NotNil(t, err)
|
||||
pairs, err = client.ZrevrangebyscoreWithScoresAndLimit("key", 5, 8, 1, 1)
|
||||
@@ -887,7 +901,7 @@ func TestRedis_SortedSet(t *testing.T) {
|
||||
pairs, err = client.ZrevrangebyscoreWithScoresAndLimit("key", 5, 8, 1, 0)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 0, len(pairs))
|
||||
_, err = NewRedis(client.Addr, "").Zrevrank("key", "value")
|
||||
_, err = New(client.Addr, badType()).Zrevrank("key", "value")
|
||||
assert.NotNil(t, err)
|
||||
client.Zadd("second", 2, "aa")
|
||||
client.Zadd("third", 3, "bbb")
|
||||
@@ -897,6 +911,8 @@ func TestRedis_SortedSet(t *testing.T) {
|
||||
}, "second", "third")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(2), val)
|
||||
_, err = New(client.Addr, badType()).Zunionstore("union", ZStore{})
|
||||
assert.NotNil(t, err)
|
||||
vals, err = client.Zrange("union", 0, 10000)
|
||||
assert.Nil(t, err)
|
||||
assert.EqualValues(t, []string{"aa", "bbb"}, vals)
|
||||
@@ -908,7 +924,7 @@ func TestRedis_SortedSet(t *testing.T) {
|
||||
|
||||
func TestRedis_Pipelined(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
assert.NotNil(t, NewRedis(client.Addr, "").Pipelined(func(pipeliner Pipeliner) error {
|
||||
assert.NotNil(t, New(client.Addr, badType()).Pipelined(func(pipeliner Pipeliner) error {
|
||||
return nil
|
||||
}))
|
||||
err := client.Pipelined(
|
||||
@@ -920,7 +936,7 @@ func TestRedis_Pipelined(t *testing.T) {
|
||||
},
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
_, err = NewRedis(client.Addr, "").Ttl("pipelined_counter")
|
||||
_, err = New(client.Addr, badType()).Ttl("pipelined_counter")
|
||||
assert.NotNil(t, err)
|
||||
ttl, err := client.Ttl("pipelined_counter")
|
||||
assert.Nil(t, err)
|
||||
@@ -940,20 +956,31 @@ func TestRedisString(t *testing.T) {
|
||||
_, err := getRedis(NewRedis(client.Addr, ClusterType))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, client.Addr, client.String())
|
||||
assert.NotNil(t, NewRedis(client.Addr, "").Ping())
|
||||
assert.NotNil(t, New(client.Addr, badType()).Ping())
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedisScriptLoad(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
client.Ping()
|
||||
_, err := NewRedis(client.Addr, "").scriptLoad("foo")
|
||||
_, err := New(client.Addr, badType()).ScriptLoad("foo")
|
||||
assert.NotNil(t, err)
|
||||
_, err = client.scriptLoad("foo")
|
||||
_, err = client.ScriptLoad("foo")
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedisEvalSha(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
client.Ping()
|
||||
scriptHash, err := client.ScriptLoad(`return redis.call("EXISTS", KEYS[1])`)
|
||||
assert.Nil(t, err)
|
||||
result, err := client.EvalSha(scriptHash, []string{"key1"})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(0), result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedisToPairs(t *testing.T) {
|
||||
pairs := toPairs([]red.Z{
|
||||
{
|
||||
@@ -1007,7 +1034,7 @@ func TestRedisBlpopEx(t *testing.T) {
|
||||
func TestRedisGeo(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
client.Ping()
|
||||
var geoLocation = []*GeoLocation{{Longitude: 13.361389, Latitude: 38.115556, Name: "Palermo"}, {Longitude: 15.087269, Latitude: 37.502669, Name: "Catania"}}
|
||||
geoLocation := []*GeoLocation{{Longitude: 13.361389, Latitude: 38.115556, Name: "Palermo"}, {Longitude: 15.087269, Latitude: 37.502669, Name: "Catania"}}
|
||||
v, err := client.GeoAdd("sicily", geoLocation...)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(2), v)
|
||||
@@ -1025,7 +1052,7 @@ func TestRedisGeo(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(v4[0].Dist), int64(190))
|
||||
assert.Equal(t, int64(v4[1].Dist), int64(56))
|
||||
var geoLocation2 = []*GeoLocation{{Longitude: 13.583333, Latitude: 37.316667, Name: "Agrigento"}}
|
||||
geoLocation2 := []*GeoLocation{{Longitude: 13.583333, Latitude: 37.316667, Name: "Agrigento"}}
|
||||
v5, err := client.GeoAdd("sicily", geoLocation2...)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(1), v5)
|
||||
@@ -1036,6 +1063,13 @@ func TestRedisGeo(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_WithPass(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
err := NewRedis(client.Addr, NodeType, "any").Ping()
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func runOnRedis(t *testing.T, fn func(client *Redis)) {
|
||||
s, err := miniredis.Run()
|
||||
assert.Nil(t, err)
|
||||
@@ -1051,10 +1085,35 @@ func runOnRedis(t *testing.T, fn func(client *Redis)) {
|
||||
client.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
fn(NewRedis(s.Addr(), NodeType))
|
||||
}
|
||||
|
||||
func runOnRedisTLS(t *testing.T, fn func(client *Redis)) {
|
||||
s, err := miniredis.RunTLS(&tls.Config{
|
||||
Certificates: make([]tls.Certificate, 1),
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
defer func() {
|
||||
client, err := clientManager.GetResource(s.Addr(), func() (io.Closer, error) {
|
||||
return nil, errors.New("should already exist")
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if client != nil {
|
||||
client.Close()
|
||||
}
|
||||
}()
|
||||
fn(New(s.Addr(), WithTLS()))
|
||||
}
|
||||
|
||||
func badType() Option {
|
||||
return func(r *Redis) {
|
||||
r.Type = "bad"
|
||||
}
|
||||
}
|
||||
|
||||
type mockedNode struct {
|
||||
RedisNode
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"io"
|
||||
|
||||
red "github.com/go-redis/redis"
|
||||
@@ -15,14 +16,21 @@ const (
|
||||
|
||||
var clientManager = syncx.NewResourceManager()
|
||||
|
||||
func getClient(server, pass string) (*red.Client, error) {
|
||||
val, err := clientManager.GetResource(server, func() (io.Closer, error) {
|
||||
func getClient(r *Redis) (*red.Client, error) {
|
||||
val, err := clientManager.GetResource(r.Addr, func() (io.Closer, error) {
|
||||
var tlsConfig *tls.Config
|
||||
if r.tls {
|
||||
tlsConfig = &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
}
|
||||
}
|
||||
store := red.NewClient(&red.Options{
|
||||
Addr: server,
|
||||
Password: pass,
|
||||
Addr: r.Addr,
|
||||
Password: r.Pass,
|
||||
DB: defaultDatabase,
|
||||
MaxRetries: maxRetries,
|
||||
MinIdleConns: idleConns,
|
||||
TLSConfig: tlsConfig,
|
||||
})
|
||||
store.WrapProcess(process)
|
||||
return store, nil
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"io"
|
||||
|
||||
red "github.com/go-redis/redis"
|
||||
@@ -9,13 +10,20 @@ import (
|
||||
|
||||
var clusterManager = syncx.NewResourceManager()
|
||||
|
||||
func getCluster(server, pass string) (*red.ClusterClient, error) {
|
||||
val, err := clusterManager.GetResource(server, func() (io.Closer, error) {
|
||||
func getCluster(r *Redis) (*red.ClusterClient, error) {
|
||||
val, err := clusterManager.GetResource(r.Addr, func() (io.Closer, error) {
|
||||
var tlsConfig *tls.Config
|
||||
if r.tls {
|
||||
tlsConfig = &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
}
|
||||
}
|
||||
store := red.NewClusterClient(&red.ClusterOptions{
|
||||
Addrs: []string{server},
|
||||
Password: pass,
|
||||
Addrs: []string{r.Addr},
|
||||
Password: r.Pass,
|
||||
MaxRetries: maxRetries,
|
||||
MinIdleConns: idleConns,
|
||||
TLSConfig: tlsConfig,
|
||||
})
|
||||
store.WrapProcess(process)
|
||||
|
||||
|
||||
@@ -53,7 +53,8 @@ func NewRedisLock(store *Redis, key string) *RedisLock {
|
||||
func (rl *RedisLock) Acquire() (bool, error) {
|
||||
seconds := atomic.LoadUint32(&rl.seconds)
|
||||
resp, err := rl.store.Eval(lockCommand, []string{rl.key}, []string{
|
||||
rl.id, strconv.Itoa(int(seconds)*millisPerSecond + tolerance)})
|
||||
rl.id, strconv.Itoa(int(seconds)*millisPerSecond + tolerance),
|
||||
})
|
||||
if err == red.Nil {
|
||||
return false, nil
|
||||
} else if err != nil {
|
||||
|
||||
@@ -24,6 +24,7 @@ type (
|
||||
ResultHandler func(sql.Result, error)
|
||||
|
||||
// A BulkInserter is used to batch insert records.
|
||||
// Postgresql is not supported yet, because of the sql is formated with symbol `$`.
|
||||
BulkInserter struct {
|
||||
executor *executors.PeriodicalExecutor
|
||||
inserter *dbInserter
|
||||
|
||||
@@ -206,7 +206,7 @@ func TestUnmarshalRowString(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUnmarshalRowStruct(t *testing.T) {
|
||||
var value = new(struct {
|
||||
value := new(struct {
|
||||
Name string
|
||||
Age int
|
||||
})
|
||||
@@ -224,7 +224,7 @@ func TestUnmarshalRowStruct(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUnmarshalRowStructWithTags(t *testing.T) {
|
||||
var value = new(struct {
|
||||
value := new(struct {
|
||||
Age int `db:"age"`
|
||||
Name string `db:"name"`
|
||||
})
|
||||
@@ -242,7 +242,7 @@ func TestUnmarshalRowStructWithTags(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUnmarshalRowStructWithTagsWrongColumns(t *testing.T) {
|
||||
var value = new(struct {
|
||||
value := new(struct {
|
||||
Age *int `db:"age"`
|
||||
Name string `db:"name"`
|
||||
})
|
||||
@@ -259,7 +259,7 @@ func TestUnmarshalRowStructWithTagsWrongColumns(t *testing.T) {
|
||||
|
||||
func TestUnmarshalRowsBool(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []bool{true, false}
|
||||
expect := []bool{true, false}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
@@ -273,7 +273,7 @@ func TestUnmarshalRowsBool(t *testing.T) {
|
||||
|
||||
func TestUnmarshalRowsInt(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []int{2, 3}
|
||||
expect := []int{2, 3}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
@@ -287,7 +287,7 @@ func TestUnmarshalRowsInt(t *testing.T) {
|
||||
|
||||
func TestUnmarshalRowsInt8(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []int8{2, 3}
|
||||
expect := []int8{2, 3}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
@@ -301,7 +301,7 @@ func TestUnmarshalRowsInt8(t *testing.T) {
|
||||
|
||||
func TestUnmarshalRowsInt16(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []int16{2, 3}
|
||||
expect := []int16{2, 3}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
@@ -315,7 +315,7 @@ func TestUnmarshalRowsInt16(t *testing.T) {
|
||||
|
||||
func TestUnmarshalRowsInt32(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []int32{2, 3}
|
||||
expect := []int32{2, 3}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
@@ -329,7 +329,7 @@ func TestUnmarshalRowsInt32(t *testing.T) {
|
||||
|
||||
func TestUnmarshalRowsInt64(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []int64{2, 3}
|
||||
expect := []int64{2, 3}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
@@ -343,7 +343,7 @@ func TestUnmarshalRowsInt64(t *testing.T) {
|
||||
|
||||
func TestUnmarshalRowsUint(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []uint{2, 3}
|
||||
expect := []uint{2, 3}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
@@ -357,7 +357,7 @@ func TestUnmarshalRowsUint(t *testing.T) {
|
||||
|
||||
func TestUnmarshalRowsUint8(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []uint8{2, 3}
|
||||
expect := []uint8{2, 3}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
@@ -371,7 +371,7 @@ func TestUnmarshalRowsUint8(t *testing.T) {
|
||||
|
||||
func TestUnmarshalRowsUint16(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []uint16{2, 3}
|
||||
expect := []uint16{2, 3}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
@@ -385,7 +385,7 @@ func TestUnmarshalRowsUint16(t *testing.T) {
|
||||
|
||||
func TestUnmarshalRowsUint32(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []uint32{2, 3}
|
||||
expect := []uint32{2, 3}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
@@ -399,7 +399,7 @@ func TestUnmarshalRowsUint32(t *testing.T) {
|
||||
|
||||
func TestUnmarshalRowsUint64(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []uint64{2, 3}
|
||||
expect := []uint64{2, 3}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
@@ -413,7 +413,7 @@ func TestUnmarshalRowsUint64(t *testing.T) {
|
||||
|
||||
func TestUnmarshalRowsFloat32(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []float32{2, 3}
|
||||
expect := []float32{2, 3}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
@@ -427,7 +427,7 @@ func TestUnmarshalRowsFloat32(t *testing.T) {
|
||||
|
||||
func TestUnmarshalRowsFloat64(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []float64{2, 3}
|
||||
expect := []float64{2, 3}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
@@ -441,7 +441,7 @@ func TestUnmarshalRowsFloat64(t *testing.T) {
|
||||
|
||||
func TestUnmarshalRowsString(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []string{"hello", "world"}
|
||||
expect := []string{"hello", "world"}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("hello\nworld")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
@@ -457,7 +457,7 @@ func TestUnmarshalRowsBoolPtr(t *testing.T) {
|
||||
yes := true
|
||||
no := false
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []*bool{&yes, &no}
|
||||
expect := []*bool{&yes, &no}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
@@ -473,7 +473,7 @@ func TestUnmarshalRowsIntPtr(t *testing.T) {
|
||||
two := 2
|
||||
three := 3
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []*int{&two, &three}
|
||||
expect := []*int{&two, &three}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
@@ -489,7 +489,7 @@ func TestUnmarshalRowsInt8Ptr(t *testing.T) {
|
||||
two := int8(2)
|
||||
three := int8(3)
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []*int8{&two, &three}
|
||||
expect := []*int8{&two, &three}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
@@ -505,7 +505,7 @@ func TestUnmarshalRowsInt16Ptr(t *testing.T) {
|
||||
two := int16(2)
|
||||
three := int16(3)
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []*int16{&two, &three}
|
||||
expect := []*int16{&two, &three}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
@@ -521,7 +521,7 @@ func TestUnmarshalRowsInt32Ptr(t *testing.T) {
|
||||
two := int32(2)
|
||||
three := int32(3)
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []*int32{&two, &three}
|
||||
expect := []*int32{&two, &three}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
@@ -537,7 +537,7 @@ func TestUnmarshalRowsInt64Ptr(t *testing.T) {
|
||||
two := int64(2)
|
||||
three := int64(3)
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []*int64{&two, &three}
|
||||
expect := []*int64{&two, &three}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
@@ -553,7 +553,7 @@ func TestUnmarshalRowsUintPtr(t *testing.T) {
|
||||
two := uint(2)
|
||||
three := uint(3)
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []*uint{&two, &three}
|
||||
expect := []*uint{&two, &three}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
@@ -569,7 +569,7 @@ func TestUnmarshalRowsUint8Ptr(t *testing.T) {
|
||||
two := uint8(2)
|
||||
three := uint8(3)
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []*uint8{&two, &three}
|
||||
expect := []*uint8{&two, &three}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
@@ -585,7 +585,7 @@ func TestUnmarshalRowsUint16Ptr(t *testing.T) {
|
||||
two := uint16(2)
|
||||
three := uint16(3)
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []*uint16{&two, &three}
|
||||
expect := []*uint16{&two, &three}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
@@ -601,7 +601,7 @@ func TestUnmarshalRowsUint32Ptr(t *testing.T) {
|
||||
two := uint32(2)
|
||||
three := uint32(3)
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []*uint32{&two, &three}
|
||||
expect := []*uint32{&two, &three}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
@@ -617,7 +617,7 @@ func TestUnmarshalRowsUint64Ptr(t *testing.T) {
|
||||
two := uint64(2)
|
||||
three := uint64(3)
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []*uint64{&two, &three}
|
||||
expect := []*uint64{&two, &three}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
@@ -633,7 +633,7 @@ func TestUnmarshalRowsFloat32Ptr(t *testing.T) {
|
||||
two := float32(2)
|
||||
three := float32(3)
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []*float32{&two, &three}
|
||||
expect := []*float32{&two, &three}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
@@ -649,7 +649,7 @@ func TestUnmarshalRowsFloat64Ptr(t *testing.T) {
|
||||
two := float64(2)
|
||||
three := float64(3)
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []*float64{&two, &three}
|
||||
expect := []*float64{&two, &three}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
@@ -665,7 +665,7 @@ func TestUnmarshalRowsStringPtr(t *testing.T) {
|
||||
hello := "hello"
|
||||
world := "world"
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []*string{&hello, &world}
|
||||
expect := []*string{&hello, &world}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("hello\nworld")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
@@ -678,7 +678,7 @@ func TestUnmarshalRowsStringPtr(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsStruct(t *testing.T) {
|
||||
var expect = []struct {
|
||||
expect := []struct {
|
||||
Name string
|
||||
Age int64
|
||||
}{
|
||||
@@ -711,7 +711,7 @@ func TestUnmarshalRowsStruct(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsStructWithNullStringType(t *testing.T) {
|
||||
var expect = []struct {
|
||||
expect := []struct {
|
||||
Name string
|
||||
NullString sql.NullString
|
||||
}{
|
||||
@@ -752,7 +752,7 @@ func TestUnmarshalRowsStructWithNullStringType(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsStructWithTags(t *testing.T) {
|
||||
var expect = []struct {
|
||||
expect := []struct {
|
||||
Name string
|
||||
Age int64
|
||||
}{
|
||||
@@ -789,7 +789,7 @@ func TestUnmarshalRowsStructAndEmbeddedAnonymousStructWithTags(t *testing.T) {
|
||||
Value int64 `db:"value"`
|
||||
}
|
||||
|
||||
var expect = []struct {
|
||||
expect := []struct {
|
||||
Name string
|
||||
Age int64
|
||||
Value int64
|
||||
@@ -831,7 +831,7 @@ func TestUnmarshalRowsStructAndEmbeddedStructPtrAnonymousWithTags(t *testing.T)
|
||||
Value int64 `db:"value"`
|
||||
}
|
||||
|
||||
var expect = []struct {
|
||||
expect := []struct {
|
||||
Name string
|
||||
Age int64
|
||||
Value int64
|
||||
@@ -869,7 +869,7 @@ func TestUnmarshalRowsStructAndEmbeddedStructPtrAnonymousWithTags(t *testing.T)
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsStructPtr(t *testing.T) {
|
||||
var expect = []*struct {
|
||||
expect := []*struct {
|
||||
Name string
|
||||
Age int64
|
||||
}{
|
||||
@@ -902,7 +902,7 @@ func TestUnmarshalRowsStructPtr(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsStructWithTagsPtr(t *testing.T) {
|
||||
var expect = []*struct {
|
||||
expect := []*struct {
|
||||
Name string
|
||||
Age int64
|
||||
}{
|
||||
@@ -935,7 +935,7 @@ func TestUnmarshalRowsStructWithTagsPtr(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsStructWithTagsPtrWithInnerPtr(t *testing.T) {
|
||||
var expect = []*struct {
|
||||
expect := []*struct {
|
||||
Name string
|
||||
Age int64
|
||||
}{
|
||||
|
||||
@@ -12,14 +12,10 @@ import (
|
||||
const slowThreshold = time.Millisecond * 500
|
||||
|
||||
func exec(conn sessionConn, q string, args ...interface{}) (sql.Result, error) {
|
||||
stmt, err := format(q, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
startTime := timex.Now()
|
||||
result, err := conn.Exec(q, args...)
|
||||
duration := timex.Since(startTime)
|
||||
stmt := formatForPrint(q, args)
|
||||
if duration > slowThreshold {
|
||||
logx.WithDuration(duration).Slowf("[SQL] exec: slowcall - %s", stmt)
|
||||
} else {
|
||||
@@ -33,10 +29,10 @@ func exec(conn sessionConn, q string, args ...interface{}) (sql.Result, error) {
|
||||
}
|
||||
|
||||
func execStmt(conn stmtConn, args ...interface{}) (sql.Result, error) {
|
||||
stmt := fmt.Sprint(args...)
|
||||
startTime := timex.Now()
|
||||
result, err := conn.Exec(args...)
|
||||
duration := timex.Since(startTime)
|
||||
stmt := fmt.Sprint(args...)
|
||||
if duration > slowThreshold {
|
||||
logx.WithDuration(duration).Slowf("[SQL] execStmt: slowcall - %s", stmt)
|
||||
} else {
|
||||
@@ -50,14 +46,10 @@ func execStmt(conn stmtConn, args ...interface{}) (sql.Result, error) {
|
||||
}
|
||||
|
||||
func query(conn sessionConn, scanner func(*sql.Rows) error, q string, args ...interface{}) error {
|
||||
stmt, err := format(q, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
startTime := timex.Now()
|
||||
rows, err := conn.Query(q, args...)
|
||||
duration := timex.Since(startTime)
|
||||
stmt := fmt.Sprint(args...)
|
||||
if duration > slowThreshold {
|
||||
logx.WithDuration(duration).Slowf("[SQL] query: slowcall - %s", stmt)
|
||||
} else {
|
||||
|
||||
@@ -16,7 +16,6 @@ func TestStmt_exec(t *testing.T) {
|
||||
name string
|
||||
args []interface{}
|
||||
delay bool
|
||||
formatError bool
|
||||
hasError bool
|
||||
err error
|
||||
lastInsertId int64
|
||||
@@ -28,12 +27,6 @@ func TestStmt_exec(t *testing.T) {
|
||||
lastInsertId: 1,
|
||||
rowsAffected: 2,
|
||||
},
|
||||
{
|
||||
name: "wrong format",
|
||||
args: []interface{}{1, 2},
|
||||
formatError: true,
|
||||
hasError: true,
|
||||
},
|
||||
{
|
||||
name: "exec error",
|
||||
args: []interface{}{1},
|
||||
@@ -70,18 +63,13 @@ func TestStmt_exec(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
for i, fn := range fns {
|
||||
i := i
|
||||
for _, fn := range fns {
|
||||
fn := fn
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
res, err := fn(test.args...)
|
||||
if i == 0 && test.formatError {
|
||||
assert.NotNil(t, err)
|
||||
return
|
||||
}
|
||||
if !test.formatError && test.hasError {
|
||||
if test.hasError {
|
||||
assert.NotNil(t, err)
|
||||
return
|
||||
}
|
||||
@@ -100,23 +88,16 @@ func TestStmt_exec(t *testing.T) {
|
||||
|
||||
func TestStmt_query(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []interface{}
|
||||
delay bool
|
||||
formatError bool
|
||||
hasError bool
|
||||
err error
|
||||
name string
|
||||
args []interface{}
|
||||
delay bool
|
||||
hasError bool
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "normal",
|
||||
args: []interface{}{1},
|
||||
},
|
||||
{
|
||||
name: "wrong format",
|
||||
args: []interface{}{1, 2},
|
||||
formatError: true,
|
||||
hasError: true,
|
||||
},
|
||||
{
|
||||
name: "query error",
|
||||
args: []interface{}{1},
|
||||
@@ -151,18 +132,13 @@ func TestStmt_query(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
for i, fn := range fns {
|
||||
i := i
|
||||
for _, fn := range fns {
|
||||
fn := fn
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := fn(test.args...)
|
||||
if i == 0 && test.formatError {
|
||||
assert.NotNil(t, err)
|
||||
return
|
||||
}
|
||||
if !test.formatError && test.hasError {
|
||||
if test.hasError {
|
||||
assert.NotNil(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -45,6 +45,24 @@ func escape(input string) string {
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func formatForPrint(query string, args ...interface{}) string {
|
||||
if len(args) == 0 {
|
||||
return query
|
||||
}
|
||||
|
||||
var vals []string
|
||||
for _, arg := range args {
|
||||
vals = append(vals, fmt.Sprintf("%q", mapping.Repr(arg)))
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteByte('[')
|
||||
b.WriteString(strings.Join(vals, ", "))
|
||||
b.WriteByte(']')
|
||||
|
||||
return strings.Join([]string{query, b.String()}, " ")
|
||||
}
|
||||
|
||||
func format(query string, args ...interface{}) (string, error) {
|
||||
numArgs := len(args)
|
||||
if numArgs == 0 {
|
||||
|
||||
@@ -28,3 +28,31 @@ func TestDesensitize_WithoutAccount(t *testing.T) {
|
||||
datasource = desensitize(datasource)
|
||||
assert.True(t, strings.Contains(datasource, "tcp(111.222.333.44:3306)"))
|
||||
}
|
||||
|
||||
func TestFormatForPrint(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
args []interface{}
|
||||
expect string
|
||||
}{
|
||||
{
|
||||
name: "no args",
|
||||
query: "select user, name from table where id=?",
|
||||
expect: `select user, name from table where id=?`,
|
||||
},
|
||||
{
|
||||
name: "one arg",
|
||||
query: "select user, name from table where id=?",
|
||||
args: []interface{}{"kevin"},
|
||||
expect: `select user, name from table where id=? ["kevin"]`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
actual := formatForPrint(test.query, test.args...)
|
||||
assert.Equal(t, test.expect, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ type (
|
||||
|
||||
// NewReplacer returns a Replacer.
|
||||
func NewReplacer(mapping map[string]string) Replacer {
|
||||
var rep = &replacer{
|
||||
rep := &replacer{
|
||||
mapping: mapping,
|
||||
}
|
||||
for k := range mapping {
|
||||
@@ -28,9 +28,9 @@ func NewReplacer(mapping map[string]string) Replacer {
|
||||
|
||||
func (r *replacer) Replace(text string) string {
|
||||
var builder strings.Builder
|
||||
var chars = []rune(text)
|
||||
var size = len(chars)
|
||||
var start = -1
|
||||
chars := []rune(text)
|
||||
size := len(chars)
|
||||
start := -1
|
||||
|
||||
for i := 0; i < size; i++ {
|
||||
child, ok := r.children[chars[i]]
|
||||
@@ -42,12 +42,12 @@ func (r *replacer) Replace(text string) string {
|
||||
if start < 0 {
|
||||
start = i
|
||||
}
|
||||
var end = -1
|
||||
end := -1
|
||||
if child.end {
|
||||
end = i + 1
|
||||
}
|
||||
|
||||
var j = i + 1
|
||||
j := i + 1
|
||||
for ; j < size; j++ {
|
||||
grandchild, ok := child.children[chars[j]]
|
||||
if !ok {
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
)
|
||||
|
||||
func TestReplacer_Replace(t *testing.T) {
|
||||
var mapping = map[string]string{
|
||||
mapping := map[string]string{
|
||||
"一二三四": "1234",
|
||||
"二三": "23",
|
||||
"二": "2",
|
||||
@@ -16,28 +16,28 @@ func TestReplacer_Replace(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestReplacer_ReplaceSingleChar(t *testing.T) {
|
||||
var mapping = map[string]string{
|
||||
mapping := map[string]string{
|
||||
"二": "2",
|
||||
}
|
||||
assert.Equal(t, "零一2三四五", NewReplacer(mapping).Replace("零一二三四五"))
|
||||
}
|
||||
|
||||
func TestReplacer_ReplaceExceedRange(t *testing.T) {
|
||||
var mapping = map[string]string{
|
||||
mapping := map[string]string{
|
||||
"二三四五六": "23456",
|
||||
}
|
||||
assert.Equal(t, "零一二三四五", NewReplacer(mapping).Replace("零一二三四五"))
|
||||
}
|
||||
|
||||
func TestReplacer_ReplacePartialMatch(t *testing.T) {
|
||||
var mapping = map[string]string{
|
||||
mapping := map[string]string{
|
||||
"二三四七": "2347",
|
||||
}
|
||||
assert.Equal(t, "零一二三四五", NewReplacer(mapping).Replace("零一二三四五"))
|
||||
}
|
||||
|
||||
func TestReplacer_ReplaceMultiMatches(t *testing.T) {
|
||||
var mapping = map[string]string{
|
||||
mapping := map[string]string{
|
||||
"二三": "23",
|
||||
}
|
||||
assert.Equal(t, "零一23四五一23四五", NewReplacer(mapping).Replace("零一二三四五一二三四五"))
|
||||
|
||||
@@ -9,7 +9,12 @@ type Barrier struct {
|
||||
|
||||
// Guard guards the given fn on the resource.
|
||||
func (b *Barrier) Guard(fn func()) {
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
Guard(&b.lock, fn)
|
||||
}
|
||||
|
||||
// Guard guards the given fn with lock.
|
||||
func Guard(lock sync.Locker, fn func()) {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
fn()
|
||||
}
|
||||
|
||||
@@ -38,3 +38,19 @@ func TestBarrierPtr_Guard(t *testing.T) {
|
||||
wg.Wait()
|
||||
assert.Equal(t, total, count)
|
||||
}
|
||||
|
||||
func TestGuard(t *testing.T) {
|
||||
const total = 10000
|
||||
var count int
|
||||
var lock sync.Mutex
|
||||
wg := new(sync.WaitGroup)
|
||||
wg.Add(total)
|
||||
for i := 0; i < total; i++ {
|
||||
go Guard(&lock, func() {
|
||||
count++
|
||||
wg.Done()
|
||||
})
|
||||
}
|
||||
wg.Wait()
|
||||
assert.Equal(t, total, count)
|
||||
}
|
||||
|
||||
2
go.mod
2
go.mod
@@ -6,7 +6,6 @@ require (
|
||||
github.com/ClickHouse/clickhouse-go v1.4.3
|
||||
github.com/DATA-DOG/go-sqlmock v1.4.1
|
||||
github.com/alicebob/miniredis/v2 v2.14.1
|
||||
github.com/antlr/antlr4 v0.0.0-20210105212045-464bcbc32de2
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.0 // indirect
|
||||
github.com/dgrijalva/jwt-go v3.2.0+incompatible
|
||||
github.com/emicklei/proto v1.9.0
|
||||
@@ -44,6 +43,7 @@ require (
|
||||
github.com/tmc/grpc-websocket-proxy v0.0.0-20171017195756-830351dc03c6 // indirect
|
||||
github.com/urfave/cli v1.22.5
|
||||
github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2
|
||||
github.com/zeromicro/antlr v0.0.1 // indirect
|
||||
go.etcd.io/etcd v0.0.0-20200402134248-51bdeb39e698
|
||||
go.uber.org/automaxprocs v1.3.0
|
||||
golang.org/x/lint v0.0.0-20200302205851-738671d3881b // indirect
|
||||
|
||||
4
go.sum
4
go.sum
@@ -276,6 +276,10 @@ github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2/go.mod h1:hzfGeI
|
||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/gopher-lua v0.0.0-20191220021717-ab39c6098bdb h1:ZkM6LRnq40pR1Ox0hTHlnpkcOTuFIDQpZ1IN8rKKhX0=
|
||||
github.com/yuin/gopher-lua v0.0.0-20191220021717-ab39c6098bdb/go.mod h1:gqRgreBUhTSL0GeU64rtZ3Uq3wtjOa/TB2YfrtkCbVQ=
|
||||
github.com/zeromicro/antlr v0.0.0-20210508120604-8d7a7786f5c4 h1:wt86gT5lsN+xWmij6lFCrDIqoxeOnqM2MxiO7cNm+Lo=
|
||||
github.com/zeromicro/antlr v0.0.0-20210508120604-8d7a7786f5c4/go.mod h1:nfpjEwFR6Q4xGDJMcZnCL9tEfQRgszMwu3rDz2Z+p5M=
|
||||
github.com/zeromicro/antlr v0.0.1 h1:CQpIn/dc0pUjgGQ81y98s/NGOm2Hfru2NNio2I9mQgk=
|
||||
github.com/zeromicro/antlr v0.0.1/go.mod h1:nfpjEwFR6Q4xGDJMcZnCL9tEfQRgszMwu3rDz2Z+p5M=
|
||||
go.etcd.io/bbolt v1.3.4 h1:hi1bXHMVrlQh6WwxAy+qZCV/SYIlqo+Ushwdpa4tAKg=
|
||||
go.etcd.io/bbolt v1.3.4/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ=
|
||||
go.etcd.io/etcd v0.0.0-20200402134248-51bdeb39e698 h1:jWtjCJX1qxhHISBMLRztWwR+EXkI7MJAF2HjHAE/x/I=
|
||||
|
||||
54
readme-cn.md
54
readme-cn.md
@@ -159,7 +159,17 @@ GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/tal-tech/
|
||||
|
||||
* API 文档
|
||||
|
||||
[https://www.yuque.com/tal-tech/go-zero](https://www.yuque.com/tal-tech/go-zero)
|
||||
[https://go-zero.dev/zh-hans/](https://zeromicro.github.io/go-zero)
|
||||
|
||||
* 常见问题
|
||||
|
||||
* 因为 `etcd` 和 `grpc` 兼容性问题,请使用 `grpc@v1.29.1`
|
||||
|
||||
`google.golang.org/grpc v1.29.1`
|
||||
|
||||
* 因为 `protobuf` 兼容性问题,请使用 `protocol-gen@v1.3.2`
|
||||
|
||||
`go get -u github.com/golang/protobuf/protoc-gen-go@v1.3.2`
|
||||
|
||||
* awesome 系列(更多文章见『微服务实践』公众号)
|
||||
* [快速构建高并发微服务](https://github.com/tal-tech/zero-doc/blob/main/doc/shorturl.md)
|
||||
@@ -175,13 +185,39 @@ GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/tal-tech/
|
||||
| [goctl-android](https://github.com/zeromicro/goctl-android) | 生成 `java (android)` 端 `http client` 请求代码 |
|
||||
| [goctl-go-compact](https://github.com/zeromicro/goctl-go-compact) | 合并 `api` 里同一个 `group` 里的 `handler` 到一个 `go` 文件 |
|
||||
|
||||
## 8. 微信公众号
|
||||
## 8. go-zero 用户
|
||||
|
||||
`go-zero` 相关文章都会在 `微服务实践` 公众号整理呈现,欢迎扫码关注,也可以通过公众号私信我 👏
|
||||
go-zero 已被许多公司用于生产部署,接入场景如在线教育、电商业务、游戏、区块链等,目前为止,已使用 go-zero 的公司包括但不限于:
|
||||
|
||||
>1. 好未来
|
||||
>2. 上海晓信信息科技有限公司(晓黑板)
|
||||
>3. 上海玉数科技有限公司
|
||||
>4. 常州千帆网络科技有限公司
|
||||
>5. 上班族科技
|
||||
>6. 英雄体育(VSPN)
|
||||
>7. githubmemory
|
||||
>8. 释空(上海)品牌策划有限公司(senkoo)
|
||||
>9. 鞍山三合众鑫科技有限公司
|
||||
>10. 广州星梦工场网络科技有限公司
|
||||
>11. 杭州复杂美科技有限公司
|
||||
>12. 赛凌科技
|
||||
>13. 捞月狗
|
||||
>14. 浙江三合通信科技有限公司
|
||||
>15. 爱克萨
|
||||
>16. 郑州众合互联信息技术有限公司
|
||||
>17. 三七游戏
|
||||
>18. 成都创道夫科技有限公司
|
||||
>19. 联想Lenovo
|
||||
|
||||
如果贵公司也已使用 go-zero,欢迎在 [登记地址](https://github.com/tal-tech/go-zero/issues/602) 登记,仅仅为了推广,不做其它用途。
|
||||
|
||||
## 9. 微信公众号
|
||||
|
||||
`go-zero` 相关文章和视频都会在 `微服务实践` 公众号整理呈现,欢迎扫码关注 👏
|
||||
|
||||
<img src="https://gitee.com/kevwan/static/raw/master/images/wechat-micro.jpg" alt="wechat" width="300" />
|
||||
|
||||
## 9. 微信交流群
|
||||
## 10. 微信交流群
|
||||
|
||||
如果文档中未能覆盖的任何疑问,欢迎您在群里提出,我们会尽快答复。
|
||||
|
||||
@@ -189,12 +225,6 @@ GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/tal-tech/
|
||||
|
||||
如果您发现 ***bug*** 请及时提 ***issue***,我们会尽快确认并修改。
|
||||
|
||||
为了防止广告用户、识别技术同行,请 ***star*** 后加我时注明 **github** 当前 ***star*** 数,我再拉进 **go-zero** 群,感谢!
|
||||
加群之前有劳点一下 ***star***,一个小小的 ***star*** 是作者们回答海量问题的动力🤝
|
||||
|
||||
加我之前有劳点一下 ***star***,一个小小的 ***star*** 是作者们回答海量问题的动力🤝
|
||||
|
||||
<img src="https://gitee.com/kevwan/static/raw/master/images/wechat.jpg" alt="wechat" width="300" />
|
||||
|
||||
项目地址:[https://github.com/tal-tech/go-zero](https://github.com/tal-tech/go-zero)
|
||||
|
||||
码云地址:[https://gitee.com/kevwan/go-zero](https://gitee.com/kevwan/go-zero) (国内用户可访问gitee,每日自动从github同步代码)
|
||||
<img src="https://raw.githubusercontent.com/tal-tech/zero-doc/main/doc/images/wechat.jpg" alt="wechat" width="300" />
|
||||
24
readme.md
24
readme.md
@@ -12,7 +12,7 @@ English | [简体中文](readme-cn.md)
|
||||
|
||||
## 0. what is go-zero
|
||||
|
||||
go-zero is a web and rpc framework that with lots of engineering practices builtin. It’s born to ensure the stability of the busy services with resilience design, and has been serving sites with tens of millions users for years.
|
||||
go-zero is a web and rpc framework with lots of builtin engineering practices. It’s born to ensure the stability of the busy services with resilience design, and has been serving sites with tens of millions users for years.
|
||||
|
||||
go-zero contains simple API description syntax and code generation tool called `goctl`. You can generate Go, iOS, Android, Kotlin, Dart, TypeScript, JavaScript from .api files with `goctl`.
|
||||
|
||||
@@ -115,11 +115,11 @@ go get -u github.com/tal-tech/go-zero
|
||||
type Request struct {
|
||||
Name string `path:"name,options=you|me"` // parameters are auto validated
|
||||
}
|
||||
|
||||
|
||||
type Response struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
|
||||
service greet-api {
|
||||
@handler GreetHandler
|
||||
get /greet/from/:name(Request) returns (Response);
|
||||
@@ -200,6 +200,8 @@ go get -u github.com/tal-tech/go-zero
|
||||
|
||||
## 7. Benchmark
|
||||
|
||||
Document: [https://go-zero.dev/en/](https://go-zero.dev/en/)
|
||||
|
||||

|
||||
|
||||
[Checkout the test code](https://github.com/smallnest/go-web-framework-benchmark)
|
||||
@@ -210,6 +212,20 @@ go get -u github.com/tal-tech/go-zero
|
||||
* [Rapid development of microservice systems - multiple RPCs](https://github.com/tal-tech/zero-doc/blob/main/docs/zero/bookstore-en.md)
|
||||
* [Examples](https://github.com/zeromicro/zero-examples)
|
||||
|
||||
## 9. Chat group
|
||||
## 9. Important notes
|
||||
|
||||
* Use grpc 1.29.1, because etcd lib doesn’t support latter versions.
|
||||
|
||||
`google.golang.org/grpc v1.29.1`
|
||||
|
||||
* For protobuf compatibility, use `protocol-gen@v1.3.2`.
|
||||
|
||||
` go get -u github.com/golang/protobuf/protoc-gen-go@v1.3.2`
|
||||
|
||||
## 10. Chat group
|
||||
|
||||
Join the chat via https://join.slack.com/t/go-zeroworkspace/shared_invite/zt-m39xssxc-kgIqERa7aVsujKNj~XuPKg
|
||||
|
||||
## Give a Star! ⭐
|
||||
|
||||
If you like or are using this project to learn or start your solution, please give it a star. Thanks!
|
||||
@@ -109,13 +109,13 @@ func (s *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat
|
||||
chain := alice.New(
|
||||
handler.TracingHandler,
|
||||
s.getLogHandler(),
|
||||
handler.PrometheusHandler(route.Path),
|
||||
handler.MaxConns(s.conf.MaxConns),
|
||||
handler.BreakerHandler(route.Method, route.Path, metrics),
|
||||
handler.SheddingHandler(s.getShedder(fr.priority), metrics),
|
||||
handler.TimeoutHandler(time.Duration(s.conf.Timeout)*time.Millisecond),
|
||||
handler.RecoverHandler,
|
||||
handler.MetricHandler(metrics),
|
||||
handler.PrometheusHandler(route.Path),
|
||||
handler.MaxBytesHandler(s.conf.MaxBytes),
|
||||
handler.GunzipHandler,
|
||||
)
|
||||
|
||||
@@ -154,8 +154,7 @@ Verbose: true
|
||||
}
|
||||
}
|
||||
|
||||
type mockedRouter struct {
|
||||
}
|
||||
type mockedRouter struct{}
|
||||
|
||||
func (m mockedRouter) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
|
||||
}
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
|
||||
@@ -138,6 +140,16 @@ func (grw *guardedResponseWriter) Header() http.Header {
|
||||
return grw.writer.Header()
|
||||
}
|
||||
|
||||
// Hijack implements the http.Hijacker interface.
|
||||
// This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it.
|
||||
func (grw *guardedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
if hijacked, ok := grw.writer.(http.Hijacker); ok {
|
||||
return hijacked.Hijack()
|
||||
}
|
||||
|
||||
return nil, nil, errors.New("server doesn't support hijacking")
|
||||
}
|
||||
|
||||
func (grw *guardedResponseWriter) Write(body []byte) (int, error) {
|
||||
return grw.writer.Write(body)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
@@ -87,6 +89,26 @@ func TestAuthHandler_NilError(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthHandler_Flush(t *testing.T) {
|
||||
resp := httptest.NewRecorder()
|
||||
handler := newGuardedResponseWriter(resp)
|
||||
handler.Flush()
|
||||
assert.True(t, resp.Flushed)
|
||||
}
|
||||
|
||||
func TestAuthHandler_Hijack(t *testing.T) {
|
||||
resp := httptest.NewRecorder()
|
||||
writer := newGuardedResponseWriter(resp)
|
||||
assert.NotPanics(t, func() {
|
||||
writer.Hijack()
|
||||
})
|
||||
|
||||
writer = newGuardedResponseWriter(mockedHijackable{resp})
|
||||
assert.NotPanics(t, func() {
|
||||
writer.Hijack()
|
||||
})
|
||||
}
|
||||
|
||||
func buildToken(secretKey string, payloads map[string]interface{}, seconds int64) (string, error) {
|
||||
now := time.Now().Unix()
|
||||
claims := make(jwt.MapClaims)
|
||||
@@ -101,3 +123,11 @@ func buildToken(secretKey string, payloads map[string]interface{}, seconds int64
|
||||
|
||||
return token.SignedString([]byte(secretKey))
|
||||
}
|
||||
|
||||
type mockedHijackable struct {
|
||||
*httptest.ResponseRecorder
|
||||
}
|
||||
|
||||
func (m mockedHijackable) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/codec"
|
||||
@@ -94,6 +96,16 @@ func (w *cryptionResponseWriter) Header() http.Header {
|
||||
return w.ResponseWriter.Header()
|
||||
}
|
||||
|
||||
// Hijack implements the http.Hijacker interface.
|
||||
// This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it.
|
||||
func (w *cryptionResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
if hijacked, ok := w.ResponseWriter.(http.Hijacker); ok {
|
||||
return hijacked.Hijack()
|
||||
}
|
||||
|
||||
return nil, nil, errors.New("server doesn't support hijacking")
|
||||
}
|
||||
|
||||
func (w *cryptionResponseWriter) Write(p []byte) (int, error) {
|
||||
return w.buf.Write(p)
|
||||
}
|
||||
|
||||
@@ -103,3 +103,16 @@ func TestCryptionHandlerFlush(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String())
|
||||
}
|
||||
|
||||
func TestCryptionHandler_Hijack(t *testing.T) {
|
||||
resp := httptest.NewRecorder()
|
||||
writer := newCryptionResponseWriter(resp)
|
||||
assert.NotPanics(t, func() {
|
||||
writer.Hijack()
|
||||
})
|
||||
|
||||
writer = newCryptionResponseWriter(mockedHijackable{resp})
|
||||
assert.NotPanics(t, func() {
|
||||
writer.Hijack()
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"time"
|
||||
@@ -25,10 +28,26 @@ type loggedResponseWriter struct {
|
||||
code int
|
||||
}
|
||||
|
||||
func (w *loggedResponseWriter) Flush() {
|
||||
if flusher, ok := w.w.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (w *loggedResponseWriter) Header() http.Header {
|
||||
return w.w.Header()
|
||||
}
|
||||
|
||||
// Hijack implements the http.Hijacker interface.
|
||||
// This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it.
|
||||
func (w *loggedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
if hijacked, ok := w.w.(http.Hijacker); ok {
|
||||
return hijacked.Hijack()
|
||||
}
|
||||
|
||||
return nil, nil, errors.New("server doesn't support hijacking")
|
||||
}
|
||||
|
||||
func (w *loggedResponseWriter) Write(bytes []byte) (int, error) {
|
||||
return w.w.Write(bytes)
|
||||
}
|
||||
@@ -38,12 +57,6 @@ func (w *loggedResponseWriter) WriteHeader(code int) {
|
||||
w.code = code
|
||||
}
|
||||
|
||||
func (w *loggedResponseWriter) Flush() {
|
||||
if flusher, ok := w.w.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// LogHandler returns a middleware that logs http request and response.
|
||||
func LogHandler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -83,6 +96,16 @@ func (w *detailLoggedResponseWriter) Header() http.Header {
|
||||
return w.writer.Header()
|
||||
}
|
||||
|
||||
// Hijack implements the http.Hijacker interface.
|
||||
// This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it.
|
||||
func (w *detailLoggedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
if hijacked, ok := w.writer.w.(http.Hijacker); ok {
|
||||
return hijacked.Hijack()
|
||||
}
|
||||
|
||||
return nil, nil, errors.New("server doesn't support hijacking")
|
||||
}
|
||||
|
||||
func (w *detailLoggedResponseWriter) Write(bs []byte) (int, error) {
|
||||
w.buf.Write(bs)
|
||||
return w.writer.Write(bs)
|
||||
|
||||
@@ -62,6 +62,44 @@ func TestLogHandlerSlow(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogHandler_Hijack(t *testing.T) {
|
||||
resp := httptest.NewRecorder()
|
||||
writer := &loggedResponseWriter{
|
||||
w: resp,
|
||||
}
|
||||
assert.NotPanics(t, func() {
|
||||
writer.Hijack()
|
||||
})
|
||||
|
||||
writer = &loggedResponseWriter{
|
||||
w: mockedHijackable{resp},
|
||||
}
|
||||
assert.NotPanics(t, func() {
|
||||
writer.Hijack()
|
||||
})
|
||||
}
|
||||
|
||||
func TestDetailedLogHandler_Hijack(t *testing.T) {
|
||||
resp := httptest.NewRecorder()
|
||||
writer := &detailLoggedResponseWriter{
|
||||
writer: &loggedResponseWriter{
|
||||
w: resp,
|
||||
},
|
||||
}
|
||||
assert.NotPanics(t, func() {
|
||||
writer.Hijack()
|
||||
})
|
||||
|
||||
writer = &detailLoggedResponseWriter{
|
||||
writer: &loggedResponseWriter{
|
||||
w: mockedHijackable{resp},
|
||||
},
|
||||
}
|
||||
assert.NotPanics(t, func() {
|
||||
writer.Hijack()
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkLogHandler(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/metric"
|
||||
"github.com/tal-tech/go-zero/core/prometheus"
|
||||
"github.com/tal-tech/go-zero/core/timex"
|
||||
"github.com/tal-tech/go-zero/rest/internal/security"
|
||||
)
|
||||
@@ -34,6 +35,10 @@ var (
|
||||
// PrometheusHandler returns a middleware that reports stats to prometheus.
|
||||
func PrometheusHandler(path string) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
if !prometheus.Enabled() {
|
||||
return next
|
||||
}
|
||||
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
startTime := timex.Now()
|
||||
cw := &security.WithCodeResponseWriter{Writer: w}
|
||||
|
||||
@@ -6,9 +6,26 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tal-tech/go-zero/core/prometheus"
|
||||
)
|
||||
|
||||
func TestPromMetricHandler(t *testing.T) {
|
||||
func TestPromMetricHandler_Disabled(t *testing.T) {
|
||||
promMetricHandler := PrometheusHandler("/user/login")
|
||||
handler := promMetricHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(resp, req)
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
}
|
||||
|
||||
func TestPromMetricHandler_Enabled(t *testing.T) {
|
||||
prometheus.StartAgent(prometheus.Config{
|
||||
Host: "localhost",
|
||||
Path: "/",
|
||||
})
|
||||
promMetricHandler := PrometheusHandler("/user/login")
|
||||
handler := promMetricHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
@@ -94,8 +94,7 @@ func (s mockShedder) Allow() (load.Promise, error) {
|
||||
return nil, load.ErrServiceOverloaded
|
||||
}
|
||||
|
||||
type mockPromise struct {
|
||||
}
|
||||
type mockPromise struct{}
|
||||
|
||||
func (p mockPromise) Pass() {
|
||||
}
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
package security
|
||||
|
||||
import "net/http"
|
||||
import (
|
||||
"bufio"
|
||||
"net"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// A WithCodeResponseWriter is a helper to delay sealing a http.ResponseWriter on writing code.
|
||||
type WithCodeResponseWriter struct {
|
||||
@@ -20,6 +24,12 @@ func (w *WithCodeResponseWriter) Header() http.Header {
|
||||
return w.Writer.Header()
|
||||
}
|
||||
|
||||
// Hijack implements the http.Hijacker interface.
|
||||
// This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it.
|
||||
func (w *WithCodeResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
return w.Writer.(http.Hijacker).Hijack()
|
||||
}
|
||||
|
||||
// Write writes bytes into w.
|
||||
func (w *WithCodeResponseWriter) Write(bytes []byte) (int, error) {
|
||||
return w.Writer.Write(bytes)
|
||||
|
||||
@@ -77,6 +77,8 @@ func (e *Server) AddRoute(r Route, opts ...RouteOption) {
|
||||
}
|
||||
|
||||
// Start starts the Server.
|
||||
// Graceful shutdown is enabled by default.
|
||||
// Use proc.SetTimeToForceQuit to customize the graceful shutdown period.
|
||||
func (e *Server) Start() {
|
||||
handleError(e.opts.start(e.ngin))
|
||||
}
|
||||
@@ -108,7 +110,7 @@ func WithJwt(secret string) RouteOption {
|
||||
}
|
||||
|
||||
// WithJwtTransition returns a func to enable jwt authentication as well as jwt secret transition.
|
||||
// Which means old and new jwt secrets work together for a peroid.
|
||||
// Which means old and new jwt secrets work together for a period.
|
||||
func WithJwtTransition(secret, prevSecret string) RouteOption {
|
||||
return func(r *featuredRoutes) {
|
||||
// why not validate prevSecret, because prevSecret is an already used one,
|
||||
|
||||
@@ -29,7 +29,7 @@ Future {{pathToFuncName .Path}}( {{if ne .Method "get"}}{{with .RequestType}}{{.
|
||||
{{end}}`
|
||||
|
||||
func genApi(dir string, api *spec.ApiSpec) error {
|
||||
err := os.MkdirAll(dir, 0755)
|
||||
err := os.MkdirAll(dir, 0o755)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -39,7 +39,7 @@ func genApi(dir string, api *spec.ApiSpec) error {
|
||||
return err
|
||||
}
|
||||
|
||||
file, err := os.OpenFile(dir+api.Service.Name+".dart", os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
|
||||
file, err := os.OpenFile(dir+api.Service.Name+".dart", os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -60,7 +60,7 @@ func genApiFile(dir string) error {
|
||||
if fileExists(path) {
|
||||
return nil
|
||||
}
|
||||
apiFile, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
|
||||
apiFile, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -32,7 +32,7 @@ class {{.Name}}{
|
||||
`
|
||||
|
||||
func genData(dir string, api *spec.ApiSpec) error {
|
||||
err := os.MkdirAll(dir, 0755)
|
||||
err := os.MkdirAll(dir, 0o755)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -42,7 +42,7 @@ func genData(dir string, api *spec.ApiSpec) error {
|
||||
return err
|
||||
}
|
||||
|
||||
file, err := os.OpenFile(dir+api.Service.Name+".dart", os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
|
||||
file, err := os.OpenFile(dir+api.Service.Name+".dart", os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -64,7 +64,7 @@ func genTokens(dir string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
tokensFile, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
|
||||
tokensFile, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -41,20 +41,20 @@ Future<Tokens> getTokens() async {
|
||||
`
|
||||
|
||||
func genVars(dir string) error {
|
||||
err := os.MkdirAll(dir, 0755)
|
||||
err := os.MkdirAll(dir, 0o755)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !fileExists(dir + "vars.dart") {
|
||||
err = ioutil.WriteFile(dir+"vars.dart", []byte(`const serverHost='demo-crm.xiaoheiban.cn';`), 0644)
|
||||
err = ioutil.WriteFile(dir+"vars.dart", []byte(`const serverHost='demo-crm.xiaoheiban.cn';`), 0o644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if !fileExists(dir + "kv.dart") {
|
||||
err = ioutil.WriteFile(dir+"kv.dart", []byte(varTemplate), 0644)
|
||||
err = ioutil.WriteFile(dir+"kv.dart", []byte(varTemplate), 0o644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -84,7 +84,7 @@ func buildDoc(route spec.Type) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
var tps = make([]spec.Type, 0)
|
||||
tps := make([]spec.Type, 0)
|
||||
tps = append(tps, route)
|
||||
if definedType, ok := route.(spec.DefineStruct); ok {
|
||||
associatedTypes(definedType, &tps)
|
||||
@@ -98,7 +98,7 @@ func buildDoc(route spec.Type) (string, error) {
|
||||
}
|
||||
|
||||
func associatedTypes(tp spec.DefineStruct, tps *[]spec.Type) {
|
||||
var hasAdded = false
|
||||
hasAdded := false
|
||||
for _, item := range *tps {
|
||||
if item.Name() == tp.Name() {
|
||||
hasAdded = true
|
||||
|
||||
@@ -107,8 +107,8 @@ func apiFormat(data string) (string, error) {
|
||||
|
||||
var builder strings.Builder
|
||||
s := bufio.NewScanner(strings.NewReader(data))
|
||||
var tapCount = 0
|
||||
var newLineCount = 0
|
||||
tapCount := 0
|
||||
newLineCount := 0
|
||||
var preLine string
|
||||
for s.Scan() {
|
||||
line := strings.TrimSpace(s.Text())
|
||||
|
||||
@@ -35,12 +35,12 @@ func genConfig(dir string, cfg *config.Config, api *spec.ApiSpec) error {
|
||||
return err
|
||||
}
|
||||
|
||||
var authNames = getAuths(api)
|
||||
authNames := getAuths(api)
|
||||
var auths []string
|
||||
for _, item := range authNames {
|
||||
auths = append(auths, fmt.Sprintf("%s %s", item, jwtTemplate))
|
||||
}
|
||||
var authImportStr = fmt.Sprintf("\"%s/rest\"", vars.ProjectOpenSourceURL)
|
||||
authImportStr := fmt.Sprintf("\"%s/rest\"", vars.ProjectOpenSourceURL)
|
||||
|
||||
return genFile(fileGenConfig{
|
||||
dir: dir,
|
||||
|
||||
@@ -31,7 +31,7 @@ func (m *{{.name}})Handle(next http.HandlerFunc) http.HandlerFunc {
|
||||
`
|
||||
|
||||
func genMiddleware(dir string, cfg *config.Config, api *spec.ApiSpec) error {
|
||||
var middlewares = getMiddleware(api)
|
||||
middlewares := getMiddleware(api)
|
||||
for _, item := range middlewares {
|
||||
middlewareFilename := strings.TrimSuffix(strings.ToLower(item), "middleware") + "_middleware"
|
||||
filename, err := format.FileNamingFormat(cfg.NamingFormat, middlewareFilename)
|
||||
|
||||
@@ -95,11 +95,11 @@ func genRoutes(dir string, cfg *config.Config, api *spec.ApiSpec) error {
|
||||
var routes string
|
||||
if len(g.middlewares) > 0 {
|
||||
gbuilder.WriteString("\n}...,")
|
||||
var params = g.middlewares
|
||||
params := g.middlewares
|
||||
for i := range params {
|
||||
params[i] = "serverCtx." + params[i]
|
||||
}
|
||||
var middlewareStr = strings.Join(params, ", ")
|
||||
middlewareStr := strings.Join(params, ", ")
|
||||
routes = fmt.Sprintf("rest.WithMiddlewares(\n[]rest.Middleware{ %s }, \n %s \n),",
|
||||
middlewareStr, strings.TrimSpace(gbuilder.String()))
|
||||
} else {
|
||||
@@ -146,7 +146,7 @@ func genRoutes(dir string, cfg *config.Config, api *spec.ApiSpec) error {
|
||||
}
|
||||
|
||||
func genRouteImports(parentPkg string, api *spec.ApiSpec) string {
|
||||
var importSet = collection.NewSet()
|
||||
importSet := collection.NewSet()
|
||||
importSet.AddStr(fmt.Sprintf("\"%s\"", util.JoinPackages(parentPkg, contextDir)))
|
||||
for _, group := range api.Service.Groups {
|
||||
for _, route := range group.Routes {
|
||||
|
||||
@@ -39,7 +39,7 @@ func genServiceContext(dir string, cfg *config.Config, api *spec.ApiSpec) error
|
||||
return err
|
||||
}
|
||||
|
||||
var authNames = getAuths(api)
|
||||
authNames := getAuths(api)
|
||||
var auths []string
|
||||
for _, item := range authNames {
|
||||
auths = append(auths, fmt.Sprintf("%s config.AuthConfig", item))
|
||||
@@ -52,7 +52,7 @@ func genServiceContext(dir string, cfg *config.Config, api *spec.ApiSpec) error
|
||||
|
||||
var middlewareStr string
|
||||
var middlewareAssignment string
|
||||
var middlewares = getMiddleware(api)
|
||||
middlewares := getMiddleware(api)
|
||||
|
||||
for _, item := range middlewares {
|
||||
middlewareStr += fmt.Sprintf("%s rest.Middleware\n", item)
|
||||
@@ -61,7 +61,7 @@ func genServiceContext(dir string, cfg *config.Config, api *spec.ApiSpec) error
|
||||
fmt.Sprintf("middleware.New%s().%s", strings.Title(name), "Handle"))
|
||||
}
|
||||
|
||||
var configImport = "\"" + ctlutil.JoinPackages(parentPkg, configDir) + "\""
|
||||
configImport := "\"" + ctlutil.JoinPackages(parentPkg, configDir) + "\""
|
||||
if len(middlewareStr) > 0 {
|
||||
configImport += "\n\t\"" + ctlutil.JoinPackages(parentPkg, middlewareDir) + "\""
|
||||
configImport += fmt.Sprintf("\n\t\"%s/rest\"", vars.ProjectOpenSourceURL)
|
||||
|
||||
@@ -103,17 +103,9 @@ func genComponents(dir, packetName string, api *spec.ApiSpec) error {
|
||||
}
|
||||
|
||||
func (c *componentsContext) createComponent(dir, packetName string, ty spec.Type) error {
|
||||
defineStruct, ok := ty.(spec.DefineStruct)
|
||||
if !ok {
|
||||
return errors.New("unsupported type %s" + ty.Name())
|
||||
}
|
||||
|
||||
for _, item := range c.requestTypes {
|
||||
if item.Name() == defineStruct.Name() {
|
||||
if len(defineStruct.GetFormMembers())+len(defineStruct.GetBodyMembers()) == 0 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
defineStruct, done, err := c.checkStruct(ty)
|
||||
if done {
|
||||
return err
|
||||
}
|
||||
|
||||
modelFile := util.Title(ty.Name()) + ".java"
|
||||
@@ -181,6 +173,22 @@ func (c *componentsContext) createComponent(dir, packetName string, ty spec.Type
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *componentsContext) checkStruct(ty spec.Type) (spec.DefineStruct, bool, error) {
|
||||
defineStruct, ok := ty.(spec.DefineStruct)
|
||||
if !ok {
|
||||
return spec.DefineStruct{}, true, errors.New("unsupported type %s" + ty.Name())
|
||||
}
|
||||
|
||||
for _, item := range c.requestTypes {
|
||||
if item.Name() == defineStruct.Name() {
|
||||
if len(defineStruct.GetFormMembers())+len(defineStruct.GetBodyMembers()) == 0 {
|
||||
return spec.DefineStruct{}, true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return defineStruct, false, nil
|
||||
}
|
||||
|
||||
func (c *componentsContext) buildProperties(defineStruct spec.DefineStruct) (string, error) {
|
||||
var builder strings.Builder
|
||||
if err := c.writeType(&builder, defineStruct); err != nil {
|
||||
@@ -269,15 +277,15 @@ func (c *componentsContext) buildConstructor() (string, string, error) {
|
||||
}
|
||||
|
||||
func (c *componentsContext) genGetSet(writer io.Writer, indent int) error {
|
||||
var members = c.members
|
||||
members := c.members
|
||||
for _, member := range members {
|
||||
javaType, err := specTypeToJava(member.Type)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var property = util.Title(member.Name)
|
||||
var templateStr = getSetTemplate
|
||||
property := util.Title(member.Name)
|
||||
templateStr := getSetTemplate
|
||||
if javaType == "boolean" {
|
||||
templateStr = boolTemplate
|
||||
property = strings.TrimPrefix(property, "Is")
|
||||
|
||||
@@ -67,7 +67,7 @@ func createWith(dir string, api *spec.ApiSpec, route spec.Route, packetName stri
|
||||
}
|
||||
defer fp.Close()
|
||||
|
||||
var hasRequestBody = false
|
||||
hasRequestBody := false
|
||||
if route.RequestType != nil {
|
||||
if defineStruct, ok := route.RequestType.(spec.DefineStruct); ok {
|
||||
hasRequestBody = len(defineStruct.GetBodyMembers()) > 0 || len(defineStruct.GetFormMembers()) > 0
|
||||
|
||||
@@ -61,7 +61,7 @@ func writeIndent(writer io.Writer, indent int) {
|
||||
}
|
||||
|
||||
func indentString(indent int) string {
|
||||
var result = ""
|
||||
result := ""
|
||||
for i := 0; i < indent; i++ {
|
||||
result += "\t"
|
||||
}
|
||||
@@ -95,17 +95,9 @@ func specTypeToJava(tp spec.Type) (string, error) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
switch valueType {
|
||||
case "int":
|
||||
return "Integer[]", nil
|
||||
case "long":
|
||||
return "Long[]", nil
|
||||
case "float":
|
||||
return "Float[]", nil
|
||||
case "double":
|
||||
return "Double[]", nil
|
||||
case "boolean":
|
||||
return "Boolean[]", nil
|
||||
s := getBaseType(valueType)
|
||||
if len(s) == 0 {
|
||||
return s, errors.New("unsupported primitive type " + tp.Name())
|
||||
}
|
||||
|
||||
return fmt.Sprintf("java.util.ArrayList<%s>", util.Title(valueType)), nil
|
||||
@@ -118,6 +110,23 @@ func specTypeToJava(tp spec.Type) (string, error) {
|
||||
return "", errors.New("unsupported primitive type " + tp.Name())
|
||||
}
|
||||
|
||||
func getBaseType(valueType string) string {
|
||||
switch valueType {
|
||||
case "int":
|
||||
return "Integer[]"
|
||||
case "long":
|
||||
return "Long[]"
|
||||
case "float":
|
||||
return "Float[]"
|
||||
case "double":
|
||||
return "Double[]"
|
||||
case "boolean":
|
||||
return "Boolean[]"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func primitiveType(tp string) (string, bool) {
|
||||
switch tp {
|
||||
case "string":
|
||||
|
||||
@@ -98,7 +98,7 @@ object {{with .Info}}{{.Title}}{{end}}{
|
||||
)
|
||||
|
||||
func genBase(dir, pkg string, api *spec.ApiSpec) error {
|
||||
e := os.MkdirAll(dir, 0755)
|
||||
e := os.MkdirAll(dir, 0o755)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
@@ -108,7 +108,7 @@ func genBase(dir, pkg string, api *spec.ApiSpec) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
file, e := os.OpenFile(path, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
|
||||
file, e := os.OpenFile(path, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o644)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
@@ -146,12 +146,12 @@ func genApi(dir, pkg string, api *spec.ApiSpec) error {
|
||||
api.Info.Title = name
|
||||
api.Info.Desc = desc
|
||||
|
||||
e := os.MkdirAll(dir, 0755)
|
||||
e := os.MkdirAll(dir, 0o755)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
|
||||
file, e := os.OpenFile(path, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0644)
|
||||
file, e := os.OpenFile(path, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0o644)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
|
||||
@@ -33,7 +33,7 @@ func (v *ApiVisitor) VisitApi(ctx *api.ApiContext) interface{} {
|
||||
for _, each := range ctx.AllSpec() {
|
||||
root := each.Accept(v).(*Api)
|
||||
v.acceptSyntax(root, &final)
|
||||
v.accetpImport(root, &final)
|
||||
v.acceptImport(root, &final)
|
||||
v.acceptInfo(root, &final)
|
||||
v.acceptType(root, &final)
|
||||
v.acceptService(root, &final)
|
||||
@@ -133,7 +133,7 @@ func (v *ApiVisitor) acceptInfo(root *Api, final *Api) {
|
||||
}
|
||||
}
|
||||
|
||||
func (v *ApiVisitor) accetpImport(root *Api, final *Api) {
|
||||
func (v *ApiVisitor) acceptImport(root *Api, final *Api) {
|
||||
for _, imp := range root.Import {
|
||||
if _, ok := final.importM[imp.Value.Text()]; ok {
|
||||
v.panic(imp.Import, fmt.Sprintf("duplicate import '%s'", imp.Value.Text()))
|
||||
|
||||
@@ -6,9 +6,9 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/antlr/antlr4/runtime/Go/antlr"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/api/parser/g4/gen/api"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/util/console"
|
||||
"github.com/zeromicro/antlr"
|
||||
)
|
||||
|
||||
type (
|
||||
@@ -175,7 +175,7 @@ func (p *Parser) valid(mainApi *Api, nestedApi *Api) error {
|
||||
for _, g := range list {
|
||||
handler := g.GetHandler()
|
||||
if handler.IsNotNil() {
|
||||
var handlerName = handler.Text()
|
||||
handlerName := handler.Text()
|
||||
handlerMap[handlerName] = Holder
|
||||
path := fmt.Sprintf("%s://%s", g.Route.Method.Text(), g.Route.Path.Text())
|
||||
routeMap[path] = Holder
|
||||
|
||||
@@ -5,9 +5,9 @@ import (
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/antlr/antlr4/runtime/Go/antlr"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/api/parser/g4/gen/api"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/util/console"
|
||||
"github.com/zeromicro/antlr"
|
||||
)
|
||||
|
||||
type (
|
||||
@@ -289,7 +289,7 @@ func (v *ApiVisitor) getHiddenTokensToLeft(t TokenStream, channel int, containsC
|
||||
|
||||
if index > 0 {
|
||||
allTokens := ct.GetAllTokens()
|
||||
var flag = false
|
||||
flag := false
|
||||
for i := index; i >= 0; i-- {
|
||||
tk := allTokens[i]
|
||||
if tk.GetChannel() == antlr.LexerDefaultTokenChannel {
|
||||
|
||||
@@ -50,7 +50,7 @@ type AtDoc struct {
|
||||
Kv []*KvExpr
|
||||
}
|
||||
|
||||
// AtHandler describes service hander ast for api syntax
|
||||
// AtHandler describes service handler ast for api syntax
|
||||
type AtHandler struct {
|
||||
AtHandlerToken Expr
|
||||
Name Expr
|
||||
@@ -630,7 +630,7 @@ func (s *Service) Equal(v interface{}) bool {
|
||||
return s.ServiceApi.Equal(service.ServiceApi)
|
||||
}
|
||||
|
||||
// Get returns the tergate KV by specified key
|
||||
// Get returns the target KV by specified key
|
||||
func (kv KV) Get(key string) Expr {
|
||||
for _, each := range kv {
|
||||
if each.Key.Text() == key {
|
||||
|
||||
@@ -17,7 +17,7 @@ type (
|
||||
NameExpr() Expr
|
||||
}
|
||||
|
||||
// TypeAlias describes alias ast for api syatax
|
||||
// TypeAlias describes alias ast for api syntax
|
||||
TypeAlias struct {
|
||||
Name Expr
|
||||
Assign Expr
|
||||
@@ -26,7 +26,7 @@ type (
|
||||
CommentExpr Expr
|
||||
}
|
||||
|
||||
// TypeStruct describes structure ast for api syatax
|
||||
// TypeStruct describes structure ast for api syntax
|
||||
TypeStruct struct {
|
||||
Name Expr
|
||||
Struct Expr
|
||||
@@ -128,7 +128,6 @@ func (v *ApiVisitor) VisitTypeBlock(ctx *api.TypeBlockContext) interface{} {
|
||||
var types []TypeExpr
|
||||
for _, each := range list {
|
||||
types = append(types, each.Accept(v).(TypeExpr))
|
||||
|
||||
}
|
||||
return types
|
||||
}
|
||||
@@ -155,7 +154,6 @@ func (v *ApiVisitor) VisitTypeStruct(ctx *api.TypeStructContext) interface{} {
|
||||
st.Name = v.newExprWithToken(ctx.GetStructName())
|
||||
|
||||
if util.UnExport(ctx.GetStructName().GetText()) {
|
||||
|
||||
}
|
||||
if ctx.GetStructToken() != nil {
|
||||
structExpr := v.newExprWithToken(ctx.GetStructToken())
|
||||
@@ -225,7 +223,7 @@ func (v *ApiVisitor) VisitTypeBlockAlias(ctx *api.TypeBlockAliasContext) interfa
|
||||
alias.DocExpr = v.getDoc(ctx)
|
||||
alias.CommentExpr = v.getComment(ctx)
|
||||
// todo: reopen if necessary
|
||||
v.panic(alias.Name, "unsupport alias")
|
||||
v.panic(alias.Name, "unsupported alias")
|
||||
return &alias
|
||||
}
|
||||
|
||||
@@ -238,7 +236,7 @@ func (v *ApiVisitor) VisitTypeAlias(ctx *api.TypeAliasContext) interface{} {
|
||||
alias.DocExpr = v.getDoc(ctx)
|
||||
alias.CommentExpr = v.getComment(ctx)
|
||||
// todo: reopen if necessary
|
||||
v.panic(alias.Name, "unsupport alias")
|
||||
v.panic(alias.Name, "unsupported alias")
|
||||
return &alias
|
||||
}
|
||||
|
||||
@@ -319,7 +317,7 @@ func (v *ApiVisitor) VisitDataType(ctx *api.DataTypeContext) interface{} {
|
||||
if ctx.GetTime() != nil {
|
||||
// todo: reopen if it is necessary
|
||||
timeExpr := v.newExprWithToken(ctx.GetTime())
|
||||
v.panic(timeExpr, "unsupport time.Time")
|
||||
v.panic(timeExpr, "unsupported time.Time")
|
||||
return &Time{Literal: timeExpr}
|
||||
}
|
||||
if ctx.PointerType() != nil {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
// Code generated from tools/goctl/api/parser/g4/ApiParser.g4 by ANTLR 4.9. DO NOT EDIT.
|
||||
|
||||
package api // ApiParser
|
||||
import "github.com/antlr/antlr4/runtime/Go/antlr"
|
||||
import "github.com/zeromicro/antlr"
|
||||
|
||||
type BaseApiParserVisitor struct {
|
||||
*antlr.BaseParseTreeVisitor
|
||||
|
||||
@@ -6,12 +6,14 @@ import (
|
||||
"fmt"
|
||||
"unicode"
|
||||
|
||||
"github.com/antlr/antlr4/runtime/Go/antlr"
|
||||
"github.com/zeromicro/antlr"
|
||||
)
|
||||
|
||||
// Suppress unused import error
|
||||
var _ = fmt.Printf
|
||||
var _ = unicode.IsLetter
|
||||
var (
|
||||
_ = fmt.Printf
|
||||
_ = unicode.IsLetter
|
||||
)
|
||||
|
||||
var serializedLexerAtn = []uint16{
|
||||
3, 24715, 42794, 33075, 47597, 16764, 15335, 30598, 22884, 2, 25, 266,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,7 +1,7 @@
|
||||
// Code generated from tools/goctl/api/parser/g4/ApiParser.g4 by ANTLR 4.9. DO NOT EDIT.
|
||||
|
||||
package api // ApiParser
|
||||
import "github.com/antlr/antlr4/runtime/Go/antlr"
|
||||
import "github.com/zeromicro/antlr"
|
||||
|
||||
// A complete Visitor for a parse tree produced by ApiParserParser.
|
||||
type ApiParserVisitor interface {
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/antlr/antlr4/runtime/Go/antlr"
|
||||
"github.com/zeromicro/antlr"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -16,28 +16,30 @@ const (
|
||||
tagRegex = `(?m)\x60[a-z]+:".+"\x60`
|
||||
)
|
||||
|
||||
var holder = struct{}{}
|
||||
var kind = map[string]struct{}{
|
||||
"bool": holder,
|
||||
"int": holder,
|
||||
"int8": holder,
|
||||
"int16": holder,
|
||||
"int32": holder,
|
||||
"int64": holder,
|
||||
"uint": holder,
|
||||
"uint8": holder,
|
||||
"uint16": holder,
|
||||
"uint32": holder,
|
||||
"uint64": holder,
|
||||
"uintptr": holder,
|
||||
"float32": holder,
|
||||
"float64": holder,
|
||||
"complex64": holder,
|
||||
"complex128": holder,
|
||||
"string": holder,
|
||||
"byte": holder,
|
||||
"rune": holder,
|
||||
}
|
||||
var (
|
||||
holder = struct{}{}
|
||||
kind = map[string]struct{}{
|
||||
"bool": holder,
|
||||
"int": holder,
|
||||
"int8": holder,
|
||||
"int16": holder,
|
||||
"int32": holder,
|
||||
"int64": holder,
|
||||
"uint": holder,
|
||||
"uint8": holder,
|
||||
"uint16": holder,
|
||||
"uint32": holder,
|
||||
"uint64": holder,
|
||||
"uintptr": holder,
|
||||
"float32": holder,
|
||||
"float64": holder,
|
||||
"complex64": holder,
|
||||
"complex128": holder,
|
||||
"string": holder,
|
||||
"byte": holder,
|
||||
"rune": holder,
|
||||
}
|
||||
)
|
||||
|
||||
func match(p *ApiParserParser, text string) {
|
||||
v := getCurrentTokenText(p)
|
||||
|
||||
@@ -145,7 +145,6 @@ line"`),
|
||||
},
|
||||
},
|
||||
}))
|
||||
|
||||
})
|
||||
|
||||
t.Run("mismatched", func(t *testing.T) {
|
||||
|
||||
@@ -120,7 +120,8 @@ func TestRoute(t *testing.T) {
|
||||
PointerExpr: ast.NewTextExpr("*Bar"),
|
||||
Star: ast.NewTextExpr("*"),
|
||||
Name: ast.NewTextExpr("Bar"),
|
||||
}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}))
|
||||
|
||||
@@ -224,7 +225,6 @@ func TestAtHandler(t *testing.T) {
|
||||
_, err = parser.Accept(fn, `@handler "foo"`)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func TestAtDoc(t *testing.T) {
|
||||
|
||||
@@ -64,7 +64,7 @@ func (p parser) convert2Spec() error {
|
||||
}
|
||||
|
||||
func (p parser) fillInfo() {
|
||||
var properties = make(map[string]string, 0)
|
||||
properties := make(map[string]string, 0)
|
||||
if p.ast.Info != nil {
|
||||
p.spec.Info = spec.Info{}
|
||||
for _, kv := range p.ast.Info.Kvs {
|
||||
@@ -147,8 +147,8 @@ func (p parser) findDefinedType(name string) (*spec.Type, error) {
|
||||
}
|
||||
|
||||
func (p parser) fieldToMember(field *ast.TypeField) spec.Member {
|
||||
var name = ""
|
||||
var tag = ""
|
||||
name := ""
|
||||
tag := ""
|
||||
if !field.IsAnonymous {
|
||||
name = field.Name.Text()
|
||||
if field.Tag == nil {
|
||||
@@ -219,9 +219,9 @@ func (p parser) fillService() error {
|
||||
|
||||
for _, astRoute := range item.ServiceApi.ServiceRoute {
|
||||
route := spec.Route{
|
||||
Annotation: spec.Annotation{},
|
||||
Method: astRoute.Route.Method.Text(),
|
||||
Path: astRoute.Route.Path.Text(),
|
||||
AtServerAnnotation: spec.Annotation{},
|
||||
Method: astRoute.Route.Method.Text(),
|
||||
Path: astRoute.Route.Path.Text(),
|
||||
}
|
||||
if astRoute.AtHandler != nil {
|
||||
route.Handler = astRoute.AtHandler.Name.Text()
|
||||
@@ -239,7 +239,7 @@ func (p parser) fillService() error {
|
||||
route.ResponseType = p.astTypeToSpec(astRoute.Route.Reply.Name)
|
||||
}
|
||||
if astRoute.AtDoc != nil {
|
||||
var properties = make(map[string]string, 0)
|
||||
properties := make(map[string]string, 0)
|
||||
for _, kv := range astRoute.AtDoc.Kv {
|
||||
properties[kv.Key.Text()] = kv.Value.Text()
|
||||
}
|
||||
@@ -271,11 +271,11 @@ func (p parser) fillService() error {
|
||||
|
||||
func (p parser) fillRouteAtServer(astRoute *ast.ServiceRoute, route *spec.Route) error {
|
||||
if astRoute.AtServer != nil {
|
||||
var properties = make(map[string]string, 0)
|
||||
properties := make(map[string]string, 0)
|
||||
for _, kv := range astRoute.AtServer.Kv {
|
||||
properties[kv.Key.Text()] = kv.Value.Text()
|
||||
}
|
||||
route.Annotation.Properties = properties
|
||||
route.AtServerAnnotation.Properties = properties
|
||||
if len(route.Handler) == 0 {
|
||||
route.Handler = properties["handler"]
|
||||
}
|
||||
@@ -295,7 +295,7 @@ func (p parser) fillRouteAtServer(astRoute *ast.ServiceRoute, route *spec.Route)
|
||||
|
||||
func (p parser) fillAtServer(item *ast.Service, group *spec.Group) {
|
||||
if item.AtServer != nil {
|
||||
var properties = make(map[string]string, 0)
|
||||
properties := make(map[string]string, 0)
|
||||
for _, kv := range item.AtServer.Kv {
|
||||
properties[kv.Key.Text()] = kv.Value.Text()
|
||||
}
|
||||
|
||||
@@ -11,10 +11,11 @@ import (
|
||||
const (
|
||||
bodyTagKey = "json"
|
||||
formTagKey = "form"
|
||||
pathTagKey = "path"
|
||||
defaultSummaryKey = "summary"
|
||||
)
|
||||
|
||||
var definedKeys = []string{bodyTagKey, formTagKey, "path"}
|
||||
var definedKeys = []string{bodyTagKey, formTagKey, pathTagKey}
|
||||
|
||||
// Routes returns all routes in api service
|
||||
func (s Service) Routes() []Route {
|
||||
@@ -25,7 +26,7 @@ func (s Service) Routes() []Route {
|
||||
return result
|
||||
}
|
||||
|
||||
// Tags retuens all tags in Member
|
||||
// Tags returns all tags in Member
|
||||
func (m Member) Tags() []*Tag {
|
||||
tags, err := Parse(m.Tag)
|
||||
if err != nil {
|
||||
@@ -141,7 +142,7 @@ func (t DefineStruct) GetFormMembers() []Member {
|
||||
return result
|
||||
}
|
||||
|
||||
// GetNonBodyMembers retruns all have no tag fields
|
||||
// GetNonBodyMembers returns all have no tag fields
|
||||
func (t DefineStruct) GetNonBodyMembers() []Member {
|
||||
var result []Member
|
||||
for _, member := range t.Members {
|
||||
@@ -162,16 +163,16 @@ func (r Route) JoinedDoc() string {
|
||||
return strings.TrimSpace(doc)
|
||||
}
|
||||
|
||||
// GetAnnotation returns the value by specified key
|
||||
// GetAnnotation returns the value by specified key from @server
|
||||
func (r Route) GetAnnotation(key string) string {
|
||||
if r.Annotation.Properties == nil {
|
||||
if r.AtServerAnnotation.Properties == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return r.Annotation.Properties[key]
|
||||
return r.AtServerAnnotation.Properties[key]
|
||||
}
|
||||
|
||||
// GetAnnotation returns the value by specified key
|
||||
// GetAnnotation returns the value by specified key from @server
|
||||
func (g Group) GetAnnotation(key string) string {
|
||||
if g.Annotation.Properties == nil {
|
||||
return ""
|
||||
|
||||
@@ -63,14 +63,14 @@ type (
|
||||
|
||||
// Route describes api route
|
||||
Route struct {
|
||||
Annotation Annotation
|
||||
Method string
|
||||
Path string
|
||||
RequestType Type
|
||||
ResponseType Type
|
||||
Docs Doc
|
||||
Handler string
|
||||
AtDoc AtDoc
|
||||
AtServerAnnotation Annotation
|
||||
Method string
|
||||
Path string
|
||||
RequestType Type
|
||||
ResponseType Type
|
||||
Docs Doc
|
||||
Handler string
|
||||
AtDoc AtDoc
|
||||
}
|
||||
|
||||
// Service describes api service
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user