mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-05-28 09:05:29 +08:00
Compare commits
189 Commits
tools/goct
...
v1.5.6
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
421e6617b1 | ||
|
|
0ee7a271d3 | ||
|
|
af022b9655 | ||
|
|
98d46261d9 | ||
|
|
4222fd97bc | ||
|
|
814852f0b8 | ||
|
|
ded2888759 | ||
|
|
18d66a795d | ||
|
|
4211672bfd | ||
|
|
68df0c3620 | ||
|
|
5e435b6a76 | ||
|
|
0dcede6457 | ||
|
|
cc21f5fae2 | ||
|
|
b22ad50d59 | ||
|
|
974252980c | ||
|
|
8d83986d27 | ||
|
|
6821b0a7dd | ||
|
|
1ba1724c65 | ||
|
|
ca5a7df5b0 | ||
|
|
69a3024853 | ||
|
|
fd3abf3717 | ||
|
|
99b3750d10 | ||
|
|
33f6d7ebb8 | ||
|
|
c4ef9ceb68 | ||
|
|
e95861f28a | ||
|
|
d3cd7b17c0 | ||
|
|
a50515496c | ||
|
|
0423313d9b | ||
|
|
7bbe7de05f | ||
|
|
83a451f2f4 | ||
|
|
d2a874f21d | ||
|
|
fd85b24b25 | ||
|
|
14fcbd7658 | ||
|
|
cb3ffc76a3 | ||
|
|
45fbd7dc35 | ||
|
|
af821cf794 | ||
|
|
ec69950153 | ||
|
|
ce5e78db53 | ||
|
|
ed75802eaa | ||
|
|
76c92b571d | ||
|
|
a2e703c53e | ||
|
|
ca698deb2a | ||
|
|
a9f4aab86b | ||
|
|
c3f57e9b0a | ||
|
|
ad4cce959d | ||
|
|
279123f4a7 | ||
|
|
457eb1961b | ||
|
|
63df384a4b | ||
|
|
42bfa26e2b | ||
|
|
ff04356704 | ||
|
|
05db706c62 | ||
|
|
ef2e0d859d | ||
|
|
05ec16ae9d | ||
|
|
13e685e0db | ||
|
|
c10f44b74e | ||
|
|
57644420ed | ||
|
|
b245159417 | ||
|
|
c26ea17669 | ||
|
|
a7daff3587 | ||
|
|
6719d06146 | ||
|
|
0c6eaeda9f | ||
|
|
b9c0c0f8b5 | ||
|
|
77da459165 | ||
|
|
13cdbdc98b | ||
|
|
e8c1e6e09b | ||
|
|
f1171e01f2 | ||
|
|
61e562d0c7 | ||
|
|
b71453985c | ||
|
|
31b9ba19a2 | ||
|
|
3170afd57b | ||
|
|
03e365a5d8 | ||
|
|
7d4fce9588 | ||
|
|
916cea858f | ||
|
|
a86942d532 | ||
|
|
f76c70ea9a | ||
|
|
4cbfdb3d74 | ||
|
|
aefa6dfb50 | ||
|
|
9047029475 | ||
|
|
f296c182f7 | ||
|
|
40e7a4cd07 | ||
|
|
92e5819e91 | ||
|
|
8d23ab158b | ||
|
|
bcccfab824 | ||
|
|
f7e701a634 | ||
|
|
7c2d8e5cc2 | ||
|
|
5b622d6265 | ||
|
|
c5510a4e1b | ||
|
|
2a33b74b35 | ||
|
|
45bb547a81 | ||
|
|
f5f5261556 | ||
|
|
b176d5d434 | ||
|
|
92f6c48349 | ||
|
|
71e8230e65 | ||
|
|
018fa8e0a0 | ||
|
|
979fe9718a | ||
|
|
f998803131 | ||
|
|
1262266ac2 | ||
|
|
9c32bf8478 | ||
|
|
37ec7f6443 | ||
|
|
2fdc4dfc0f | ||
|
|
4b2a6ba3de | ||
|
|
7fa3f10f22 | ||
|
|
4a29a0b642 | ||
|
|
a62745a152 | ||
|
|
28314326e7 | ||
|
|
f6bdb6e1de | ||
|
|
efa6940001 | ||
|
|
da81d8f774 | ||
|
|
fd84b27bdc | ||
|
|
6b4d0d89c0 | ||
|
|
d61a55f779 | ||
|
|
8ef4164209 | ||
|
|
50e29e2075 | ||
|
|
452c9dbcaf | ||
|
|
3564e36a35 | ||
|
|
e479e47634 | ||
|
|
ad921a6419 | ||
|
|
44c8d6f269 | ||
|
|
8a4cc4f98d | ||
|
|
e751736516 | ||
|
|
032f2419a2 | ||
|
|
84adc054bc | ||
|
|
b92e706ce1 | ||
|
|
1b5946346e | ||
|
|
28d3905731 | ||
|
|
3726851c7f | ||
|
|
2f2ddd373b | ||
|
|
8d48e34eed | ||
|
|
32f78668db | ||
|
|
cd0f3726ed | ||
|
|
0217044900 | ||
|
|
8b4382dcec | ||
|
|
fa33329a44 | ||
|
|
d76a39ac26 | ||
|
|
76a7a17e57 | ||
|
|
4a2a8d9e45 | ||
|
|
ef26b39b4c | ||
|
|
3ca40001b4 | ||
|
|
278ae3d26a | ||
|
|
fa1d6d50a8 | ||
|
|
0f4973be06 | ||
|
|
a9aac7e420 | ||
|
|
925cf8d3d1 | ||
|
|
99ce24e2ab | ||
|
|
701bb31ed2 | ||
|
|
55e2c7ee83 | ||
|
|
90839965fa | ||
|
|
f7228e9af1 | ||
|
|
f95adae3c1 | ||
|
|
bff5b81ad9 | ||
|
|
f0bdfb928f | ||
|
|
e4a1b7bb39 | ||
|
|
b6906b5d21 | ||
|
|
116da96178 | ||
|
|
9fa98c2bd3 | ||
|
|
b1c4c4736f | ||
|
|
ef410e8083 | ||
|
|
c22bc1c8ea | ||
|
|
1853428011 | ||
|
|
3637e10815 | ||
|
|
93124329ac | ||
|
|
851a72f1cc | ||
|
|
a93c24ce84 | ||
|
|
9f42eda9ff | ||
|
|
8762a3b7ba | ||
|
|
2684a157ff | ||
|
|
63368d8b0c | ||
|
|
4f13fe8188 | ||
|
|
9fc7874336 | ||
|
|
e6518521eb | ||
|
|
8f5a0a2de7 | ||
|
|
774e8d1d08 | ||
|
|
8ad0668612 | ||
|
|
8a043d2443 | ||
|
|
0e2ee97a02 | ||
|
|
42300a7d83 | ||
|
|
fe97fab274 | ||
|
|
f93e752f98 | ||
|
|
3a66fc038f | ||
|
|
b028ed058d | ||
|
|
1fd0c3992b | ||
|
|
1aebb3e5e4 | ||
|
|
8ffe4c01d1 | ||
|
|
a31256b327 | ||
|
|
14caf5c799 | ||
|
|
c0f8a58ed7 | ||
|
|
3189ec7be6 | ||
|
|
f51e9f0ea7 | ||
|
|
ba9d510cdb |
@@ -6,3 +6,4 @@ ignore:
|
|||||||
- "tools"
|
- "tools"
|
||||||
- "**/mock"
|
- "**/mock"
|
||||||
- "**/*_mock.go"
|
- "**/*_mock.go"
|
||||||
|
- "**/*test"
|
||||||
|
|||||||
2
.github/workflows/go.yml
vendored
2
.github/workflows/go.yml
vendored
@@ -61,5 +61,5 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
go mod verify
|
go mod verify
|
||||||
go mod download
|
go mod download
|
||||||
go test -v -race ./...
|
go test ./...
|
||||||
cd tools/goctl && go build -v goctl.go
|
cd tools/goctl && go build -v goctl.go
|
||||||
|
|||||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -12,11 +12,13 @@
|
|||||||
|
|
||||||
# ignore
|
# ignore
|
||||||
**/.idea
|
**/.idea
|
||||||
|
**/.vscode
|
||||||
**/.DS_Store
|
**/.DS_Store
|
||||||
**/logs
|
**/logs
|
||||||
|
**/adhoc
|
||||||
|
**/coverage.txt
|
||||||
|
|
||||||
# for test purpose
|
# for test purpose
|
||||||
**/adhoc
|
|
||||||
go.work
|
go.work
|
||||||
go.work.sum
|
go.work.sum
|
||||||
|
|
||||||
|
|||||||
@@ -29,6 +29,8 @@ func NewSafeMap() *SafeMap {
|
|||||||
// Del deletes the value with the given key from m.
|
// Del deletes the value with the given key from m.
|
||||||
func (m *SafeMap) Del(key any) {
|
func (m *SafeMap) Del(key any) {
|
||||||
m.lock.Lock()
|
m.lock.Lock()
|
||||||
|
defer m.lock.Unlock()
|
||||||
|
|
||||||
if _, ok := m.dirtyOld[key]; ok {
|
if _, ok := m.dirtyOld[key]; ok {
|
||||||
delete(m.dirtyOld, key)
|
delete(m.dirtyOld, key)
|
||||||
m.deletionOld++
|
m.deletionOld++
|
||||||
@@ -52,7 +54,6 @@ func (m *SafeMap) Del(key any) {
|
|||||||
m.dirtyNew = make(map[any]any)
|
m.dirtyNew = make(map[any]any)
|
||||||
m.deletionNew = 0
|
m.deletionNew = 0
|
||||||
}
|
}
|
||||||
m.lock.Unlock()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get gets the value with the given key from m.
|
// Get gets the value with the given key from m.
|
||||||
@@ -89,6 +90,8 @@ func (m *SafeMap) Range(f func(key, val any) bool) {
|
|||||||
// Set sets the value into m with the given key.
|
// Set sets the value into m with the given key.
|
||||||
func (m *SafeMap) Set(key, value any) {
|
func (m *SafeMap) Set(key, value any) {
|
||||||
m.lock.Lock()
|
m.lock.Lock()
|
||||||
|
defer m.lock.Unlock()
|
||||||
|
|
||||||
if m.deletionOld <= maxDeletion {
|
if m.deletionOld <= maxDeletion {
|
||||||
if _, ok := m.dirtyNew[key]; ok {
|
if _, ok := m.dirtyNew[key]; ok {
|
||||||
delete(m.dirtyNew, key)
|
delete(m.dirtyNew, key)
|
||||||
@@ -102,7 +105,6 @@ func (m *SafeMap) Set(key, value any) {
|
|||||||
}
|
}
|
||||||
m.dirtyNew[key] = value
|
m.dirtyNew[key] = value
|
||||||
}
|
}
|
||||||
m.lock.Unlock()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Size returns the size of m.
|
// Size returns the size of m.
|
||||||
|
|||||||
@@ -147,3 +147,65 @@ func TestSafeMap_Range(t *testing.T) {
|
|||||||
assert.Equal(t, m.dirtyNew, newMap.dirtyNew)
|
assert.Equal(t, m.dirtyNew, newMap.dirtyNew)
|
||||||
assert.Equal(t, m.dirtyOld, newMap.dirtyOld)
|
assert.Equal(t, m.dirtyOld, newMap.dirtyOld)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSetManyTimes(t *testing.T) {
|
||||||
|
const iteration = maxDeletion * 2
|
||||||
|
m := NewSafeMap()
|
||||||
|
for i := 0; i < iteration; i++ {
|
||||||
|
m.Set(i, i)
|
||||||
|
if i%3 == 0 {
|
||||||
|
m.Del(i / 2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var count int
|
||||||
|
m.Range(func(k, v any) bool {
|
||||||
|
count++
|
||||||
|
return count < maxDeletion/2
|
||||||
|
})
|
||||||
|
assert.Equal(t, maxDeletion/2, count)
|
||||||
|
for i := 0; i < iteration; i++ {
|
||||||
|
m.Set(i, i)
|
||||||
|
if i%3 == 0 {
|
||||||
|
m.Del(i / 2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for i := 0; i < iteration; i++ {
|
||||||
|
m.Set(i, i)
|
||||||
|
if i%3 == 0 {
|
||||||
|
m.Del(i / 2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for i := 0; i < iteration; i++ {
|
||||||
|
m.Set(i, i)
|
||||||
|
if i%3 == 0 {
|
||||||
|
m.Del(i / 2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
m.Range(func(k, v any) bool {
|
||||||
|
count++
|
||||||
|
return count < maxDeletion
|
||||||
|
})
|
||||||
|
assert.Equal(t, maxDeletion, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetManyTimesNew(t *testing.T) {
|
||||||
|
m := NewSafeMap()
|
||||||
|
for i := 0; i < maxDeletion*3; i++ {
|
||||||
|
m.Set(i, i)
|
||||||
|
}
|
||||||
|
for i := 0; i < maxDeletion*2; i++ {
|
||||||
|
m.Del(i)
|
||||||
|
}
|
||||||
|
for i := 0; i < maxDeletion*3; i++ {
|
||||||
|
m.Set(i+maxDeletion*3, i+maxDeletion*3)
|
||||||
|
}
|
||||||
|
for i := 0; i < maxDeletion*2; i++ {
|
||||||
|
m.Del(i + maxDeletion*2)
|
||||||
|
}
|
||||||
|
for i := 0; i < maxDeletion-copyThreshold+1; i++ {
|
||||||
|
m.Del(i + maxDeletion*2)
|
||||||
|
}
|
||||||
|
assert.Equal(t, 0, len(m.dirtyNew))
|
||||||
|
}
|
||||||
|
|||||||
@@ -35,11 +35,11 @@ func TestConfigJson(t *testing.T) {
|
|||||||
"c": "${FOO}",
|
"c": "${FOO}",
|
||||||
"d": "abcd!@#$112"
|
"d": "abcd!@#$112"
|
||||||
}`
|
}`
|
||||||
|
t.Setenv("FOO", "2")
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
test := test
|
test := test
|
||||||
t.Run(test, func(t *testing.T) {
|
t.Run(test, func(t *testing.T) {
|
||||||
os.Setenv("FOO", "2")
|
|
||||||
defer os.Unsetenv("FOO")
|
|
||||||
tmpfile, err := createTempFile(test, text)
|
tmpfile, err := createTempFile(test, text)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
defer os.Remove(tmpfile)
|
defer os.Remove(tmpfile)
|
||||||
@@ -81,8 +81,7 @@ b = 1
|
|||||||
c = "${FOO}"
|
c = "${FOO}"
|
||||||
d = "abcd!@#$112"
|
d = "abcd!@#$112"
|
||||||
`
|
`
|
||||||
os.Setenv("FOO", "2")
|
t.Setenv("FOO", "2")
|
||||||
defer os.Unsetenv("FOO")
|
|
||||||
tmpfile, err := createTempFile(".toml", text)
|
tmpfile, err := createTempFile(".toml", text)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
defer os.Remove(tmpfile)
|
defer os.Remove(tmpfile)
|
||||||
@@ -207,8 +206,7 @@ b = 1
|
|||||||
c = "${FOO}"
|
c = "${FOO}"
|
||||||
d = "abcd!@#112"
|
d = "abcd!@#112"
|
||||||
`
|
`
|
||||||
os.Setenv("FOO", "2")
|
t.Setenv("FOO", "2")
|
||||||
defer os.Unsetenv("FOO")
|
|
||||||
tmpfile, err := createTempFile(".toml", text)
|
tmpfile, err := createTempFile(".toml", text)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
defer os.Remove(tmpfile)
|
defer os.Remove(tmpfile)
|
||||||
@@ -239,11 +237,10 @@ func TestConfigJsonEnv(t *testing.T) {
|
|||||||
"c": "${FOO}",
|
"c": "${FOO}",
|
||||||
"d": "abcd!@#$a12 3"
|
"d": "abcd!@#$a12 3"
|
||||||
}`
|
}`
|
||||||
|
t.Setenv("FOO", "2")
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
test := test
|
test := test
|
||||||
t.Run(test, func(t *testing.T) {
|
t.Run(test, func(t *testing.T) {
|
||||||
os.Setenv("FOO", "2")
|
|
||||||
defer os.Unsetenv("FOO")
|
|
||||||
tmpfile, err := createTempFile(test, text)
|
tmpfile, err := createTempFile(test, text)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
defer os.Remove(tmpfile)
|
defer os.Remove(tmpfile)
|
||||||
|
|||||||
@@ -45,8 +45,7 @@ func TestPropertiesEnv(t *testing.T) {
|
|||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
defer os.Remove(tmpfile)
|
defer os.Remove(tmpfile)
|
||||||
|
|
||||||
os.Setenv("FOO", "2")
|
t.Setenv("FOO", "2")
|
||||||
defer os.Unsetenv("FOO")
|
|
||||||
|
|
||||||
props, err := LoadProperties(tmpfile, UseEnv())
|
props, err := LoadProperties(tmpfile, UseEnv())
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|||||||
@@ -337,13 +337,11 @@ func (c *cluster) watchConnState(cli EtcdClient) {
|
|||||||
// DialClient dials an etcd cluster with given endpoints.
|
// DialClient dials an etcd cluster with given endpoints.
|
||||||
func DialClient(endpoints []string) (EtcdClient, error) {
|
func DialClient(endpoints []string) (EtcdClient, error) {
|
||||||
cfg := clientv3.Config{
|
cfg := clientv3.Config{
|
||||||
Endpoints: endpoints,
|
Endpoints: endpoints,
|
||||||
AutoSyncInterval: autoSyncInterval,
|
AutoSyncInterval: autoSyncInterval,
|
||||||
DialTimeout: DialTimeout,
|
DialTimeout: DialTimeout,
|
||||||
DialKeepAliveTime: dialKeepAliveTime,
|
RejectOldCluster: true,
|
||||||
DialKeepAliveTimeout: DialTimeout,
|
PermitWithoutStream: true,
|
||||||
RejectOldCluster: true,
|
|
||||||
PermitWithoutStream: true,
|
|
||||||
}
|
}
|
||||||
if account, ok := GetAccount(endpoints); ok {
|
if account, ok := GetAccount(endpoints); ok {
|
||||||
cfg.Username = account.User
|
cfg.Username = account.User
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ const (
|
|||||||
autoSyncInterval = time.Minute
|
autoSyncInterval = time.Minute
|
||||||
coolDownInterval = time.Second
|
coolDownInterval = time.Second
|
||||||
dialTimeout = 5 * time.Second
|
dialTimeout = 5 * time.Second
|
||||||
dialKeepAliveTime = 5 * time.Second
|
|
||||||
requestTimeout = 3 * time.Second
|
requestTimeout = 3 * time.Second
|
||||||
endpointsSeparator = ","
|
endpointsSeparator = ","
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ func TestBulkExecutorFlush(t *testing.T) {
|
|||||||
wait.Wait()
|
wait.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBuldExecutorFlushSlowTasks(t *testing.T) {
|
func TestBulkExecutorFlushSlowTasks(t *testing.T) {
|
||||||
const total = 1500
|
const total = 1500
|
||||||
lock := new(sync.Mutex)
|
lock := new(sync.Mutex)
|
||||||
result := make([]any, 0, 10000)
|
result := make([]any, 0, 10000)
|
||||||
|
|||||||
@@ -168,23 +168,23 @@ func TestPeriodicalExecutor_FlushPanic(t *testing.T) {
|
|||||||
|
|
||||||
func TestPeriodicalExecutor_Wait(t *testing.T) {
|
func TestPeriodicalExecutor_Wait(t *testing.T) {
|
||||||
var lock sync.Mutex
|
var lock sync.Mutex
|
||||||
executer := NewBulkExecutor(func(tasks []any) {
|
executor := NewBulkExecutor(func(tasks []any) {
|
||||||
lock.Lock()
|
lock.Lock()
|
||||||
defer lock.Unlock()
|
defer lock.Unlock()
|
||||||
time.Sleep(10 * time.Millisecond)
|
time.Sleep(10 * time.Millisecond)
|
||||||
}, WithBulkTasks(1), WithBulkInterval(time.Second))
|
}, WithBulkTasks(1), WithBulkInterval(time.Second))
|
||||||
for i := 0; i < 10; i++ {
|
for i := 0; i < 10; i++ {
|
||||||
executer.Add(1)
|
executor.Add(1)
|
||||||
}
|
}
|
||||||
executer.Flush()
|
executor.Flush()
|
||||||
executer.Wait()
|
executor.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPeriodicalExecutor_WaitFast(t *testing.T) {
|
func TestPeriodicalExecutor_WaitFast(t *testing.T) {
|
||||||
const total = 3
|
const total = 3
|
||||||
var cnt int
|
var cnt int
|
||||||
var lock sync.Mutex
|
var lock sync.Mutex
|
||||||
executer := NewBulkExecutor(func(tasks []any) {
|
executor := NewBulkExecutor(func(tasks []any) {
|
||||||
defer func() {
|
defer func() {
|
||||||
cnt++
|
cnt++
|
||||||
}()
|
}()
|
||||||
@@ -193,10 +193,10 @@ func TestPeriodicalExecutor_WaitFast(t *testing.T) {
|
|||||||
time.Sleep(10 * time.Millisecond)
|
time.Sleep(10 * time.Millisecond)
|
||||||
}, WithBulkTasks(1), WithBulkInterval(10*time.Millisecond))
|
}, WithBulkTasks(1), WithBulkInterval(10*time.Millisecond))
|
||||||
for i := 0; i < total; i++ {
|
for i := 0; i < total; i++ {
|
||||||
executer.Add(2)
|
executor.Add(2)
|
||||||
}
|
}
|
||||||
executer.Flush()
|
executor.Flush()
|
||||||
executer.Wait()
|
executor.Wait()
|
||||||
assert.Equal(t, total, cnt)
|
assert.Equal(t, total, cnt)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -74,6 +74,11 @@ func TestFirstLineShort(t *testing.T) {
|
|||||||
assert.Equal(t, "first line", val)
|
assert.Equal(t, "first line", val)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFirstLineError(t *testing.T) {
|
||||||
|
_, err := FirstLine("/tmp/does-not-exist")
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
func TestLastLine(t *testing.T) {
|
func TestLastLine(t *testing.T) {
|
||||||
filename, err := fs.TempFilenameWithText(text)
|
filename, err := fs.TempFilenameWithText(text)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
@@ -113,3 +118,8 @@ func TestLastLineWithLastNewlineShort(t *testing.T) {
|
|||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, "last line", val)
|
assert.Equal(t, "last line", val)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestLastLineError(t *testing.T) {
|
||||||
|
_, err := LastLine("/tmp/does-not-exist")
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|||||||
@@ -11,29 +11,29 @@ import (
|
|||||||
// The file is kept as open, the caller should close the file handle,
|
// The file is kept as open, the caller should close the file handle,
|
||||||
// and remove the file by name.
|
// and remove the file by name.
|
||||||
func TempFileWithText(text string) (*os.File, error) {
|
func TempFileWithText(text string) (*os.File, error) {
|
||||||
tmpfile, err := os.CreateTemp(os.TempDir(), hash.Md5Hex([]byte(text)))
|
tmpFile, err := os.CreateTemp(os.TempDir(), hash.Md5Hex([]byte(text)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := os.WriteFile(tmpfile.Name(), []byte(text), os.ModeTemporary); err != nil {
|
if err := os.WriteFile(tmpFile.Name(), []byte(text), os.ModeTemporary); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return tmpfile, nil
|
return tmpFile, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TempFilenameWithText creates the file with the given content,
|
// TempFilenameWithText creates the file with the given content,
|
||||||
// and returns the filename (full path).
|
// and returns the filename (full path).
|
||||||
// The caller should remove the file after use.
|
// The caller should remove the file after use.
|
||||||
func TempFilenameWithText(text string) (string, error) {
|
func TempFilenameWithText(text string) (string, error) {
|
||||||
tmpfile, err := TempFileWithText(text)
|
tmpFile, err := TempFileWithText(text)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
filename := tmpfile.Name()
|
filename := tmpFile.Name()
|
||||||
if err = tmpfile.Close(); err != nil {
|
if err = tmpFile.Close(); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,31 +1,87 @@
|
|||||||
package fx
|
package fx
|
||||||
|
|
||||||
import "github.com/zeromicro/go-zero/core/errorx"
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/zeromicro/go-zero/core/errorx"
|
||||||
|
)
|
||||||
|
|
||||||
const defaultRetryTimes = 3
|
const defaultRetryTimes = 3
|
||||||
|
|
||||||
|
var errTimeout = errors.New("retry timeout")
|
||||||
|
|
||||||
type (
|
type (
|
||||||
// RetryOption defines the method to customize DoWithRetry.
|
// RetryOption defines the method to customize DoWithRetry.
|
||||||
RetryOption func(*retryOptions)
|
RetryOption func(*retryOptions)
|
||||||
|
|
||||||
retryOptions struct {
|
retryOptions struct {
|
||||||
times int
|
times int
|
||||||
|
interval time.Duration
|
||||||
|
timeout time.Duration
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
// DoWithRetry runs fn, and retries if failed. Default to retry 3 times.
|
// DoWithRetry runs fn, and retries if failed. Default to retry 3 times.
|
||||||
|
// Note that if the fn function accesses global variables outside the function
|
||||||
|
// and performs modification operations, it is best to lock them,
|
||||||
|
// otherwise there may be data race issues
|
||||||
func DoWithRetry(fn func() error, opts ...RetryOption) error {
|
func DoWithRetry(fn func() error, opts ...RetryOption) error {
|
||||||
|
return retry(func(errChan chan error, retryCount int) {
|
||||||
|
errChan <- fn()
|
||||||
|
}, opts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoWithRetryCtx runs fn, and retries if failed. Default to retry 3 times.
|
||||||
|
// fn retryCount indicates the current number of retries, starting from 0
|
||||||
|
// Note that if the fn function accesses global variables outside the function
|
||||||
|
// and performs modification operations, it is best to lock them,
|
||||||
|
// otherwise there may be data race issues
|
||||||
|
func DoWithRetryCtx(ctx context.Context, fn func(ctx context.Context, retryCount int) error,
|
||||||
|
opts ...RetryOption) error {
|
||||||
|
return retry(func(errChan chan error, retryCount int) {
|
||||||
|
errChan <- fn(ctx, retryCount)
|
||||||
|
}, opts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func retry(fn func(errChan chan error, retryCount int), opts ...RetryOption) error {
|
||||||
options := newRetryOptions()
|
options := newRetryOptions()
|
||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
opt(options)
|
opt(options)
|
||||||
}
|
}
|
||||||
|
|
||||||
var berr errorx.BatchError
|
var berr errorx.BatchError
|
||||||
|
var cancelFunc context.CancelFunc
|
||||||
|
ctx := context.Background()
|
||||||
|
if options.timeout > 0 {
|
||||||
|
ctx, cancelFunc = context.WithTimeout(ctx, options.timeout)
|
||||||
|
defer cancelFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
errChan := make(chan error, 1)
|
||||||
for i := 0; i < options.times; i++ {
|
for i := 0; i < options.times; i++ {
|
||||||
if err := fn(); err != nil {
|
go fn(errChan, i)
|
||||||
berr.Add(err)
|
|
||||||
} else {
|
select {
|
||||||
return nil
|
case err := <-errChan:
|
||||||
|
if err != nil {
|
||||||
|
berr.Add(err)
|
||||||
|
} else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
case <-ctx.Done():
|
||||||
|
berr.Add(errTimeout)
|
||||||
|
return berr.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
if options.interval > 0 {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
berr.Add(errTimeout)
|
||||||
|
return berr.Err()
|
||||||
|
case <-time.After(options.interval):
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -39,6 +95,18 @@ func WithRetry(times int) RetryOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithInterval(interval time.Duration) RetryOption {
|
||||||
|
return func(options *retryOptions) {
|
||||||
|
options.interval = interval
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithTimeout(timeout time.Duration) RetryOption {
|
||||||
|
return func(options *retryOptions) {
|
||||||
|
options.timeout = timeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func newRetryOptions() *retryOptions {
|
func newRetryOptions() *retryOptions {
|
||||||
return &retryOptions{
|
return &retryOptions{
|
||||||
times: defaultRetryTimes,
|
times: defaultRetryTimes,
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
package fx
|
package fx
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
@@ -12,31 +14,103 @@ func TestRetry(t *testing.T) {
|
|||||||
return errors.New("any")
|
return errors.New("any")
|
||||||
}))
|
}))
|
||||||
|
|
||||||
var times int
|
times1 := 0
|
||||||
assert.Nil(t, DoWithRetry(func() error {
|
assert.Nil(t, DoWithRetry(func() error {
|
||||||
times++
|
times1++
|
||||||
if times == defaultRetryTimes {
|
if times1 == defaultRetryTimes {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return errors.New("any")
|
return errors.New("any")
|
||||||
}))
|
}))
|
||||||
|
|
||||||
times = 0
|
times2 := 0
|
||||||
assert.NotNil(t, DoWithRetry(func() error {
|
assert.NotNil(t, DoWithRetry(func() error {
|
||||||
times++
|
times2++
|
||||||
if times == defaultRetryTimes+1 {
|
if times2 == defaultRetryTimes+1 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return errors.New("any")
|
return errors.New("any")
|
||||||
}))
|
}))
|
||||||
|
|
||||||
total := 2 * defaultRetryTimes
|
total := 2 * defaultRetryTimes
|
||||||
times = 0
|
times3 := 0
|
||||||
assert.Nil(t, DoWithRetry(func() error {
|
assert.Nil(t, DoWithRetry(func() error {
|
||||||
times++
|
times3++
|
||||||
if times == total {
|
if times3 == total {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return errors.New("any")
|
return errors.New("any")
|
||||||
}, WithRetry(total)))
|
}, WithRetry(total)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRetryWithTimeout(t *testing.T) {
|
||||||
|
assert.Nil(t, DoWithRetry(func() error {
|
||||||
|
return nil
|
||||||
|
}, WithTimeout(time.Millisecond*500)))
|
||||||
|
|
||||||
|
times1 := 0
|
||||||
|
assert.Nil(t, DoWithRetry(func() error {
|
||||||
|
times1++
|
||||||
|
if times1 == 1 {
|
||||||
|
return errors.New("any ")
|
||||||
|
}
|
||||||
|
time.Sleep(time.Millisecond * 150)
|
||||||
|
return nil
|
||||||
|
}, WithTimeout(time.Millisecond*250)))
|
||||||
|
|
||||||
|
total := defaultRetryTimes
|
||||||
|
times2 := 0
|
||||||
|
assert.Nil(t, DoWithRetry(func() error {
|
||||||
|
times2++
|
||||||
|
if times2 == total {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
time.Sleep(time.Millisecond * 50)
|
||||||
|
return errors.New("any")
|
||||||
|
}, WithTimeout(time.Millisecond*50*(time.Duration(total)+2))))
|
||||||
|
|
||||||
|
assert.NotNil(t, DoWithRetry(func() error {
|
||||||
|
return errors.New("any")
|
||||||
|
}, WithTimeout(time.Millisecond*250)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRetryWithInterval(t *testing.T) {
|
||||||
|
times1 := 0
|
||||||
|
assert.NotNil(t, DoWithRetry(func() error {
|
||||||
|
times1++
|
||||||
|
if times1 == 1 {
|
||||||
|
return errors.New("any")
|
||||||
|
}
|
||||||
|
time.Sleep(time.Millisecond * 150)
|
||||||
|
return nil
|
||||||
|
}, WithTimeout(time.Millisecond*250), WithInterval(time.Millisecond*150)))
|
||||||
|
|
||||||
|
times2 := 0
|
||||||
|
assert.NotNil(t, DoWithRetry(func() error {
|
||||||
|
times2++
|
||||||
|
if times2 == 2 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
time.Sleep(time.Millisecond * 150)
|
||||||
|
return errors.New("any ")
|
||||||
|
}, WithTimeout(time.Millisecond*250), WithInterval(time.Millisecond*150)))
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRetryCtx(t *testing.T) {
|
||||||
|
assert.NotNil(t, DoWithRetryCtx(context.Background(), func(ctx context.Context, retryCount int) error {
|
||||||
|
if retryCount == 0 {
|
||||||
|
return errors.New("any")
|
||||||
|
}
|
||||||
|
time.Sleep(time.Millisecond * 150)
|
||||||
|
return nil
|
||||||
|
}, WithTimeout(time.Millisecond*250), WithInterval(time.Millisecond*150)))
|
||||||
|
|
||||||
|
assert.NotNil(t, DoWithRetryCtx(context.Background(), func(ctx context.Context, retryCount int) error {
|
||||||
|
if retryCount == 1 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
time.Sleep(time.Millisecond * 150)
|
||||||
|
return errors.New("any ")
|
||||||
|
}, WithTimeout(time.Millisecond*250), WithInterval(time.Millisecond*150)))
|
||||||
|
}
|
||||||
|
|||||||
@@ -32,6 +32,10 @@ func (bp *BufferPool) Get() *bytes.Buffer {
|
|||||||
|
|
||||||
// Put returns buf into bp.
|
// Put returns buf into bp.
|
||||||
func (bp *BufferPool) Put(buf *bytes.Buffer) {
|
func (bp *BufferPool) Put(buf *bytes.Buffer) {
|
||||||
|
if buf == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if buf.Cap() < bp.capability {
|
if buf.Cap() < bp.capability {
|
||||||
bp.pool.Put(buf)
|
bp.pool.Put(buf)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,3 +13,26 @@ func TestBufferPool(t *testing.T) {
|
|||||||
pool.Put(bytes.NewBuffer(make([]byte, 0, 2*capacity)))
|
pool.Put(bytes.NewBuffer(make([]byte, 0, 2*capacity)))
|
||||||
assert.True(t, pool.Get().Cap() <= capacity)
|
assert.True(t, pool.Get().Cap() <= capacity)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBufferPool_Put(t *testing.T) {
|
||||||
|
t.Run("with nil buf", func(t *testing.T) {
|
||||||
|
pool := NewBufferPool(1024)
|
||||||
|
pool.Put(nil)
|
||||||
|
val := pool.Get()
|
||||||
|
assert.IsType(t, new(bytes.Buffer), val)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with less-cap buf", func(t *testing.T) {
|
||||||
|
pool := NewBufferPool(1024)
|
||||||
|
pool.Put(bytes.NewBuffer(make([]byte, 0, 512)))
|
||||||
|
val := pool.Get()
|
||||||
|
assert.IsType(t, new(bytes.Buffer), val)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with more-cap buf", func(t *testing.T) {
|
||||||
|
pool := NewBufferPool(1024)
|
||||||
|
pool.Put(bytes.NewBuffer(make([]byte, 0, 1024<<1)))
|
||||||
|
val := pool.Get()
|
||||||
|
assert.IsType(t, new(bytes.Buffer), val)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
12
core/iox/nopcloser_test.go
Normal file
12
core/iox/nopcloser_test.go
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
package iox
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNopCloser(t *testing.T) {
|
||||||
|
closer := NopCloser(nil)
|
||||||
|
assert.NoError(t, closer.Close())
|
||||||
|
}
|
||||||
@@ -35,6 +35,16 @@ func KeepSpace() TextReadOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LimitDupReadCloser returns two io.ReadCloser that read from the first will be written to the second.
|
||||||
|
// But the second io.ReadCloser is limited to up to n bytes.
|
||||||
|
// The first returned reader needs to be read first, because the content
|
||||||
|
// read from it will be written to the underlying buffer of the second reader.
|
||||||
|
func LimitDupReadCloser(reader io.ReadCloser, n int64) (io.ReadCloser, io.ReadCloser) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
tee := LimitTeeReader(reader, &buf, n)
|
||||||
|
return io.NopCloser(tee), io.NopCloser(&buf)
|
||||||
|
}
|
||||||
|
|
||||||
// ReadBytes reads exactly the bytes with the length of len(buf)
|
// ReadBytes reads exactly the bytes with the length of len(buf)
|
||||||
func ReadBytes(reader io.Reader, buf []byte) error {
|
func ReadBytes(reader io.Reader, buf []byte) error {
|
||||||
var got int
|
var got int
|
||||||
|
|||||||
@@ -40,17 +40,22 @@ b`,
|
|||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run(test.input, func(t *testing.T) {
|
t.Run(test.input, func(t *testing.T) {
|
||||||
tmpfile, err := fs.TempFilenameWithText(test.input)
|
tmpFile, err := fs.TempFilenameWithText(test.input)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
defer os.Remove(tmpfile)
|
defer os.Remove(tmpFile)
|
||||||
|
|
||||||
content, err := ReadText(tmpfile)
|
content, err := ReadText(tmpFile)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, test.expect, content)
|
assert.Equal(t, test.expect, content)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestReadTextError(t *testing.T) {
|
||||||
|
_, err := ReadText("not-exist")
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
func TestReadTextLines(t *testing.T) {
|
func TestReadTextLines(t *testing.T) {
|
||||||
text := `1
|
text := `1
|
||||||
|
|
||||||
@@ -59,9 +64,9 @@ func TestReadTextLines(t *testing.T) {
|
|||||||
#a
|
#a
|
||||||
3`
|
3`
|
||||||
|
|
||||||
tmpfile, err := fs.TempFilenameWithText(text)
|
tmpFile, err := fs.TempFilenameWithText(text)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
defer os.Remove(tmpfile)
|
defer os.Remove(tmpFile)
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
options []TextReadOption
|
options []TextReadOption
|
||||||
@@ -87,13 +92,18 @@ func TestReadTextLines(t *testing.T) {
|
|||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run(stringx.Rand(), func(t *testing.T) {
|
t.Run(stringx.Rand(), func(t *testing.T) {
|
||||||
lines, err := ReadTextLines(tmpfile, test.options...)
|
lines, err := ReadTextLines(tmpFile, test.options...)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, test.expectLines, len(lines))
|
assert.Equal(t, test.expectLines, len(lines))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestReadTextLinesError(t *testing.T) {
|
||||||
|
_, err := ReadTextLines("not-exist")
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
func TestDupReadCloser(t *testing.T) {
|
func TestDupReadCloser(t *testing.T) {
|
||||||
input := "hello"
|
input := "hello"
|
||||||
reader := io.NopCloser(bytes.NewBufferString(input))
|
reader := io.NopCloser(bytes.NewBufferString(input))
|
||||||
@@ -108,6 +118,29 @@ func TestDupReadCloser(t *testing.T) {
|
|||||||
verify(r2)
|
verify(r2)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestLimitDupReadCloser(t *testing.T) {
|
||||||
|
input := "hello world"
|
||||||
|
limitBytes := int64(4)
|
||||||
|
reader := io.NopCloser(bytes.NewBufferString(input))
|
||||||
|
r1, r2 := LimitDupReadCloser(reader, limitBytes)
|
||||||
|
verify := func(r io.Reader) {
|
||||||
|
output, err := io.ReadAll(r)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, input, string(output))
|
||||||
|
}
|
||||||
|
verifyLimit := func(r io.Reader, limit int64) {
|
||||||
|
output, err := io.ReadAll(r)
|
||||||
|
if limit < int64(len(input)) {
|
||||||
|
input = input[:limit]
|
||||||
|
}
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, input, string(output))
|
||||||
|
}
|
||||||
|
|
||||||
|
verify(r1)
|
||||||
|
verifyLimit(r2, limitBytes)
|
||||||
|
}
|
||||||
|
|
||||||
func TestReadBytes(t *testing.T) {
|
func TestReadBytes(t *testing.T) {
|
||||||
reader := io.NopCloser(bytes.NewBufferString("helloworld"))
|
reader := io.NopCloser(bytes.NewBufferString("helloworld"))
|
||||||
buf := make([]byte, 5)
|
buf := make([]byte, 5)
|
||||||
|
|||||||
35
core/iox/tee.go
Normal file
35
core/iox/tee.go
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
package iox
|
||||||
|
|
||||||
|
import "io"
|
||||||
|
|
||||||
|
// LimitTeeReader returns a Reader that writes up to n bytes to w what it reads from r.
|
||||||
|
// First n bytes reads from r performed through it are matched with
|
||||||
|
// corresponding writes to w. There is no internal buffering -
|
||||||
|
// the write must complete before the first n bytes read completes.
|
||||||
|
// Any error encountered while writing is reported as a read error.
|
||||||
|
func LimitTeeReader(r io.Reader, w io.Writer, n int64) io.Reader {
|
||||||
|
return &limitTeeReader{r, w, n}
|
||||||
|
}
|
||||||
|
|
||||||
|
type limitTeeReader struct {
|
||||||
|
r io.Reader
|
||||||
|
w io.Writer
|
||||||
|
n int64 // limit bytes remaining
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *limitTeeReader) Read(p []byte) (n int, err error) {
|
||||||
|
n, err = t.r.Read(p)
|
||||||
|
if n > 0 && t.n > 0 {
|
||||||
|
limit := int64(n)
|
||||||
|
if limit > t.n {
|
||||||
|
limit = t.n
|
||||||
|
}
|
||||||
|
if n, err := t.w.Write(p[:limit]); err != nil {
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
t.n -= limit
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
40
core/iox/tee_test.go
Normal file
40
core/iox/tee_test.go
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
package iox
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLimitTeeReader(t *testing.T) {
|
||||||
|
limit := int64(4)
|
||||||
|
src := []byte("hello, world")
|
||||||
|
dst := make([]byte, len(src))
|
||||||
|
rb := bytes.NewBuffer(src)
|
||||||
|
wb := new(bytes.Buffer)
|
||||||
|
r := LimitTeeReader(rb, wb, limit)
|
||||||
|
if n, err := io.ReadFull(r, dst); err != nil || n != len(src) {
|
||||||
|
t.Fatalf("ReadFull(r, dst) = %d, %v; want %d, nil", n, err, len(src))
|
||||||
|
}
|
||||||
|
if !bytes.Equal(dst, src) {
|
||||||
|
t.Errorf("bytes read = %q want %q", dst, src)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(wb.Bytes(), src[:limit]) {
|
||||||
|
t.Errorf("bytes written = %q want %q", wb.Bytes(), src)
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := r.Read(dst)
|
||||||
|
assert.Equal(t, 0, n)
|
||||||
|
assert.Equal(t, io.EOF, err)
|
||||||
|
|
||||||
|
rb = bytes.NewBuffer(src)
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
if assert.NoError(t, pr.Close()) {
|
||||||
|
r = LimitTeeReader(rb, pw, limit)
|
||||||
|
n, err := io.ReadFull(r, dst)
|
||||||
|
assert.Equal(t, 0, n)
|
||||||
|
assert.Equal(t, io.ErrClosedPipe, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -2,6 +2,7 @@ package iox
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
)
|
)
|
||||||
@@ -26,7 +27,7 @@ func CountLines(file string) (int, error) {
|
|||||||
count += bytes.Count(buf[:c], lineSep)
|
count += bytes.Count(buf[:c], lineSep)
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case err == io.EOF:
|
case errors.Is(err, io.EOF):
|
||||||
if noEol {
|
if noEol {
|
||||||
count++
|
count++
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,3 +24,8 @@ func TestCountLines(t *testing.T) {
|
|||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, 4, lines)
|
assert.Equal(t, 4, lines)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCountLinesError(t *testing.T) {
|
||||||
|
_, err := CountLines("not-exist")
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package iox
|
|||||||
import (
|
import (
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"testing/iotest"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
@@ -22,3 +23,10 @@ func TestScanner(t *testing.T) {
|
|||||||
}
|
}
|
||||||
assert.EqualValues(t, []string{"1", "2", "3", "4"}, lines)
|
assert.EqualValues(t, []string{"1", "2", "3", "4"}, lines)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBadScanner(t *testing.T) {
|
||||||
|
scanner := NewTextLineScanner(iotest.ErrReader(iotest.ErrTimeout))
|
||||||
|
assert.False(t, scanner.Scan())
|
||||||
|
_, err := scanner.Line()
|
||||||
|
assert.ErrorIs(t, err, iotest.ErrTimeout)
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package logc
|
package logc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -11,14 +10,11 @@ import (
|
|||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/zeromicro/go-zero/core/logx"
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
|
"github.com/zeromicro/go-zero/core/logx/logtest"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestAddGlobalFields(t *testing.T) {
|
func TestAddGlobalFields(t *testing.T) {
|
||||||
var buf bytes.Buffer
|
buf := logtest.NewCollector(t)
|
||||||
writer := logx.NewWriter(&buf)
|
|
||||||
old := logx.Reset()
|
|
||||||
logx.SetWriter(writer)
|
|
||||||
defer logx.SetWriter(old)
|
|
||||||
|
|
||||||
Info(context.Background(), "hello")
|
Info(context.Background(), "hello")
|
||||||
buf.Reset()
|
buf.Reset()
|
||||||
@@ -34,155 +30,90 @@ func TestAddGlobalFields(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAlert(t *testing.T) {
|
func TestAlert(t *testing.T) {
|
||||||
var buf strings.Builder
|
buf := logtest.NewCollector(t)
|
||||||
writer := logx.NewWriter(&buf)
|
|
||||||
old := logx.Reset()
|
|
||||||
logx.SetWriter(writer)
|
|
||||||
defer logx.SetWriter(old)
|
|
||||||
|
|
||||||
Alert(context.Background(), "foo")
|
Alert(context.Background(), "foo")
|
||||||
assert.True(t, strings.Contains(buf.String(), "foo"), buf.String())
|
assert.True(t, strings.Contains(buf.String(), "foo"), buf.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestError(t *testing.T) {
|
func TestError(t *testing.T) {
|
||||||
var buf strings.Builder
|
buf := logtest.NewCollector(t)
|
||||||
writer := logx.NewWriter(&buf)
|
|
||||||
old := logx.Reset()
|
|
||||||
logx.SetWriter(writer)
|
|
||||||
defer logx.SetWriter(old)
|
|
||||||
|
|
||||||
file, line := getFileLine()
|
file, line := getFileLine()
|
||||||
Error(context.Background(), "foo")
|
Error(context.Background(), "foo")
|
||||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestErrorf(t *testing.T) {
|
func TestErrorf(t *testing.T) {
|
||||||
var buf strings.Builder
|
buf := logtest.NewCollector(t)
|
||||||
writer := logx.NewWriter(&buf)
|
|
||||||
old := logx.Reset()
|
|
||||||
logx.SetWriter(writer)
|
|
||||||
defer logx.SetWriter(old)
|
|
||||||
|
|
||||||
file, line := getFileLine()
|
file, line := getFileLine()
|
||||||
Errorf(context.Background(), "foo %s", "bar")
|
Errorf(context.Background(), "foo %s", "bar")
|
||||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestErrorv(t *testing.T) {
|
func TestErrorv(t *testing.T) {
|
||||||
var buf strings.Builder
|
buf := logtest.NewCollector(t)
|
||||||
writer := logx.NewWriter(&buf)
|
|
||||||
old := logx.Reset()
|
|
||||||
logx.SetWriter(writer)
|
|
||||||
defer logx.SetWriter(old)
|
|
||||||
|
|
||||||
file, line := getFileLine()
|
file, line := getFileLine()
|
||||||
Errorv(context.Background(), "foo")
|
Errorv(context.Background(), "foo")
|
||||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestErrorw(t *testing.T) {
|
func TestErrorw(t *testing.T) {
|
||||||
var buf strings.Builder
|
buf := logtest.NewCollector(t)
|
||||||
writer := logx.NewWriter(&buf)
|
|
||||||
old := logx.Reset()
|
|
||||||
logx.SetWriter(writer)
|
|
||||||
defer logx.SetWriter(old)
|
|
||||||
|
|
||||||
file, line := getFileLine()
|
file, line := getFileLine()
|
||||||
Errorw(context.Background(), "foo", Field("a", "b"))
|
Errorw(context.Background(), "foo", Field("a", "b"))
|
||||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestInfo(t *testing.T) {
|
func TestInfo(t *testing.T) {
|
||||||
var buf strings.Builder
|
buf := logtest.NewCollector(t)
|
||||||
writer := logx.NewWriter(&buf)
|
|
||||||
old := logx.Reset()
|
|
||||||
logx.SetWriter(writer)
|
|
||||||
defer logx.SetWriter(old)
|
|
||||||
|
|
||||||
file, line := getFileLine()
|
file, line := getFileLine()
|
||||||
Info(context.Background(), "foo")
|
Info(context.Background(), "foo")
|
||||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestInfof(t *testing.T) {
|
func TestInfof(t *testing.T) {
|
||||||
var buf strings.Builder
|
buf := logtest.NewCollector(t)
|
||||||
writer := logx.NewWriter(&buf)
|
|
||||||
old := logx.Reset()
|
|
||||||
logx.SetWriter(writer)
|
|
||||||
defer logx.SetWriter(old)
|
|
||||||
|
|
||||||
file, line := getFileLine()
|
file, line := getFileLine()
|
||||||
Infof(context.Background(), "foo %s", "bar")
|
Infof(context.Background(), "foo %s", "bar")
|
||||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestInfov(t *testing.T) {
|
func TestInfov(t *testing.T) {
|
||||||
var buf strings.Builder
|
buf := logtest.NewCollector(t)
|
||||||
writer := logx.NewWriter(&buf)
|
|
||||||
old := logx.Reset()
|
|
||||||
logx.SetWriter(writer)
|
|
||||||
defer logx.SetWriter(old)
|
|
||||||
|
|
||||||
file, line := getFileLine()
|
file, line := getFileLine()
|
||||||
Infov(context.Background(), "foo")
|
Infov(context.Background(), "foo")
|
||||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestInfow(t *testing.T) {
|
func TestInfow(t *testing.T) {
|
||||||
var buf strings.Builder
|
buf := logtest.NewCollector(t)
|
||||||
writer := logx.NewWriter(&buf)
|
|
||||||
old := logx.Reset()
|
|
||||||
logx.SetWriter(writer)
|
|
||||||
defer logx.SetWriter(old)
|
|
||||||
|
|
||||||
file, line := getFileLine()
|
file, line := getFileLine()
|
||||||
Infow(context.Background(), "foo", Field("a", "b"))
|
Infow(context.Background(), "foo", Field("a", "b"))
|
||||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDebug(t *testing.T) {
|
func TestDebug(t *testing.T) {
|
||||||
var buf strings.Builder
|
buf := logtest.NewCollector(t)
|
||||||
writer := logx.NewWriter(&buf)
|
|
||||||
old := logx.Reset()
|
|
||||||
logx.SetWriter(writer)
|
|
||||||
defer logx.SetWriter(old)
|
|
||||||
|
|
||||||
file, line := getFileLine()
|
file, line := getFileLine()
|
||||||
Debug(context.Background(), "foo")
|
Debug(context.Background(), "foo")
|
||||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDebugf(t *testing.T) {
|
func TestDebugf(t *testing.T) {
|
||||||
var buf strings.Builder
|
buf := logtest.NewCollector(t)
|
||||||
writer := logx.NewWriter(&buf)
|
|
||||||
old := logx.Reset()
|
|
||||||
logx.SetWriter(writer)
|
|
||||||
defer logx.SetWriter(old)
|
|
||||||
|
|
||||||
file, line := getFileLine()
|
file, line := getFileLine()
|
||||||
Debugf(context.Background(), "foo %s", "bar")
|
Debugf(context.Background(), "foo %s", "bar")
|
||||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDebugv(t *testing.T) {
|
func TestDebugv(t *testing.T) {
|
||||||
var buf strings.Builder
|
buf := logtest.NewCollector(t)
|
||||||
writer := logx.NewWriter(&buf)
|
|
||||||
old := logx.Reset()
|
|
||||||
logx.SetWriter(writer)
|
|
||||||
defer logx.SetWriter(old)
|
|
||||||
|
|
||||||
file, line := getFileLine()
|
file, line := getFileLine()
|
||||||
Debugv(context.Background(), "foo")
|
Debugv(context.Background(), "foo")
|
||||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDebugw(t *testing.T) {
|
func TestDebugw(t *testing.T) {
|
||||||
var buf strings.Builder
|
buf := logtest.NewCollector(t)
|
||||||
writer := logx.NewWriter(&buf)
|
|
||||||
old := logx.Reset()
|
|
||||||
logx.SetWriter(writer)
|
|
||||||
defer logx.SetWriter(old)
|
|
||||||
|
|
||||||
file, line := getFileLine()
|
file, line := getFileLine()
|
||||||
Debugw(context.Background(), "foo", Field("a", "b"))
|
Debugw(context.Background(), "foo", Field("a", "b"))
|
||||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||||
@@ -204,48 +135,28 @@ func TestMisc(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSlow(t *testing.T) {
|
func TestSlow(t *testing.T) {
|
||||||
var buf strings.Builder
|
buf := logtest.NewCollector(t)
|
||||||
writer := logx.NewWriter(&buf)
|
|
||||||
old := logx.Reset()
|
|
||||||
logx.SetWriter(writer)
|
|
||||||
defer logx.SetWriter(old)
|
|
||||||
|
|
||||||
file, line := getFileLine()
|
file, line := getFileLine()
|
||||||
Slow(context.Background(), "foo")
|
Slow(context.Background(), "foo")
|
||||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)), buf.String())
|
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)), buf.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSlowf(t *testing.T) {
|
func TestSlowf(t *testing.T) {
|
||||||
var buf strings.Builder
|
buf := logtest.NewCollector(t)
|
||||||
writer := logx.NewWriter(&buf)
|
|
||||||
old := logx.Reset()
|
|
||||||
logx.SetWriter(writer)
|
|
||||||
defer logx.SetWriter(old)
|
|
||||||
|
|
||||||
file, line := getFileLine()
|
file, line := getFileLine()
|
||||||
Slowf(context.Background(), "foo %s", "bar")
|
Slowf(context.Background(), "foo %s", "bar")
|
||||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)), buf.String())
|
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)), buf.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSlowv(t *testing.T) {
|
func TestSlowv(t *testing.T) {
|
||||||
var buf strings.Builder
|
buf := logtest.NewCollector(t)
|
||||||
writer := logx.NewWriter(&buf)
|
|
||||||
old := logx.Reset()
|
|
||||||
logx.SetWriter(writer)
|
|
||||||
defer logx.SetWriter(old)
|
|
||||||
|
|
||||||
file, line := getFileLine()
|
file, line := getFileLine()
|
||||||
Slowv(context.Background(), "foo")
|
Slowv(context.Background(), "foo")
|
||||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)), buf.String())
|
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)), buf.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSloww(t *testing.T) {
|
func TestSloww(t *testing.T) {
|
||||||
var buf strings.Builder
|
buf := logtest.NewCollector(t)
|
||||||
writer := logx.NewWriter(&buf)
|
|
||||||
old := logx.Reset()
|
|
||||||
logx.SetWriter(writer)
|
|
||||||
defer logx.SetWriter(old)
|
|
||||||
|
|
||||||
file, line := getFileLine()
|
file, line := getFileLine()
|
||||||
Sloww(context.Background(), "foo", Field("a", "b"))
|
Sloww(context.Background(), "foo", Field("a", "b"))
|
||||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)), buf.String())
|
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)), buf.String())
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ type LogConf struct {
|
|||||||
StackCooldownMillis int `json:",default=100"`
|
StackCooldownMillis int `json:",default=100"`
|
||||||
// MaxBackups represents how many backup log files will be kept. 0 means all files will be kept forever.
|
// MaxBackups represents how many backup log files will be kept. 0 means all files will be kept forever.
|
||||||
// Only take effect when RotationRuleType is `size`.
|
// Only take effect when RotationRuleType is `size`.
|
||||||
// Even thougth `MaxBackups` sets 0, log files will still be removed
|
// Even though `MaxBackups` sets 0, log files will still be removed
|
||||||
// if the `KeepDays` limitation is reached.
|
// if the `KeepDays` limitation is reached.
|
||||||
MaxBackups int `json:",default=0"`
|
MaxBackups int `json:",default=0"`
|
||||||
// MaxSize represents how much space the writing log file takes up. 0 means no limit. The unit is `MB`.
|
// MaxSize represents how much space the writing log file takes up. 0 means no limit. The unit is `MB`.
|
||||||
|
|||||||
40
core/logx/fs.go
Normal file
40
core/logx/fs.go
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
package logx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
var fileSys realFileSystem
|
||||||
|
|
||||||
|
type (
|
||||||
|
fileSystem interface {
|
||||||
|
Close(closer io.Closer) error
|
||||||
|
Copy(writer io.Writer, reader io.Reader) (int64, error)
|
||||||
|
Create(name string) (*os.File, error)
|
||||||
|
Open(name string) (*os.File, error)
|
||||||
|
Remove(name string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
realFileSystem struct{}
|
||||||
|
)
|
||||||
|
|
||||||
|
func (fs realFileSystem) Close(closer io.Closer) error {
|
||||||
|
return closer.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fs realFileSystem) Copy(writer io.Writer, reader io.Reader) (int64, error) {
|
||||||
|
return io.Copy(writer, reader)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fs realFileSystem) Create(name string) (*os.File, error) {
|
||||||
|
return os.Create(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fs realFileSystem) Open(name string) (*os.File, error) {
|
||||||
|
return os.Open(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fs realFileSystem) Remove(name string) error {
|
||||||
|
return os.Remove(name)
|
||||||
|
}
|
||||||
@@ -68,22 +68,30 @@ func Close() error {
|
|||||||
|
|
||||||
// Debug writes v into access log.
|
// Debug writes v into access log.
|
||||||
func Debug(v ...any) {
|
func Debug(v ...any) {
|
||||||
writeDebug(fmt.Sprint(v...))
|
if shallLog(DebugLevel) {
|
||||||
|
writeDebug(fmt.Sprint(v...))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Debugf writes v with format into access log.
|
// Debugf writes v with format into access log.
|
||||||
func Debugf(format string, v ...any) {
|
func Debugf(format string, v ...any) {
|
||||||
writeDebug(fmt.Sprintf(format, v...))
|
if shallLog(DebugLevel) {
|
||||||
|
writeDebug(fmt.Sprintf(format, v...))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Debugv writes v into access log with json content.
|
// Debugv writes v into access log with json content.
|
||||||
func Debugv(v any) {
|
func Debugv(v any) {
|
||||||
writeDebug(v)
|
if shallLog(DebugLevel) {
|
||||||
|
writeDebug(v)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Debugw writes msg along with fields into access log.
|
// Debugw writes msg along with fields into access log.
|
||||||
func Debugw(msg string, fields ...LogField) {
|
func Debugw(msg string, fields ...LogField) {
|
||||||
writeDebug(msg, fields...)
|
if shallLog(DebugLevel) {
|
||||||
|
writeDebug(msg, fields...)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Disable disables the logging.
|
// Disable disables the logging.
|
||||||
@@ -99,35 +107,47 @@ func DisableStat() {
|
|||||||
|
|
||||||
// Error writes v into error log.
|
// Error writes v into error log.
|
||||||
func Error(v ...any) {
|
func Error(v ...any) {
|
||||||
writeError(fmt.Sprint(v...))
|
if shallLog(ErrorLevel) {
|
||||||
|
writeError(fmt.Sprint(v...))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Errorf writes v with format into error log.
|
// Errorf writes v with format into error log.
|
||||||
func Errorf(format string, v ...any) {
|
func Errorf(format string, v ...any) {
|
||||||
writeError(fmt.Errorf(format, v...).Error())
|
if shallLog(ErrorLevel) {
|
||||||
|
writeError(fmt.Errorf(format, v...).Error())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ErrorStack writes v along with call stack into error log.
|
// ErrorStack writes v along with call stack into error log.
|
||||||
func ErrorStack(v ...any) {
|
func ErrorStack(v ...any) {
|
||||||
// there is newline in stack string
|
if shallLog(ErrorLevel) {
|
||||||
writeStack(fmt.Sprint(v...))
|
// there is newline in stack string
|
||||||
|
writeStack(fmt.Sprint(v...))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ErrorStackf writes v along with call stack in format into error log.
|
// ErrorStackf writes v along with call stack in format into error log.
|
||||||
func ErrorStackf(format string, v ...any) {
|
func ErrorStackf(format string, v ...any) {
|
||||||
// there is newline in stack string
|
if shallLog(ErrorLevel) {
|
||||||
writeStack(fmt.Sprintf(format, v...))
|
// there is newline in stack string
|
||||||
|
writeStack(fmt.Sprintf(format, v...))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Errorv writes v into error log with json content.
|
// Errorv writes v into error log with json content.
|
||||||
// No call stack attached, because not elegant to pack the messages.
|
// No call stack attached, because not elegant to pack the messages.
|
||||||
func Errorv(v any) {
|
func Errorv(v any) {
|
||||||
writeError(v)
|
if shallLog(ErrorLevel) {
|
||||||
|
writeError(v)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Errorw writes msg along with fields into error log.
|
// Errorw writes msg along with fields into error log.
|
||||||
func Errorw(msg string, fields ...LogField) {
|
func Errorw(msg string, fields ...LogField) {
|
||||||
writeError(msg, fields...)
|
if shallLog(ErrorLevel) {
|
||||||
|
writeError(msg, fields...)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Field returns a LogField for the given key and value.
|
// Field returns a LogField for the given key and value.
|
||||||
@@ -170,22 +190,30 @@ func Field(key string, value any) LogField {
|
|||||||
|
|
||||||
// Info writes v into access log.
|
// Info writes v into access log.
|
||||||
func Info(v ...any) {
|
func Info(v ...any) {
|
||||||
writeInfo(fmt.Sprint(v...))
|
if shallLog(InfoLevel) {
|
||||||
|
writeInfo(fmt.Sprint(v...))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Infof writes v with format into access log.
|
// Infof writes v with format into access log.
|
||||||
func Infof(format string, v ...any) {
|
func Infof(format string, v ...any) {
|
||||||
writeInfo(fmt.Sprintf(format, v...))
|
if shallLog(InfoLevel) {
|
||||||
|
writeInfo(fmt.Sprintf(format, v...))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Infov writes v into access log with json content.
|
// Infov writes v into access log with json content.
|
||||||
func Infov(v any) {
|
func Infov(v any) {
|
||||||
writeInfo(v)
|
if shallLog(InfoLevel) {
|
||||||
|
writeInfo(v)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Infow writes msg along with fields into access log.
|
// Infow writes msg along with fields into access log.
|
||||||
func Infow(msg string, fields ...LogField) {
|
func Infow(msg string, fields ...LogField) {
|
||||||
writeInfo(msg, fields...)
|
if shallLog(InfoLevel) {
|
||||||
|
writeInfo(msg, fields...)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Must checks if err is nil, otherwise logs the error and exits.
|
// Must checks if err is nil, otherwise logs the error and exits.
|
||||||
@@ -194,7 +222,7 @@ func Must(err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
msg := err.Error()
|
msg := fmt.Sprintf("%+v\n\n%s", err.Error(), debug.Stack())
|
||||||
log.Print(msg)
|
log.Print(msg)
|
||||||
getWriter().Severe(msg)
|
getWriter().Severe(msg)
|
||||||
|
|
||||||
@@ -269,42 +297,58 @@ func SetUp(c LogConf) (err error) {
|
|||||||
|
|
||||||
// Severe writes v into severe log.
|
// Severe writes v into severe log.
|
||||||
func Severe(v ...any) {
|
func Severe(v ...any) {
|
||||||
writeSevere(fmt.Sprint(v...))
|
if shallLog(SevereLevel) {
|
||||||
|
writeSevere(fmt.Sprint(v...))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Severef writes v with format into severe log.
|
// Severef writes v with format into severe log.
|
||||||
func Severef(format string, v ...any) {
|
func Severef(format string, v ...any) {
|
||||||
writeSevere(fmt.Sprintf(format, v...))
|
if shallLog(SevereLevel) {
|
||||||
|
writeSevere(fmt.Sprintf(format, v...))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Slow writes v into slow log.
|
// Slow writes v into slow log.
|
||||||
func Slow(v ...any) {
|
func Slow(v ...any) {
|
||||||
writeSlow(fmt.Sprint(v...))
|
if shallLog(ErrorLevel) {
|
||||||
|
writeSlow(fmt.Sprint(v...))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Slowf writes v with format into slow log.
|
// Slowf writes v with format into slow log.
|
||||||
func Slowf(format string, v ...any) {
|
func Slowf(format string, v ...any) {
|
||||||
writeSlow(fmt.Sprintf(format, v...))
|
if shallLog(ErrorLevel) {
|
||||||
|
writeSlow(fmt.Sprintf(format, v...))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Slowv writes v into slow log with json content.
|
// Slowv writes v into slow log with json content.
|
||||||
func Slowv(v any) {
|
func Slowv(v any) {
|
||||||
writeSlow(v)
|
if shallLog(ErrorLevel) {
|
||||||
|
writeSlow(v)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sloww writes msg along with fields into slow log.
|
// Sloww writes msg along with fields into slow log.
|
||||||
func Sloww(msg string, fields ...LogField) {
|
func Sloww(msg string, fields ...LogField) {
|
||||||
writeSlow(msg, fields...)
|
if shallLog(ErrorLevel) {
|
||||||
|
writeSlow(msg, fields...)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stat writes v into stat log.
|
// Stat writes v into stat log.
|
||||||
func Stat(v ...any) {
|
func Stat(v ...any) {
|
||||||
writeStat(fmt.Sprint(v...))
|
if shallLogStat() && shallLog(InfoLevel) {
|
||||||
|
writeStat(fmt.Sprint(v...))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Statf writes v with format into stat log.
|
// Statf writes v with format into stat log.
|
||||||
func Statf(format string, v ...any) {
|
func Statf(format string, v ...any) {
|
||||||
writeStat(fmt.Sprintf(format, v...))
|
if shallLogStat() && shallLog(InfoLevel) {
|
||||||
|
writeStat(fmt.Sprintf(format, v...))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithCooldownMillis customizes logging on writing call stack interval.
|
// WithCooldownMillis customizes logging on writing call stack interval.
|
||||||
@@ -358,14 +402,16 @@ func createOutput(path string) (io.WriteCloser, error) {
|
|||||||
return nil, ErrLogPathNotSet
|
return nil, ErrLogPathNotSet
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var rule RotateRule
|
||||||
switch options.rotationRule {
|
switch options.rotationRule {
|
||||||
case sizeRotationRule:
|
case sizeRotationRule:
|
||||||
return NewLogger(path, NewSizeLimitRotateRule(path, backupFileDelimiter, options.keepDays,
|
rule = NewSizeLimitRotateRule(path, backupFileDelimiter, options.keepDays, options.maxSize,
|
||||||
options.maxSize, options.maxBackups, options.gzipEnabled), options.gzipEnabled)
|
options.maxBackups, options.gzipEnabled)
|
||||||
default:
|
default:
|
||||||
return NewLogger(path, DefaultRotateRule(path, backupFileDelimiter, options.keepDays,
|
rule = DefaultRotateRule(path, backupFileDelimiter, options.keepDays, options.gzipEnabled)
|
||||||
options.gzipEnabled), options.gzipEnabled)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return NewLogger(path, rule, options.gzipEnabled)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getWriter() Writer {
|
func getWriter() Writer {
|
||||||
@@ -427,44 +473,58 @@ func shallLogStat() bool {
|
|||||||
return atomic.LoadUint32(&disableStat) == 0
|
return atomic.LoadUint32(&disableStat) == 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// writeDebug writes v into debug log.
|
||||||
|
// Not checking shallLog here is for performance consideration.
|
||||||
|
// If we check shallLog here, the fmt.Sprint might be called even if the log level is not enabled.
|
||||||
|
// The caller should check shallLog before calling this function.
|
||||||
func writeDebug(val any, fields ...LogField) {
|
func writeDebug(val any, fields ...LogField) {
|
||||||
if shallLog(DebugLevel) {
|
getWriter().Debug(val, addCaller(fields...)...)
|
||||||
getWriter().Debug(val, addCaller(fields...)...)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// writeError writes v into error log.
|
||||||
|
// Not checking shallLog here is for performance consideration.
|
||||||
|
// If we check shallLog here, the fmt.Sprint might be called even if the log level is not enabled.
|
||||||
|
// The caller should check shallLog before calling this function.
|
||||||
func writeError(val any, fields ...LogField) {
|
func writeError(val any, fields ...LogField) {
|
||||||
if shallLog(ErrorLevel) {
|
getWriter().Error(val, addCaller(fields...)...)
|
||||||
getWriter().Error(val, addCaller(fields...)...)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// writeInfo writes v into info log.
|
||||||
|
// Not checking shallLog here is for performance consideration.
|
||||||
|
// If we check shallLog here, the fmt.Sprint might be called even if the log level is not enabled.
|
||||||
|
// The caller should check shallLog before calling this function.
|
||||||
func writeInfo(val any, fields ...LogField) {
|
func writeInfo(val any, fields ...LogField) {
|
||||||
if shallLog(InfoLevel) {
|
getWriter().Info(val, addCaller(fields...)...)
|
||||||
getWriter().Info(val, addCaller(fields...)...)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// writeSevere writes v into severe log.
|
||||||
|
// Not checking shallLog here is for performance consideration.
|
||||||
|
// If we check shallLog here, the fmt.Sprint might be called even if the log level is not enabled.
|
||||||
|
// The caller should check shallLog before calling this function.
|
||||||
func writeSevere(msg string) {
|
func writeSevere(msg string) {
|
||||||
if shallLog(SevereLevel) {
|
getWriter().Severe(fmt.Sprintf("%s\n%s", msg, string(debug.Stack())))
|
||||||
getWriter().Severe(fmt.Sprintf("%s\n%s", msg, string(debug.Stack())))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// writeSlow writes v into slow log.
|
||||||
|
// Not checking shallLog here is for performance consideration.
|
||||||
|
// If we check shallLog here, the fmt.Sprint might be called even if the log level is not enabled.
|
||||||
|
// The caller should check shallLog before calling this function.
|
||||||
func writeSlow(val any, fields ...LogField) {
|
func writeSlow(val any, fields ...LogField) {
|
||||||
if shallLog(ErrorLevel) {
|
getWriter().Slow(val, addCaller(fields...)...)
|
||||||
getWriter().Slow(val, addCaller(fields...)...)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// writeStack writes v into stack log.
|
||||||
|
// Not checking shallLog here is for performance consideration.
|
||||||
|
// If we check shallLog here, the fmt.Sprint might be called even if the log level is not enabled.
|
||||||
|
// The caller should check shallLog before calling this function.
|
||||||
func writeStack(msg string) {
|
func writeStack(msg string) {
|
||||||
if shallLog(ErrorLevel) {
|
getWriter().Stack(fmt.Sprintf("%s\n%s", msg, string(debug.Stack())))
|
||||||
getWriter().Stack(fmt.Sprintf("%s\n%s", msg, string(debug.Stack())))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// writeStat writes v into stat log.
|
||||||
|
// Not checking shallLog here is for performance consideration.
|
||||||
|
// If we check shallLog here, the fmt.Sprint might be called even if the log level is not enabled.
|
||||||
|
// The caller should check shallLog before calling this function.
|
||||||
func writeStat(msg string) {
|
func writeStat(msg string) {
|
||||||
if shallLogStat() && shallLog(InfoLevel) {
|
getWriter().Stat(msg, addCaller()...)
|
||||||
getWriter().Stat(msg, addCaller()...)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
84
core/logx/logtest/logtest.go
Normal file
84
core/logx/logtest/logtest.go
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
package logtest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Buffer struct {
|
||||||
|
buf *bytes.Buffer
|
||||||
|
t *testing.T
|
||||||
|
}
|
||||||
|
|
||||||
|
func Discard(t *testing.T) {
|
||||||
|
prev := logx.Reset()
|
||||||
|
logx.SetWriter(logx.NewWriter(io.Discard))
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
logx.SetWriter(prev)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewCollector(t *testing.T) *Buffer {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writer := logx.NewWriter(&buf)
|
||||||
|
prev := logx.Reset()
|
||||||
|
logx.SetWriter(writer)
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
logx.SetWriter(prev)
|
||||||
|
})
|
||||||
|
|
||||||
|
return &Buffer{
|
||||||
|
buf: &buf,
|
||||||
|
t: t,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) Bytes() []byte {
|
||||||
|
return b.buf.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) Content() string {
|
||||||
|
var m map[string]interface{}
|
||||||
|
if err := json.Unmarshal(b.buf.Bytes(), &m); err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
content, ok := m["content"]
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
switch val := content.(type) {
|
||||||
|
case string:
|
||||||
|
return val
|
||||||
|
default:
|
||||||
|
// err is impossible to be not nil, unmarshaled from b.buf.Bytes()
|
||||||
|
bs, _ := json.Marshal(content)
|
||||||
|
return string(bs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) Reset() {
|
||||||
|
b.buf.Reset()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) String() string {
|
||||||
|
return b.buf.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func PanicOnFatal(t *testing.T) {
|
||||||
|
ok := logx.ExitOnFatal.CompareAndSwap(true, false)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
logx.ExitOnFatal.CompareAndSwap(false, true)
|
||||||
|
})
|
||||||
|
}
|
||||||
44
core/logx/logtest/logtest_test.go
Normal file
44
core/logx/logtest/logtest_test.go
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
package logtest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCollector(t *testing.T) {
|
||||||
|
const input = "hello"
|
||||||
|
c := NewCollector(t)
|
||||||
|
logx.Info(input)
|
||||||
|
assert.Equal(t, input, c.Content())
|
||||||
|
assert.Contains(t, c.String(), input)
|
||||||
|
c.Reset()
|
||||||
|
assert.Empty(t, c.Bytes())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPanicOnFatal(t *testing.T) {
|
||||||
|
const input = "hello"
|
||||||
|
Discard(t)
|
||||||
|
logx.Info(input)
|
||||||
|
|
||||||
|
PanicOnFatal(t)
|
||||||
|
PanicOnFatal(t)
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
logx.Must(errors.New("foo"))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCollectorContent(t *testing.T) {
|
||||||
|
const input = "hello"
|
||||||
|
c := NewCollector(t)
|
||||||
|
c.buf.WriteString(input)
|
||||||
|
assert.Empty(t, c.Content())
|
||||||
|
c.Reset()
|
||||||
|
c.buf.WriteString(`{}`)
|
||||||
|
assert.Empty(t, c.Content())
|
||||||
|
c.Reset()
|
||||||
|
c.buf.WriteString(`{"content":1}`)
|
||||||
|
assert.Equal(t, "1", c.Content())
|
||||||
|
}
|
||||||
@@ -65,7 +65,7 @@ func (l *richLogger) Errorf(format string, v ...any) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (l *richLogger) Errorv(v any) {
|
func (l *richLogger) Errorv(v any) {
|
||||||
l.err(fmt.Sprint(v))
|
l.err(v)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *richLogger) Errorw(msg string, fields ...LogField) {
|
func (l *richLogger) Errorw(msg string, fields ...LogField) {
|
||||||
|
|||||||
@@ -66,6 +66,9 @@ func TestTraceDebug(t *testing.T) {
|
|||||||
l.WithDuration(time.Second).Debugv(testlog)
|
l.WithDuration(time.Second).Debugv(testlog)
|
||||||
validate(t, w.String(), true, true)
|
validate(t, w.String(), true, true)
|
||||||
w.Reset()
|
w.Reset()
|
||||||
|
l.WithDuration(time.Second).Debugv(testobj)
|
||||||
|
validateContentType(t, w.String(), map[string]any{}, true, true)
|
||||||
|
w.Reset()
|
||||||
l.WithDuration(time.Second).Debugw(testlog, Field("foo", "bar"))
|
l.WithDuration(time.Second).Debugw(testlog, Field("foo", "bar"))
|
||||||
validate(t, w.String(), true, true)
|
validate(t, w.String(), true, true)
|
||||||
assert.True(t, strings.Contains(w.String(), "foo"), w.String())
|
assert.True(t, strings.Contains(w.String(), "foo"), w.String())
|
||||||
@@ -103,6 +106,9 @@ func TestTraceError(t *testing.T) {
|
|||||||
l.WithDuration(time.Second).Errorv(testlog)
|
l.WithDuration(time.Second).Errorv(testlog)
|
||||||
validate(t, w.String(), true, true)
|
validate(t, w.String(), true, true)
|
||||||
w.Reset()
|
w.Reset()
|
||||||
|
l.WithDuration(time.Second).Errorv(testobj)
|
||||||
|
validateContentType(t, w.String(), map[string]any{}, true, true)
|
||||||
|
w.Reset()
|
||||||
l.WithDuration(time.Second).Errorw(testlog, Field("basket", "ball"))
|
l.WithDuration(time.Second).Errorw(testlog, Field("basket", "ball"))
|
||||||
validate(t, w.String(), true, true)
|
validate(t, w.String(), true, true)
|
||||||
assert.True(t, strings.Contains(w.String(), "basket"), w.String())
|
assert.True(t, strings.Contains(w.String(), "basket"), w.String())
|
||||||
@@ -137,6 +143,9 @@ func TestTraceInfo(t *testing.T) {
|
|||||||
l.WithDuration(time.Second).Infov(testlog)
|
l.WithDuration(time.Second).Infov(testlog)
|
||||||
validate(t, w.String(), true, true)
|
validate(t, w.String(), true, true)
|
||||||
w.Reset()
|
w.Reset()
|
||||||
|
l.WithDuration(time.Second).Infov(testobj)
|
||||||
|
validateContentType(t, w.String(), map[string]any{}, true, true)
|
||||||
|
w.Reset()
|
||||||
l.WithDuration(time.Second).Infow(testlog, Field("basket", "ball"))
|
l.WithDuration(time.Second).Infow(testlog, Field("basket", "ball"))
|
||||||
validate(t, w.String(), true, true)
|
validate(t, w.String(), true, true)
|
||||||
assert.True(t, strings.Contains(w.String(), "basket"), w.String())
|
assert.True(t, strings.Contains(w.String(), "basket"), w.String())
|
||||||
@@ -173,6 +182,9 @@ func TestTraceInfoConsole(t *testing.T) {
|
|||||||
w.Reset()
|
w.Reset()
|
||||||
l.WithDuration(time.Second).Infov(testlog)
|
l.WithDuration(time.Second).Infov(testlog)
|
||||||
validate(t, w.String(), true, true)
|
validate(t, w.String(), true, true)
|
||||||
|
w.Reset()
|
||||||
|
l.WithDuration(time.Second).Infov(testobj)
|
||||||
|
validateContentType(t, w.String(), map[string]any{}, true, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTraceSlow(t *testing.T) {
|
func TestTraceSlow(t *testing.T) {
|
||||||
@@ -204,6 +216,9 @@ func TestTraceSlow(t *testing.T) {
|
|||||||
l.WithDuration(time.Second).Slowv(testlog)
|
l.WithDuration(time.Second).Slowv(testlog)
|
||||||
validate(t, w.String(), true, true)
|
validate(t, w.String(), true, true)
|
||||||
w.Reset()
|
w.Reset()
|
||||||
|
l.WithDuration(time.Second).Slowv(testobj)
|
||||||
|
validateContentType(t, w.String(), map[string]any{}, true, true)
|
||||||
|
w.Reset()
|
||||||
l.WithDuration(time.Second).Sloww(testlog, Field("basket", "ball"))
|
l.WithDuration(time.Second).Sloww(testlog, Field("basket", "ball"))
|
||||||
validate(t, w.String(), true, true)
|
validate(t, w.String(), true, true)
|
||||||
assert.True(t, strings.Contains(w.String(), "basket"), w.String())
|
assert.True(t, strings.Contains(w.String(), "basket"), w.String())
|
||||||
@@ -311,8 +326,32 @@ func validate(t *testing.T, body string, expectedTrace, expectedSpan bool) {
|
|||||||
assert.Equal(t, expectedSpan, len(val.Span) > 0, body)
|
assert.Equal(t, expectedSpan, len(val.Span) > 0, body)
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockValue struct {
|
func validateContentType(t *testing.T, body string, expectedType any, expectedTrace, expectedSpan bool) {
|
||||||
Trace string `json:"trace"`
|
var val mockValue
|
||||||
Span string `json:"span"`
|
dec := json.NewDecoder(strings.NewReader(body))
|
||||||
Foo string `json:"foo"`
|
|
||||||
|
for {
|
||||||
|
var doc mockValue
|
||||||
|
err := dec.Decode(&doc)
|
||||||
|
if err == io.EOF {
|
||||||
|
// all done
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
val = doc
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.IsType(t, expectedType, val.Content, body)
|
||||||
|
assert.Equal(t, expectedTrace, len(val.Trace) > 0, body)
|
||||||
|
assert.Equal(t, expectedSpan, len(val.Span) > 0, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockValue struct {
|
||||||
|
Trace string `json:"trace"`
|
||||||
|
Span string `json:"span"`
|
||||||
|
Foo string `json:"foo"`
|
||||||
|
Content any `json:"content"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"compress/gzip"
|
"compress/gzip"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
@@ -299,6 +298,7 @@ func (l *RotateLogger) initialize() error {
|
|||||||
if l.fp, err = os.OpenFile(l.filename, os.O_APPEND|os.O_WRONLY, defaultFileMode); err != nil {
|
if l.fp, err = os.OpenFile(l.filename, os.O_APPEND|os.O_WRONLY, defaultFileMode); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
l.currentSize = fileInfo.Size()
|
l.currentSize = fileInfo.Size()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -382,7 +382,15 @@ func (l *RotateLogger) startWorker() {
|
|||||||
case event := <-l.channel:
|
case event := <-l.channel:
|
||||||
l.write(event)
|
l.write(event)
|
||||||
case <-l.done:
|
case <-l.done:
|
||||||
return
|
// avoid losing logs before closing.
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case event := <-l.channel:
|
||||||
|
l.write(event)
|
||||||
|
default:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -406,7 +414,7 @@ func (l *RotateLogger) write(v []byte) {
|
|||||||
func compressLogFile(file string) {
|
func compressLogFile(file string) {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
Infof("compressing log file: %s", file)
|
Infof("compressing log file: %s", file)
|
||||||
if err := gzipFile(file); err != nil {
|
if err := gzipFile(file, fileSys); err != nil {
|
||||||
Errorf("compress error: %s", err)
|
Errorf("compress error: %s", err)
|
||||||
} else {
|
} else {
|
||||||
Infof("compressed log file: %s, took %s", file, time.Since(start))
|
Infof("compressed log file: %s, took %s", file, time.Since(start))
|
||||||
@@ -421,25 +429,37 @@ func getNowDateInRFC3339Format() string {
|
|||||||
return time.Now().Format(fileTimeFormat)
|
return time.Now().Format(fileTimeFormat)
|
||||||
}
|
}
|
||||||
|
|
||||||
func gzipFile(file string) error {
|
func gzipFile(file string, fsys fileSystem) (err error) {
|
||||||
in, err := os.Open(file)
|
in, err := fsys.Open(file)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer in.Close()
|
defer func() {
|
||||||
|
if e := fsys.Close(in); e != nil {
|
||||||
|
Errorf("failed to close file: %s, error: %v", file, e)
|
||||||
|
}
|
||||||
|
if err == nil {
|
||||||
|
// only remove the original file when compression is successful
|
||||||
|
err = fsys.Remove(file)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
out, err := os.Create(fmt.Sprintf("%s%s", file, gzipExt))
|
out, err := fsys.Create(fmt.Sprintf("%s%s", file, gzipExt))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer out.Close()
|
defer func() {
|
||||||
|
e := fsys.Close(out)
|
||||||
|
if err == nil {
|
||||||
|
err = e
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
w := gzip.NewWriter(out)
|
w := gzip.NewWriter(out)
|
||||||
if _, err = io.Copy(w, in); err != nil {
|
if _, err = fsys.Copy(w, in); err != nil {
|
||||||
return err
|
// failed to copy, no need to close w
|
||||||
} else if err = w.Close(); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return os.Remove(file)
|
return fsys.Close(w)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,12 @@
|
|||||||
package logx
|
package logx
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
"os"
|
"os"
|
||||||
|
"path"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"sync/atomic"
|
||||||
"syscall"
|
"syscall"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -13,18 +17,58 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestDailyRotateRuleMarkRotated(t *testing.T) {
|
func TestDailyRotateRuleMarkRotated(t *testing.T) {
|
||||||
var rule DailyRotateRule
|
t.Run("daily rule", func(t *testing.T) {
|
||||||
rule.MarkRotated()
|
var rule DailyRotateRule
|
||||||
assert.Equal(t, getNowDate(), rule.rotatedTime)
|
rule.MarkRotated()
|
||||||
|
assert.Equal(t, getNowDate(), rule.rotatedTime)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("daily rule", func(t *testing.T) {
|
||||||
|
rule := DefaultRotateRule("test", "-", 1, false)
|
||||||
|
_, ok := rule.(*DailyRotateRule)
|
||||||
|
assert.True(t, ok)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDailyRotateRuleOutdatedFiles(t *testing.T) {
|
func TestDailyRotateRuleOutdatedFiles(t *testing.T) {
|
||||||
var rule DailyRotateRule
|
t.Run("no files", func(t *testing.T) {
|
||||||
assert.Empty(t, rule.OutdatedFiles())
|
var rule DailyRotateRule
|
||||||
rule.days = 1
|
assert.Empty(t, rule.OutdatedFiles())
|
||||||
assert.Empty(t, rule.OutdatedFiles())
|
rule.days = 1
|
||||||
rule.gzip = true
|
assert.Empty(t, rule.OutdatedFiles())
|
||||||
assert.Empty(t, rule.OutdatedFiles())
|
rule.gzip = true
|
||||||
|
assert.Empty(t, rule.OutdatedFiles())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("bad files", func(t *testing.T) {
|
||||||
|
rule := DailyRotateRule{
|
||||||
|
filename: "[a-z",
|
||||||
|
}
|
||||||
|
assert.Empty(t, rule.OutdatedFiles())
|
||||||
|
rule.days = 1
|
||||||
|
assert.Empty(t, rule.OutdatedFiles())
|
||||||
|
rule.gzip = true
|
||||||
|
assert.Empty(t, rule.OutdatedFiles())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("temp files", func(t *testing.T) {
|
||||||
|
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
|
||||||
|
f1, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
_ = f1.Close()
|
||||||
|
f2, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
_ = f2.Close()
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = os.Remove(f1.Name())
|
||||||
|
_ = os.Remove(f2.Name())
|
||||||
|
})
|
||||||
|
rule := DailyRotateRule{
|
||||||
|
filename: path.Join(os.TempDir(), "go-zero-test-"),
|
||||||
|
days: 1,
|
||||||
|
}
|
||||||
|
assert.NotEmpty(t, rule.OutdatedFiles())
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDailyRotateRuleShallRotate(t *testing.T) {
|
func TestDailyRotateRuleShallRotate(t *testing.T) {
|
||||||
@@ -34,20 +78,101 @@ func TestDailyRotateRuleShallRotate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSizeLimitRotateRuleMarkRotated(t *testing.T) {
|
func TestSizeLimitRotateRuleMarkRotated(t *testing.T) {
|
||||||
var rule SizeLimitRotateRule
|
t.Run("size limit rule", func(t *testing.T) {
|
||||||
rule.MarkRotated()
|
var rule SizeLimitRotateRule
|
||||||
assert.Equal(t, getNowDateInRFC3339Format(), rule.rotatedTime)
|
rule.MarkRotated()
|
||||||
|
assert.Equal(t, getNowDateInRFC3339Format(), rule.rotatedTime)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("size limit rule", func(t *testing.T) {
|
||||||
|
rule := NewSizeLimitRotateRule("foo", "-", 1, 1, 1, false)
|
||||||
|
rule.MarkRotated()
|
||||||
|
assert.Equal(t, getNowDateInRFC3339Format(), rule.(*SizeLimitRotateRule).rotatedTime)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSizeLimitRotateRuleOutdatedFiles(t *testing.T) {
|
func TestSizeLimitRotateRuleOutdatedFiles(t *testing.T) {
|
||||||
var rule SizeLimitRotateRule
|
t.Run("no files", func(t *testing.T) {
|
||||||
assert.Empty(t, rule.OutdatedFiles())
|
var rule SizeLimitRotateRule
|
||||||
rule.days = 1
|
assert.Empty(t, rule.OutdatedFiles())
|
||||||
assert.Empty(t, rule.OutdatedFiles())
|
rule.days = 1
|
||||||
rule.gzip = true
|
assert.Empty(t, rule.OutdatedFiles())
|
||||||
assert.Empty(t, rule.OutdatedFiles())
|
rule.gzip = true
|
||||||
rule.maxBackups = 0
|
assert.Empty(t, rule.OutdatedFiles())
|
||||||
assert.Empty(t, rule.OutdatedFiles())
|
rule.maxBackups = 0
|
||||||
|
assert.Empty(t, rule.OutdatedFiles())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("bad files", func(t *testing.T) {
|
||||||
|
rule := SizeLimitRotateRule{
|
||||||
|
DailyRotateRule: DailyRotateRule{
|
||||||
|
filename: "[a-z",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert.Empty(t, rule.OutdatedFiles())
|
||||||
|
rule.days = 1
|
||||||
|
assert.Empty(t, rule.OutdatedFiles())
|
||||||
|
rule.gzip = true
|
||||||
|
assert.Empty(t, rule.OutdatedFiles())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("temp files", func(t *testing.T) {
|
||||||
|
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
|
||||||
|
f1, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
f2, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
boundary1 := time.Now().Add(time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
|
||||||
|
f3, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary1)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = f1.Close()
|
||||||
|
_ = os.Remove(f1.Name())
|
||||||
|
_ = f2.Close()
|
||||||
|
_ = os.Remove(f2.Name())
|
||||||
|
_ = f3.Close()
|
||||||
|
_ = os.Remove(f3.Name())
|
||||||
|
})
|
||||||
|
rule := SizeLimitRotateRule{
|
||||||
|
DailyRotateRule: DailyRotateRule{
|
||||||
|
filename: path.Join(os.TempDir(), "go-zero-test-"),
|
||||||
|
days: 1,
|
||||||
|
},
|
||||||
|
maxBackups: 3,
|
||||||
|
}
|
||||||
|
assert.NotEmpty(t, rule.OutdatedFiles())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("no backups", func(t *testing.T) {
|
||||||
|
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
|
||||||
|
f1, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
f2, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
boundary1 := time.Now().Add(time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
|
||||||
|
f3, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary1)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = f1.Close()
|
||||||
|
_ = os.Remove(f1.Name())
|
||||||
|
_ = f2.Close()
|
||||||
|
_ = os.Remove(f2.Name())
|
||||||
|
_ = f3.Close()
|
||||||
|
_ = os.Remove(f3.Name())
|
||||||
|
})
|
||||||
|
rule := SizeLimitRotateRule{
|
||||||
|
DailyRotateRule: DailyRotateRule{
|
||||||
|
filename: path.Join(os.TempDir(), "go-zero-test-"),
|
||||||
|
days: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert.NotEmpty(t, rule.OutdatedFiles())
|
||||||
|
|
||||||
|
logger := new(RotateLogger)
|
||||||
|
logger.rule = &rule
|
||||||
|
logger.maybeDeleteOutdatedFiles()
|
||||||
|
assert.Empty(t, rule.OutdatedFiles())
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSizeLimitRotateRuleShallRotate(t *testing.T) {
|
func TestSizeLimitRotateRuleShallRotate(t *testing.T) {
|
||||||
@@ -61,14 +186,47 @@ func TestSizeLimitRotateRuleShallRotate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRotateLoggerClose(t *testing.T) {
|
func TestRotateLoggerClose(t *testing.T) {
|
||||||
filename, err := fs.TempFilenameWithText("foo")
|
t.Run("close", func(t *testing.T) {
|
||||||
assert.Nil(t, err)
|
filename, err := fs.TempFilenameWithText("foo")
|
||||||
if len(filename) > 0 {
|
assert.Nil(t, err)
|
||||||
defer os.Remove(filename)
|
if len(filename) > 0 {
|
||||||
}
|
defer os.Remove(filename)
|
||||||
logger, err := NewLogger(filename, new(DailyRotateRule), false)
|
}
|
||||||
assert.Nil(t, err)
|
logger, err := NewLogger(filename, new(DailyRotateRule), false)
|
||||||
assert.Nil(t, logger.Close())
|
assert.Nil(t, err)
|
||||||
|
_, err = logger.Write([]byte("foo"))
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Nil(t, logger.Close())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("close and write", func(t *testing.T) {
|
||||||
|
logger := new(RotateLogger)
|
||||||
|
logger.done = make(chan struct{})
|
||||||
|
close(logger.done)
|
||||||
|
_, err := logger.Write([]byte("foo"))
|
||||||
|
assert.ErrorIs(t, err, ErrLogFileClosed)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("close without losing logs", func(t *testing.T) {
|
||||||
|
text := "foo"
|
||||||
|
filename, err := fs.TempFilenameWithText(text)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
if len(filename) > 0 {
|
||||||
|
defer os.Remove(filename)
|
||||||
|
}
|
||||||
|
logger, err := NewLogger(filename, new(DailyRotateRule), false)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
msg := []byte("foo")
|
||||||
|
n := 100
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
_, err = logger.Write(msg)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
}
|
||||||
|
assert.Nil(t, logger.Close())
|
||||||
|
bs, err := os.ReadFile(filename)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, len(msg)*n+len(text), len(bs))
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRotateLoggerGetBackupFilename(t *testing.T) {
|
func TestRotateLoggerGetBackupFilename(t *testing.T) {
|
||||||
@@ -179,7 +337,7 @@ func TestRotateLoggerWithSizeLimitRotateRuleClose(t *testing.T) {
|
|||||||
}
|
}
|
||||||
logger, err := NewLogger(filename, new(SizeLimitRotateRule), false)
|
logger, err := NewLogger(filename, new(SizeLimitRotateRule), false)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Nil(t, logger.Close())
|
_ = logger.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRotateLoggerGetBackupWithSizeLimitRotateRuleFilename(t *testing.T) {
|
func TestRotateLoggerGetBackupWithSizeLimitRotateRuleFilename(t *testing.T) {
|
||||||
@@ -295,6 +453,85 @@ func TestRotateLoggerWithSizeLimitRotateRuleWrite(t *testing.T) {
|
|||||||
logger.write([]byte(`baz`))
|
logger.write([]byte(`baz`))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGzipFile(t *testing.T) {
|
||||||
|
err := errors.New("any error")
|
||||||
|
|
||||||
|
t.Run("gzip file open failed", func(t *testing.T) {
|
||||||
|
fsys := &fakeFileSystem{
|
||||||
|
openFn: func(name string) (*os.File, error) {
|
||||||
|
return nil, err
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert.ErrorIs(t, err, gzipFile("any", fsys))
|
||||||
|
assert.False(t, fsys.Removed())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("gzip file create failed", func(t *testing.T) {
|
||||||
|
fsys := &fakeFileSystem{
|
||||||
|
createFn: func(name string) (*os.File, error) {
|
||||||
|
return nil, err
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert.ErrorIs(t, err, gzipFile("any", fsys))
|
||||||
|
assert.False(t, fsys.Removed())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("gzip file copy failed", func(t *testing.T) {
|
||||||
|
fsys := &fakeFileSystem{
|
||||||
|
copyFn: func(writer io.Writer, reader io.Reader) (int64, error) {
|
||||||
|
return 0, err
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert.ErrorIs(t, err, gzipFile("any", fsys))
|
||||||
|
assert.False(t, fsys.Removed())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("gzip file last close failed", func(t *testing.T) {
|
||||||
|
var called int32
|
||||||
|
fsys := &fakeFileSystem{
|
||||||
|
closeFn: func(closer io.Closer) error {
|
||||||
|
if atomic.AddInt32(&called, 1) > 2 {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert.NoError(t, gzipFile("any", fsys))
|
||||||
|
assert.True(t, fsys.Removed())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("gzip file remove failed", func(t *testing.T) {
|
||||||
|
fsys := &fakeFileSystem{
|
||||||
|
removeFn: func(name string) error {
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert.Error(t, err, gzipFile("any", fsys))
|
||||||
|
assert.True(t, fsys.Removed())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("gzip file everything ok", func(t *testing.T) {
|
||||||
|
fsys := &fakeFileSystem{}
|
||||||
|
assert.NoError(t, gzipFile("any", fsys))
|
||||||
|
assert.True(t, fsys.Removed())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRotateLogger_WithExistingFile(t *testing.T) {
|
||||||
|
const body = "foo"
|
||||||
|
filename, err := fs.TempFilenameWithText(body)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
if len(filename) > 0 {
|
||||||
|
defer os.Remove(filename)
|
||||||
|
}
|
||||||
|
|
||||||
|
rule := NewSizeLimitRotateRule(filename, "-", 1, 100, 3, false)
|
||||||
|
logger, err := NewLogger(filename, rule, false)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, int64(len(body)), logger.currentSize)
|
||||||
|
assert.Nil(t, logger.Close())
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkRotateLogger(b *testing.B) {
|
func BenchmarkRotateLogger(b *testing.B) {
|
||||||
filename := "./test.log"
|
filename := "./test.log"
|
||||||
filename2 := "./test2.log"
|
filename2 := "./test2.log"
|
||||||
@@ -346,3 +583,53 @@ func BenchmarkRotateLogger(b *testing.B) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type fakeFileSystem struct {
|
||||||
|
removed int32
|
||||||
|
closeFn func(closer io.Closer) error
|
||||||
|
copyFn func(writer io.Writer, reader io.Reader) (int64, error)
|
||||||
|
createFn func(name string) (*os.File, error)
|
||||||
|
openFn func(name string) (*os.File, error)
|
||||||
|
removeFn func(name string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeFileSystem) Close(closer io.Closer) error {
|
||||||
|
if f.closeFn != nil {
|
||||||
|
return f.closeFn(closer)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeFileSystem) Copy(writer io.Writer, reader io.Reader) (int64, error) {
|
||||||
|
if f.copyFn != nil {
|
||||||
|
return f.copyFn(writer, reader)
|
||||||
|
}
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeFileSystem) Create(name string) (*os.File, error) {
|
||||||
|
if f.createFn != nil {
|
||||||
|
return f.createFn(name)
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeFileSystem) Open(name string) (*os.File, error) {
|
||||||
|
if f.openFn != nil {
|
||||||
|
return f.openFn(name)
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeFileSystem) Remove(name string) error {
|
||||||
|
atomic.AddInt32(&f.removed, 1)
|
||||||
|
|
||||||
|
if f.removeFn != nil {
|
||||||
|
return f.removeFn(name)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeFileSystem) Removed() bool {
|
||||||
|
return atomic.LoadInt32(&f.removed) > 0
|
||||||
|
}
|
||||||
|
|||||||
@@ -12,6 +12,8 @@ import (
|
|||||||
|
|
||||||
const testlog = "Stay hungry, stay foolish."
|
const testlog = "Stay hungry, stay foolish."
|
||||||
|
|
||||||
|
var testobj = map[string]any{"foo": "bar"}
|
||||||
|
|
||||||
func TestCollectSysLog(t *testing.T) {
|
func TestCollectSysLog(t *testing.T) {
|
||||||
CollectSysLog()
|
CollectSysLog()
|
||||||
content := getContent(captureOutput(func() {
|
content := getContent(captureOutput(func() {
|
||||||
|
|||||||
@@ -97,6 +97,15 @@ func TestConsoleWriter(t *testing.T) {
|
|||||||
w.(*concreteWriter).statLog = easyToCloseWriter{}
|
w.(*concreteWriter).statLog = easyToCloseWriter{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNewFileWriter(t *testing.T) {
|
||||||
|
t.Run("access", func(t *testing.T) {
|
||||||
|
_, err := newFileWriter(LogConf{
|
||||||
|
Path: "/not-exists",
|
||||||
|
})
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestNopWriter(t *testing.T) {
|
func TestNopWriter(t *testing.T) {
|
||||||
assert.NotPanics(t, func() {
|
assert.NotPanics(t, func() {
|
||||||
var w nopWriter
|
var w nopWriter
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -20,6 +21,7 @@ import (
|
|||||||
const (
|
const (
|
||||||
defaultKeyName = "key"
|
defaultKeyName = "key"
|
||||||
delimiter = '.'
|
delimiter = '.'
|
||||||
|
ignoreKey = "-"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -49,6 +51,7 @@ type (
|
|||||||
unmarshalOptions struct {
|
unmarshalOptions struct {
|
||||||
fillDefault bool
|
fillDefault bool
|
||||||
fromString bool
|
fromString bool
|
||||||
|
opaqueKeys bool
|
||||||
canonicalKey func(key string) string
|
canonicalKey func(key string) string
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -72,7 +75,11 @@ func UnmarshalKey(m map[string]any, v any) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Unmarshal unmarshals m into v.
|
// Unmarshal unmarshals m into v.
|
||||||
func (u *Unmarshaler) Unmarshal(i any, v any) error {
|
func (u *Unmarshaler) Unmarshal(i, v any) error {
|
||||||
|
return u.unmarshal(i, v, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *Unmarshaler) unmarshal(i, v any, fullName string) error {
|
||||||
valueType := reflect.TypeOf(v)
|
valueType := reflect.TypeOf(v)
|
||||||
if valueType.Kind() != reflect.Ptr {
|
if valueType.Kind() != reflect.Ptr {
|
||||||
return errValueNotSettable
|
return errValueNotSettable
|
||||||
@@ -85,13 +92,13 @@ func (u *Unmarshaler) Unmarshal(i any, v any) error {
|
|||||||
return errTypeMismatch
|
return errTypeMismatch
|
||||||
}
|
}
|
||||||
|
|
||||||
return u.UnmarshalValuer(mapValuer(iv), v)
|
return u.unmarshalValuer(mapValuer(iv), v, fullName)
|
||||||
case []any:
|
case []any:
|
||||||
if elemType.Kind() != reflect.Slice {
|
if elemType.Kind() != reflect.Slice {
|
||||||
return errTypeMismatch
|
return errTypeMismatch
|
||||||
}
|
}
|
||||||
|
|
||||||
return u.fillSlice(elemType, reflect.ValueOf(v).Elem(), iv)
|
return u.fillSlice(elemType, reflect.ValueOf(v).Elem(), iv, fullName)
|
||||||
default:
|
default:
|
||||||
return errUnsupportedType
|
return errUnsupportedType
|
||||||
}
|
}
|
||||||
@@ -99,17 +106,21 @@ func (u *Unmarshaler) Unmarshal(i any, v any) error {
|
|||||||
|
|
||||||
// UnmarshalValuer unmarshals m into v.
|
// UnmarshalValuer unmarshals m into v.
|
||||||
func (u *Unmarshaler) UnmarshalValuer(m Valuer, v any) error {
|
func (u *Unmarshaler) UnmarshalValuer(m Valuer, v any) error {
|
||||||
return u.unmarshalWithFullName(simpleValuer{current: m}, v, "")
|
return u.unmarshalValuer(simpleValuer{current: m}, v, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *Unmarshaler) fillMap(fieldType reflect.Type, value reflect.Value, mapValue any) error {
|
func (u *Unmarshaler) unmarshalValuer(m Valuer, v any, fullName string) error {
|
||||||
|
return u.unmarshalWithFullName(simpleValuer{current: m}, v, fullName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *Unmarshaler) fillMap(fieldType reflect.Type, value reflect.Value, mapValue any, fullName string) error {
|
||||||
if !value.CanSet() {
|
if !value.CanSet() {
|
||||||
return errValueNotSettable
|
return errValueNotSettable
|
||||||
}
|
}
|
||||||
|
|
||||||
fieldKeyType := fieldType.Key()
|
fieldKeyType := fieldType.Key()
|
||||||
fieldElemType := fieldType.Elem()
|
fieldElemType := fieldType.Elem()
|
||||||
targetValue, err := u.generateMap(fieldKeyType, fieldElemType, mapValue)
|
targetValue, err := u.generateMap(fieldKeyType, fieldElemType, mapValue, fullName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -143,19 +154,22 @@ func (u *Unmarshaler) fillMapFromString(value reflect.Value, mapValue any) error
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value, mapValue any) error {
|
func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value, mapValue any, fullName string) error {
|
||||||
if !value.CanSet() {
|
if !value.CanSet() {
|
||||||
return errValueNotSettable
|
return errValueNotSettable
|
||||||
}
|
}
|
||||||
|
|
||||||
|
refValue := reflect.ValueOf(mapValue)
|
||||||
|
if refValue.Kind() != reflect.Slice {
|
||||||
|
return newTypeMismatchErrorWithHint(fullName, reflect.Slice.String(), refValue.Type().String())
|
||||||
|
}
|
||||||
|
if refValue.IsNil() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
baseType := fieldType.Elem()
|
baseType := fieldType.Elem()
|
||||||
dereffedBaseType := Deref(baseType)
|
dereffedBaseType := Deref(baseType)
|
||||||
dereffedBaseKind := dereffedBaseType.Kind()
|
dereffedBaseKind := dereffedBaseType.Kind()
|
||||||
refValue := reflect.ValueOf(mapValue)
|
|
||||||
if refValue.IsNil() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
conv := reflect.MakeSlice(reflect.SliceOf(baseType), refValue.Len(), refValue.Cap())
|
conv := reflect.MakeSlice(reflect.SliceOf(baseType), refValue.Len(), refValue.Cap())
|
||||||
if refValue.Len() == 0 {
|
if refValue.Len() == 0 {
|
||||||
value.Set(conv)
|
value.Set(conv)
|
||||||
@@ -170,20 +184,27 @@ func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value, map
|
|||||||
}
|
}
|
||||||
|
|
||||||
valid = true
|
valid = true
|
||||||
|
sliceFullName := fmt.Sprintf("%s[%d]", fullName, i)
|
||||||
|
|
||||||
switch dereffedBaseKind {
|
switch dereffedBaseKind {
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
target := reflect.New(dereffedBaseType)
|
target := reflect.New(dereffedBaseType)
|
||||||
if err := u.Unmarshal(ithValue.(map[string]any), target.Interface()); err != nil {
|
val, ok := ithValue.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return errTypeMismatch
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := u.unmarshal(val, target.Interface(), sliceFullName); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
SetValue(fieldType.Elem(), conv.Index(i), target.Elem())
|
SetValue(fieldType.Elem(), conv.Index(i), target.Elem())
|
||||||
case reflect.Slice:
|
case reflect.Slice:
|
||||||
if err := u.fillSlice(dereffedBaseType, conv.Index(i), ithValue); err != nil {
|
if err := u.fillSlice(dereffedBaseType, conv.Index(i), ithValue, sliceFullName); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
if err := u.fillSliceValue(conv, i, dereffedBaseKind, ithValue); err != nil {
|
if err := u.fillSliceValue(conv, i, dereffedBaseKind, ithValue, sliceFullName); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -197,7 +218,7 @@ func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value, map
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (u *Unmarshaler) fillSliceFromString(fieldType reflect.Type, value reflect.Value,
|
func (u *Unmarshaler) fillSliceFromString(fieldType reflect.Type, value reflect.Value,
|
||||||
mapValue any) error {
|
mapValue any, fullName string) error {
|
||||||
var slice []any
|
var slice []any
|
||||||
switch v := mapValue.(type) {
|
switch v := mapValue.(type) {
|
||||||
case fmt.Stringer:
|
case fmt.Stringer:
|
||||||
@@ -217,7 +238,7 @@ func (u *Unmarshaler) fillSliceFromString(fieldType reflect.Type, value reflect.
|
|||||||
conv := reflect.MakeSlice(reflect.SliceOf(baseFieldType), len(slice), cap(slice))
|
conv := reflect.MakeSlice(reflect.SliceOf(baseFieldType), len(slice), cap(slice))
|
||||||
|
|
||||||
for i := 0; i < len(slice); i++ {
|
for i := 0; i < len(slice); i++ {
|
||||||
if err := u.fillSliceValue(conv, i, baseFieldKind, slice[i]); err != nil {
|
if err := u.fillSliceValue(conv, i, baseFieldKind, slice[i], fullName); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -227,7 +248,7 @@ func (u *Unmarshaler) fillSliceFromString(fieldType reflect.Type, value reflect.
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (u *Unmarshaler) fillSliceValue(slice reflect.Value, index int,
|
func (u *Unmarshaler) fillSliceValue(slice reflect.Value, index int,
|
||||||
baseKind reflect.Kind, value any) error {
|
baseKind reflect.Kind, value any, fullName string) error {
|
||||||
ithVal := slice.Index(index)
|
ithVal := slice.Index(index)
|
||||||
switch v := value.(type) {
|
switch v := value.(type) {
|
||||||
case fmt.Stringer:
|
case fmt.Stringer:
|
||||||
@@ -235,7 +256,7 @@ func (u *Unmarshaler) fillSliceValue(slice reflect.Value, index int,
|
|||||||
case string:
|
case string:
|
||||||
return setValueFromString(baseKind, ithVal, v)
|
return setValueFromString(baseKind, ithVal, v)
|
||||||
case map[string]any:
|
case map[string]any:
|
||||||
return u.fillMap(ithVal.Type(), ithVal, value)
|
return u.fillMap(ithVal.Type(), ithVal, value, fullName)
|
||||||
default:
|
default:
|
||||||
// don't need to consider the difference between int, int8, int16, int32, int64,
|
// don't need to consider the difference between int, int8, int16, int32, int64,
|
||||||
// uint, uint8, uint16, uint32, uint64, because they're handled as json.Number.
|
// uint, uint8, uint16, uint32, uint64, because they're handled as json.Number.
|
||||||
@@ -261,7 +282,7 @@ func (u *Unmarshaler) fillSliceValue(slice reflect.Value, index int,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (u *Unmarshaler) fillSliceWithDefault(derefedType reflect.Type, value reflect.Value,
|
func (u *Unmarshaler) fillSliceWithDefault(derefedType reflect.Type, value reflect.Value,
|
||||||
defaultValue string) error {
|
defaultValue, fullName string) error {
|
||||||
baseFieldType := Deref(derefedType.Elem())
|
baseFieldType := Deref(derefedType.Elem())
|
||||||
baseFieldKind := baseFieldType.Kind()
|
baseFieldKind := baseFieldType.Kind()
|
||||||
defaultCacheLock.Lock()
|
defaultCacheLock.Lock()
|
||||||
@@ -279,10 +300,10 @@ func (u *Unmarshaler) fillSliceWithDefault(derefedType reflect.Type, value refle
|
|||||||
defaultCacheLock.Unlock()
|
defaultCacheLock.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
return u.fillSlice(derefedType, value, slice)
|
return u.fillSlice(derefedType, value, slice, fullName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *Unmarshaler) generateMap(keyType, elemType reflect.Type, mapValue any) (reflect.Value, error) {
|
func (u *Unmarshaler) generateMap(keyType, elemType reflect.Type, mapValue any, fullName string) (reflect.Value, error) {
|
||||||
mapType := reflect.MapOf(keyType, elemType)
|
mapType := reflect.MapOf(keyType, elemType)
|
||||||
valueType := reflect.TypeOf(mapValue)
|
valueType := reflect.TypeOf(mapValue)
|
||||||
if mapType == valueType {
|
if mapType == valueType {
|
||||||
@@ -301,11 +322,12 @@ func (u *Unmarshaler) generateMap(keyType, elemType reflect.Type, mapValue any)
|
|||||||
for _, key := range refValue.MapKeys() {
|
for _, key := range refValue.MapKeys() {
|
||||||
keythValue := refValue.MapIndex(key)
|
keythValue := refValue.MapIndex(key)
|
||||||
keythData := keythValue.Interface()
|
keythData := keythValue.Interface()
|
||||||
|
mapFullName := fmt.Sprintf("%s[%s]", fullName, key.String())
|
||||||
|
|
||||||
switch dereffedElemKind {
|
switch dereffedElemKind {
|
||||||
case reflect.Slice:
|
case reflect.Slice:
|
||||||
target := reflect.New(dereffedElemType)
|
target := reflect.New(dereffedElemType)
|
||||||
if err := u.fillSlice(elemType, target.Elem(), keythData); err != nil {
|
if err := u.fillSlice(elemType, target.Elem(), keythData, mapFullName); err != nil {
|
||||||
return emptyValue, err
|
return emptyValue, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -317,7 +339,7 @@ func (u *Unmarshaler) generateMap(keyType, elemType reflect.Type, mapValue any)
|
|||||||
}
|
}
|
||||||
|
|
||||||
target := reflect.New(dereffedElemType)
|
target := reflect.New(dereffedElemType)
|
||||||
if err := u.Unmarshal(keythMap, target.Interface()); err != nil {
|
if err := u.unmarshal(keythMap, target.Interface(), mapFullName); err != nil {
|
||||||
return emptyValue, err
|
return emptyValue, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -328,7 +350,7 @@ func (u *Unmarshaler) generateMap(keyType, elemType reflect.Type, mapValue any)
|
|||||||
return emptyValue, errTypeMismatch
|
return emptyValue, errTypeMismatch
|
||||||
}
|
}
|
||||||
|
|
||||||
innerValue, err := u.generateMap(elemType.Key(), elemType.Elem(), keythMap)
|
innerValue, err := u.generateMap(elemType.Key(), elemType.Elem(), keythMap, mapFullName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return emptyValue, err
|
return emptyValue, err
|
||||||
}
|
}
|
||||||
@@ -347,7 +369,12 @@ func (u *Unmarshaler) generateMap(keyType, elemType reflect.Type, mapValue any)
|
|||||||
return emptyValue, errTypeMismatch
|
return emptyValue, errTypeMismatch
|
||||||
}
|
}
|
||||||
|
|
||||||
targetValue.SetMapIndex(key, reflect.ValueOf(v))
|
val := reflect.ValueOf(v)
|
||||||
|
if !val.Type().AssignableTo(dereffedElemType) {
|
||||||
|
return emptyValue, errTypeMismatch
|
||||||
|
}
|
||||||
|
|
||||||
|
targetValue.SetMapIndex(key, val)
|
||||||
case json.Number:
|
case json.Number:
|
||||||
target := reflect.New(dereffedElemType)
|
target := reflect.New(dereffedElemType)
|
||||||
if err := setValueFromString(dereffedElemKind, target.Elem(), v.String()); err != nil {
|
if err := setValueFromString(dereffedElemKind, target.Elem(), v.String()); err != nil {
|
||||||
@@ -412,6 +439,10 @@ func (u *Unmarshaler) processAnonymousField(field reflect.StructField, value ref
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if key == ignoreKey {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
if options.optional() {
|
if options.optional() {
|
||||||
return u.processAnonymousFieldOptional(field, value, key, m, fullName)
|
return u.processAnonymousFieldOptional(field, value, key, m, fullName)
|
||||||
}
|
}
|
||||||
@@ -470,7 +501,7 @@ func (u *Unmarshaler) processAnonymousStructFieldOptional(fieldType reflect.Type
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, hasValue := getValue(m, fieldKey)
|
_, hasValue := getValue(m, fieldKey, u.opts.opaqueKeys)
|
||||||
if hasValue {
|
if hasValue {
|
||||||
if !filled {
|
if !filled {
|
||||||
filled = true
|
filled = true
|
||||||
@@ -513,8 +544,8 @@ func (u *Unmarshaler) processFieldNotFromString(fieldType reflect.Type, value re
|
|||||||
vp valueWithParent, opts *fieldOptionsWithContext, fullName string) error {
|
vp valueWithParent, opts *fieldOptionsWithContext, fullName string) error {
|
||||||
derefedFieldType := Deref(fieldType)
|
derefedFieldType := Deref(fieldType)
|
||||||
typeKind := derefedFieldType.Kind()
|
typeKind := derefedFieldType.Kind()
|
||||||
valueKind := reflect.TypeOf(vp.value).Kind()
|
|
||||||
mapValue := vp.value
|
mapValue := vp.value
|
||||||
|
valueKind := reflect.TypeOf(mapValue).Kind()
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case valueKind == reflect.Map && typeKind == reflect.Struct:
|
case valueKind == reflect.Map && typeKind == reflect.Struct:
|
||||||
@@ -527,12 +558,14 @@ func (u *Unmarshaler) processFieldNotFromString(fieldType reflect.Type, value re
|
|||||||
current: mapValuer(mv),
|
current: mapValuer(mv),
|
||||||
parent: vp.parent,
|
parent: vp.parent,
|
||||||
}, fullName)
|
}, fullName)
|
||||||
|
case typeKind == reflect.Slice && valueKind == reflect.Slice:
|
||||||
|
return u.fillSlice(fieldType, value, mapValue, fullName)
|
||||||
case valueKind == reflect.Map && typeKind == reflect.Map:
|
case valueKind == reflect.Map && typeKind == reflect.Map:
|
||||||
return u.fillMap(fieldType, value, mapValue)
|
return u.fillMap(fieldType, value, mapValue, fullName)
|
||||||
case valueKind == reflect.String && typeKind == reflect.Map:
|
case valueKind == reflect.String && typeKind == reflect.Map:
|
||||||
return u.fillMapFromString(value, mapValue)
|
return u.fillMapFromString(value, mapValue)
|
||||||
case valueKind == reflect.String && typeKind == reflect.Slice:
|
case valueKind == reflect.String && typeKind == reflect.Slice:
|
||||||
return u.fillSliceFromString(fieldType, value, mapValue)
|
return u.fillSliceFromString(fieldType, value, mapValue, fullName)
|
||||||
case valueKind == reflect.String && derefedFieldType == durationType:
|
case valueKind == reflect.String && derefedFieldType == durationType:
|
||||||
return fillDurationValue(fieldType, value, mapValue.(string))
|
return fillDurationValue(fieldType, value, mapValue.(string))
|
||||||
default:
|
default:
|
||||||
@@ -545,23 +578,16 @@ func (u *Unmarshaler) processFieldPrimitive(fieldType reflect.Type, value reflec
|
|||||||
typeKind := Deref(fieldType).Kind()
|
typeKind := Deref(fieldType).Kind()
|
||||||
valueKind := reflect.TypeOf(mapValue).Kind()
|
valueKind := reflect.TypeOf(mapValue).Kind()
|
||||||
|
|
||||||
switch {
|
switch v := mapValue.(type) {
|
||||||
case typeKind == reflect.Slice && valueKind == reflect.Slice:
|
case json.Number:
|
||||||
return u.fillSlice(fieldType, value, mapValue)
|
return u.processFieldPrimitiveWithJSONNumber(fieldType, value, v, opts, fullName)
|
||||||
case typeKind == reflect.Map && valueKind == reflect.Map:
|
|
||||||
return u.fillMap(fieldType, value, mapValue)
|
|
||||||
default:
|
default:
|
||||||
switch v := mapValue.(type) {
|
if typeKind == valueKind {
|
||||||
case json.Number:
|
if err := validateValueInOptions(mapValue, opts.options()); err != nil {
|
||||||
return u.processFieldPrimitiveWithJSONNumber(fieldType, value, v, opts, fullName)
|
return err
|
||||||
default:
|
|
||||||
if typeKind == valueKind {
|
|
||||||
if err := validateValueInOptions(mapValue, opts.options()); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return fillWithSameType(fieldType, value, mapValue, opts)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return fillWithSameType(fieldType, value, mapValue, opts)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -584,25 +610,23 @@ func (u *Unmarshaler) processFieldPrimitiveWithJSONNumber(fieldType reflect.Type
|
|||||||
target := reflect.New(Deref(fieldType)).Elem()
|
target := reflect.New(Deref(fieldType)).Elem()
|
||||||
|
|
||||||
switch typeKind {
|
switch typeKind {
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
||||||
iValue, err := v.Int64()
|
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||||
|
if err := setValueFromString(typeKind, target, v.String()); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
case reflect.Float32:
|
||||||
|
fValue, err := v.Float64()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
target.SetInt(iValue)
|
if fValue > math.MaxFloat32 {
|
||||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
return float32OverflowError(v.String())
|
||||||
iValue, err := v.Int64()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if iValue < 0 {
|
target.SetFloat(fValue)
|
||||||
return fmt.Errorf("unmarshal %q with bad value %q", fullName, v.String())
|
case reflect.Float64:
|
||||||
}
|
|
||||||
|
|
||||||
target.SetUint(uint64(iValue))
|
|
||||||
case reflect.Float32, reflect.Float64:
|
|
||||||
fValue, err := v.Float64()
|
fValue, err := v.Float64()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -610,7 +634,7 @@ func (u *Unmarshaler) processFieldPrimitiveWithJSONNumber(fieldType reflect.Type
|
|||||||
|
|
||||||
target.SetFloat(fValue)
|
target.SetFloat(fValue)
|
||||||
default:
|
default:
|
||||||
return newTypeMismatchError(fullName)
|
return newTypeMismatchErrorWithHint(fullName, typeKind.String(), value.Type().String())
|
||||||
}
|
}
|
||||||
|
|
||||||
SetValue(fieldType, value, target)
|
SetValue(fieldType, value, target)
|
||||||
@@ -704,6 +728,10 @@ func (u *Unmarshaler) processNamedField(field reflect.StructField, value reflect
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if key == ignoreKey {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
fullName = join(fullName, key)
|
fullName = join(fullName, key)
|
||||||
if opts != nil && len(opts.EnvVar) > 0 {
|
if opts != nil && len(opts.EnvVar) > 0 {
|
||||||
envVal := proc.Env(opts.EnvVar)
|
envVal := proc.Env(opts.EnvVar)
|
||||||
@@ -718,7 +746,7 @@ func (u *Unmarshaler) processNamedField(field reflect.StructField, value reflect
|
|||||||
}
|
}
|
||||||
|
|
||||||
valuer := createValuer(m, opts)
|
valuer := createValuer(m, opts)
|
||||||
mapValue, hasValue := getValue(valuer, canonicalKey)
|
mapValue, hasValue := getValue(valuer, canonicalKey, u.opts.opaqueKeys)
|
||||||
|
|
||||||
// When fillDefault is used, m is a null value, hasValue must be false, all priority judgments fillDefault.
|
// When fillDefault is used, m is a null value, hasValue must be false, all priority judgments fillDefault.
|
||||||
if u.opts.fillDefault {
|
if u.opts.fillDefault {
|
||||||
@@ -811,7 +839,7 @@ func (u *Unmarshaler) processNamedFieldWithoutValue(fieldType reflect.Type, valu
|
|||||||
|
|
||||||
switch fieldKind {
|
switch fieldKind {
|
||||||
case reflect.Array, reflect.Slice:
|
case reflect.Array, reflect.Slice:
|
||||||
return u.fillSliceWithDefault(derefedType, value, defaultValue)
|
return u.fillSliceWithDefault(derefedType, value, defaultValue, fullName)
|
||||||
default:
|
default:
|
||||||
return setValueFromString(fieldKind, value, defaultValue)
|
return setValueFromString(fieldKind, value, defaultValue)
|
||||||
}
|
}
|
||||||
@@ -859,7 +887,7 @@ func (u *Unmarshaler) processNamedFieldWithoutValue(fieldType reflect.Type, valu
|
|||||||
|
|
||||||
func (u *Unmarshaler) unmarshalWithFullName(m valuerWithParent, v any, fullName string) error {
|
func (u *Unmarshaler) unmarshalWithFullName(m valuerWithParent, v any, fullName string) error {
|
||||||
rv := reflect.ValueOf(v)
|
rv := reflect.ValueOf(v)
|
||||||
if err := ValidatePtr(&rv); err != nil {
|
if err := ValidatePtr(rv); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -881,11 +909,6 @@ func (u *Unmarshaler) unmarshalWithFullName(m valuerWithParent, v any, fullName
|
|||||||
typeField := baseType.Field(i)
|
typeField := baseType.Field(i)
|
||||||
valueField := valElem.Field(i)
|
valueField := valElem.Field(i)
|
||||||
if err := u.processField(typeField, valueField, m, fullName); err != nil {
|
if err := u.processField(typeField, valueField, m, fullName); err != nil {
|
||||||
if len(fullName) > 0 {
|
|
||||||
err = fmt.Errorf("%w, fullName: %s, field: %s, type: %s",
|
|
||||||
err, fullName, typeField.Name, valueField.Type().Name())
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -914,6 +937,14 @@ func WithDefault() UnmarshalOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithOpaqueKeys customizes an Unmarshaler with opaque keys.
|
||||||
|
// Opaque keys are keys that are not processed by the unmarshaler.
|
||||||
|
func WithOpaqueKeys() UnmarshalOption {
|
||||||
|
return func(opt *unmarshalOptions) {
|
||||||
|
opt.opaqueKeys = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func createValuer(v valuerWithParent, opts *fieldOptionsWithContext) valuerWithParent {
|
func createValuer(v valuerWithParent, opts *fieldOptionsWithContext) valuerWithParent {
|
||||||
if opts.inherit() {
|
if opts.inherit() {
|
||||||
return recursiveValuer{
|
return recursiveValuer{
|
||||||
@@ -991,8 +1022,8 @@ func fillWithSameType(fieldType reflect.Type, value reflect.Value, mapValue any,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getValue gets the value for the specific key, the key can be in the format of parentKey.childKey
|
// getValue gets the value for the specific key, the key can be in the format of parentKey.childKey
|
||||||
func getValue(m valuerWithParent, key string) (any, bool) {
|
func getValue(m valuerWithParent, key string, opaque bool) (any, bool) {
|
||||||
keys := readKeys(key)
|
keys := readKeys(key, opaque)
|
||||||
return getValueWithChainedKeys(m, keys)
|
return getValueWithChainedKeys(m, keys)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1046,7 +1077,16 @@ func newTypeMismatchError(name string) error {
|
|||||||
return fmt.Errorf("type mismatch for field %q", name)
|
return fmt.Errorf("type mismatch for field %q", name)
|
||||||
}
|
}
|
||||||
|
|
||||||
func readKeys(key string) []string {
|
func newTypeMismatchErrorWithHint(name, expectType, actualType string) error {
|
||||||
|
return fmt.Errorf("type mismatch for field %q, expect %q, actual %q",
|
||||||
|
name, expectType, actualType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func readKeys(key string, opaque bool) []string {
|
||||||
|
if opaque {
|
||||||
|
return []string{key}
|
||||||
|
}
|
||||||
|
|
||||||
cacheKeysLock.Lock()
|
cacheKeysLock.Lock()
|
||||||
keys, ok := cacheKeys[key]
|
keys, ok := cacheKeys[key]
|
||||||
cacheKeysLock.Unlock()
|
cacheKeysLock.Unlock()
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -42,6 +42,10 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
|
integer interface {
|
||||||
|
~int | ~int8 | ~int16 | ~int32 | ~int64 | ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64
|
||||||
|
}
|
||||||
|
|
||||||
optionsCacheValue struct {
|
optionsCacheValue struct {
|
||||||
key string
|
key string
|
||||||
options *fieldOptions
|
options *fieldOptions
|
||||||
@@ -79,7 +83,7 @@ func SetMapIndexValue(tp reflect.Type, value, key, target reflect.Value) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ValidatePtr validates v if it's a valid pointer.
|
// ValidatePtr validates v if it's a valid pointer.
|
||||||
func ValidatePtr(v *reflect.Value) error {
|
func ValidatePtr(v reflect.Value) error {
|
||||||
// sequence is very important, IsNil must be called after checking Kind() with reflect.Ptr,
|
// sequence is very important, IsNil must be called after checking Kind() with reflect.Ptr,
|
||||||
// panic otherwise
|
// panic otherwise
|
||||||
if !v.IsValid() || v.Kind() != reflect.Ptr || v.IsNil() {
|
if !v.IsValid() || v.Kind() != reflect.Ptr || v.IsNil() {
|
||||||
@@ -103,21 +107,32 @@ func convertTypeFromString(kind reflect.Kind, str string) (any, error) {
|
|||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
intValue, err := strconv.ParseInt(str, 10, 64)
|
intValue, err := strconv.ParseInt(str, 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("the value %q cannot parsed as int", str)
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return intValue, nil
|
return intValue, nil
|
||||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||||
uintValue, err := strconv.ParseUint(str, 10, 64)
|
uintValue, err := strconv.ParseUint(str, 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("the value %q cannot parsed as uint", str)
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return uintValue, nil
|
return uintValue, nil
|
||||||
case reflect.Float32, reflect.Float64:
|
case reflect.Float32:
|
||||||
floatValue, err := strconv.ParseFloat(str, 64)
|
floatValue, err := strconv.ParseFloat(str, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("the value %q cannot parsed as float", str)
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if floatValue > math.MaxFloat32 {
|
||||||
|
return 0, float32OverflowError(str)
|
||||||
|
}
|
||||||
|
|
||||||
|
return floatValue, nil
|
||||||
|
case reflect.Float64:
|
||||||
|
floatValue, err := strconv.ParseFloat(str, 64)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return floatValue, nil
|
return floatValue, nil
|
||||||
@@ -215,6 +230,10 @@ func implicitValueRequiredStruct(tag string, tp reflect.Type) (bool, error) {
|
|||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func intOverflowError[T integer](v T, kind reflect.Kind) error {
|
||||||
|
return fmt.Errorf("parsing \"%d\" as %s: value out of range", v, kind.String())
|
||||||
|
}
|
||||||
|
|
||||||
func isLeftInclude(b byte) (bool, error) {
|
func isLeftInclude(b byte) (bool, error) {
|
||||||
switch b {
|
switch b {
|
||||||
case '[':
|
case '[':
|
||||||
@@ -237,6 +256,10 @@ func isRightInclude(b byte) (bool, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func float32OverflowError(str string) error {
|
||||||
|
return fmt.Errorf("parsing %q as float32: value out of range", str)
|
||||||
|
}
|
||||||
|
|
||||||
func maybeNewValue(fieldType reflect.Type, value reflect.Value) {
|
func maybeNewValue(fieldType reflect.Type, value reflect.Value) {
|
||||||
if fieldType.Kind() == reflect.Ptr && value.IsNil() {
|
if fieldType.Kind() == reflect.Ptr && value.IsNil() {
|
||||||
value.Set(reflect.New(value.Type().Elem()))
|
value.Set(reflect.New(value.Type().Elem()))
|
||||||
@@ -372,8 +395,6 @@ func parseOption(fieldOpts *fieldOptions, fieldName, option string) error {
|
|||||||
default:
|
default:
|
||||||
return fmt.Errorf("field %q has wrong optional", fieldName)
|
return fmt.Errorf("field %q has wrong optional", fieldName)
|
||||||
}
|
}
|
||||||
case option == optionalOption:
|
|
||||||
fieldOpts.Optional = true
|
|
||||||
case strings.HasPrefix(option, optionsOption):
|
case strings.HasPrefix(option, optionsOption):
|
||||||
val, err := parseProperty(fieldName, optionsOption, option)
|
val, err := parseProperty(fieldName, optionsOption, option)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -484,22 +505,61 @@ func parseSegments(val string) []string {
|
|||||||
return segments
|
return segments
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func setIntValue(value reflect.Value, v any, min, max int64) error {
|
||||||
|
iv := v.(int64)
|
||||||
|
if iv < min || iv > max {
|
||||||
|
return intOverflowError(iv, value.Kind())
|
||||||
|
}
|
||||||
|
|
||||||
|
value.SetInt(iv)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func setMatchedPrimitiveValue(kind reflect.Kind, value reflect.Value, v any) error {
|
func setMatchedPrimitiveValue(kind reflect.Kind, value reflect.Value, v any) error {
|
||||||
switch kind {
|
switch kind {
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
value.SetBool(v.(bool))
|
value.SetBool(v.(bool))
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
return nil
|
||||||
|
case reflect.Int: // int depends on int size, 32 or 64
|
||||||
|
return setIntValue(value, v, math.MinInt, math.MaxInt)
|
||||||
|
case reflect.Int8:
|
||||||
|
return setIntValue(value, v, math.MinInt8, math.MaxInt8)
|
||||||
|
case reflect.Int16:
|
||||||
|
return setIntValue(value, v, math.MinInt16, math.MaxInt16)
|
||||||
|
case reflect.Int32:
|
||||||
|
return setIntValue(value, v, math.MinInt32, math.MaxInt32)
|
||||||
|
case reflect.Int64:
|
||||||
value.SetInt(v.(int64))
|
value.SetInt(v.(int64))
|
||||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
return nil
|
||||||
|
case reflect.Uint: // uint depends on int size, 32 or 64
|
||||||
|
return setUintValue(value, v, math.MaxUint)
|
||||||
|
case reflect.Uint8:
|
||||||
|
return setUintValue(value, v, math.MaxUint8)
|
||||||
|
case reflect.Uint16:
|
||||||
|
return setUintValue(value, v, math.MaxUint16)
|
||||||
|
case reflect.Uint32:
|
||||||
|
return setUintValue(value, v, math.MaxUint32)
|
||||||
|
case reflect.Uint64:
|
||||||
value.SetUint(v.(uint64))
|
value.SetUint(v.(uint64))
|
||||||
|
return nil
|
||||||
case reflect.Float32, reflect.Float64:
|
case reflect.Float32, reflect.Float64:
|
||||||
value.SetFloat(v.(float64))
|
value.SetFloat(v.(float64))
|
||||||
|
return nil
|
||||||
case reflect.String:
|
case reflect.String:
|
||||||
value.SetString(v.(string))
|
value.SetString(v.(string))
|
||||||
|
return nil
|
||||||
default:
|
default:
|
||||||
return errUnsupportedType
|
return errUnsupportedType
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func setUintValue(value reflect.Value, v any, boundary uint64) error {
|
||||||
|
iv := v.(uint64)
|
||||||
|
if iv > boundary {
|
||||||
|
return intOverflowError(iv, value.Kind())
|
||||||
|
}
|
||||||
|
|
||||||
|
value.SetUint(iv)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -577,7 +637,8 @@ func usingDifferentKeys(key string, field reflect.StructField) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateAndSetValue(kind reflect.Kind, value reflect.Value, str string, opts *fieldOptionsWithContext) error {
|
func validateAndSetValue(kind reflect.Kind, value reflect.Value, str string,
|
||||||
|
opts *fieldOptionsWithContext) error {
|
||||||
if !value.CanSet() {
|
if !value.CanSet() {
|
||||||
return errValueNotSettable
|
return errValueNotSettable
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -218,30 +218,31 @@ func TestParseSegments(t *testing.T) {
|
|||||||
func TestValidatePtrWithNonPtr(t *testing.T) {
|
func TestValidatePtrWithNonPtr(t *testing.T) {
|
||||||
var foo string
|
var foo string
|
||||||
rve := reflect.ValueOf(foo)
|
rve := reflect.ValueOf(foo)
|
||||||
assert.NotNil(t, ValidatePtr(&rve))
|
assert.NotNil(t, ValidatePtr(rve))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestValidatePtrWithPtr(t *testing.T) {
|
func TestValidatePtrWithPtr(t *testing.T) {
|
||||||
var foo string
|
var foo string
|
||||||
rve := reflect.ValueOf(&foo)
|
rve := reflect.ValueOf(&foo)
|
||||||
assert.Nil(t, ValidatePtr(&rve))
|
assert.Nil(t, ValidatePtr(rve))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestValidatePtrWithNilPtr(t *testing.T) {
|
func TestValidatePtrWithNilPtr(t *testing.T) {
|
||||||
var foo *string
|
var foo *string
|
||||||
rve := reflect.ValueOf(foo)
|
rve := reflect.ValueOf(foo)
|
||||||
assert.NotNil(t, ValidatePtr(&rve))
|
assert.NotNil(t, ValidatePtr(rve))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestValidatePtrWithZeroValue(t *testing.T) {
|
func TestValidatePtrWithZeroValue(t *testing.T) {
|
||||||
var s string
|
var s string
|
||||||
e := reflect.Zero(reflect.TypeOf(s))
|
e := reflect.Zero(reflect.TypeOf(s))
|
||||||
assert.NotNil(t, ValidatePtr(&e))
|
assert.NotNil(t, ValidatePtr(e))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSetValueNotSettable(t *testing.T) {
|
func TestSetValueNotSettable(t *testing.T) {
|
||||||
var i int
|
var i int
|
||||||
assert.NotNil(t, setValueFromString(reflect.Int, reflect.ValueOf(i), "1"))
|
assert.Error(t, setValueFromString(reflect.Int, reflect.ValueOf(i), "1"))
|
||||||
|
assert.Error(t, validateAndSetValue(reflect.Int, reflect.ValueOf(i), "1", nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseKeyAndOptionsErrors(t *testing.T) {
|
func TestParseKeyAndOptionsErrors(t *testing.T) {
|
||||||
@@ -300,3 +301,36 @@ func TestSetValueFormatErrors(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestValidateValueRange(t *testing.T) {
|
||||||
|
t.Run("float", func(t *testing.T) {
|
||||||
|
assert.NoError(t, validateValueRange(1.2, nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("float number range", func(t *testing.T) {
|
||||||
|
assert.NoError(t, validateNumberRange(1.2, nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("bad float", func(t *testing.T) {
|
||||||
|
assert.Error(t, validateValueRange("a", &fieldOptionsWithContext{
|
||||||
|
Range: &numberRange{},
|
||||||
|
}))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("bad float validate", func(t *testing.T) {
|
||||||
|
var v struct {
|
||||||
|
Foo float32
|
||||||
|
}
|
||||||
|
assert.Error(t, validateAndSetValue(reflect.Int, reflect.ValueOf(&v).Elem().Field(0),
|
||||||
|
"1", &fieldOptionsWithContext{
|
||||||
|
Range: &numberRange{
|
||||||
|
left: 2,
|
||||||
|
right: 3,
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetMatchedPrimitiveValue(t *testing.T) {
|
||||||
|
assert.Error(t, setMatchedPrimitiveValue(reflect.Func, reflect.ValueOf(2), "1"))
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
|
|
||||||
"github.com/prometheus/client_golang/prometheus/testutil"
|
"github.com/prometheus/client_golang/prometheus/testutil"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/zeromicro/go-zero/core/proc"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewHistogramVec(t *testing.T) {
|
func TestNewHistogramVec(t *testing.T) {
|
||||||
@@ -48,6 +47,4 @@ func TestHistogramObserve(t *testing.T) {
|
|||||||
|
|
||||||
err := testutil.CollectAndCompare(hv.histogram, strings.NewReader(metadata+val))
|
err := testutil.CollectAndCompare(hv.histogram, strings.NewReader(metadata+val))
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
proc.Shutdown()
|
|
||||||
}
|
}
|
||||||
|
|||||||
65
core/metric/summary.go
Normal file
65
core/metric/summary.go
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
package metric
|
||||||
|
|
||||||
|
import (
|
||||||
|
prom "github.com/prometheus/client_golang/prometheus"
|
||||||
|
"github.com/zeromicro/go-zero/core/proc"
|
||||||
|
"github.com/zeromicro/go-zero/core/prometheus"
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
// A SummaryVecOpts is a summary vector options
|
||||||
|
SummaryVecOpts struct {
|
||||||
|
VecOpt VectorOpts
|
||||||
|
Objectives map[float64]float64
|
||||||
|
}
|
||||||
|
|
||||||
|
// A SummaryVec interface represents a summary vector.
|
||||||
|
SummaryVec interface {
|
||||||
|
// Observe adds observation v to labels.
|
||||||
|
Observe(v float64, labels ...string)
|
||||||
|
close() bool
|
||||||
|
}
|
||||||
|
|
||||||
|
promSummaryVec struct {
|
||||||
|
summary *prom.SummaryVec
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewSummaryVec return a SummaryVec
|
||||||
|
func NewSummaryVec(cfg *SummaryVecOpts) SummaryVec {
|
||||||
|
if cfg == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
vec := prom.NewSummaryVec(
|
||||||
|
prom.SummaryOpts{
|
||||||
|
Namespace: cfg.VecOpt.Namespace,
|
||||||
|
Subsystem: cfg.VecOpt.Subsystem,
|
||||||
|
Name: cfg.VecOpt.Name,
|
||||||
|
Help: cfg.VecOpt.Help,
|
||||||
|
Objectives: cfg.Objectives,
|
||||||
|
},
|
||||||
|
cfg.VecOpt.Labels,
|
||||||
|
)
|
||||||
|
prom.MustRegister(vec)
|
||||||
|
sv := &promSummaryVec{
|
||||||
|
summary: vec,
|
||||||
|
}
|
||||||
|
proc.AddShutdownListener(func() {
|
||||||
|
sv.close()
|
||||||
|
})
|
||||||
|
|
||||||
|
return sv
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sv *promSummaryVec) Observe(v float64, labels ...string) {
|
||||||
|
if !prometheus.Enabled() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sv.summary.WithLabelValues(labels...).Observe(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sv *promSummaryVec) close() bool {
|
||||||
|
return prom.Unregister(sv.summary)
|
||||||
|
}
|
||||||
68
core/metric/summary_test.go
Normal file
68
core/metric/summary_test.go
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
package metric
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/prometheus/client_golang/prometheus/testutil"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/zeromicro/go-zero/core/proc"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewSummaryVec(t *testing.T) {
|
||||||
|
summaryVec := NewSummaryVec(&SummaryVecOpts{
|
||||||
|
VecOpt: VectorOpts{
|
||||||
|
Namespace: "http_server",
|
||||||
|
Subsystem: "requests",
|
||||||
|
Name: "duration_quantiles",
|
||||||
|
Help: "rpc client requests duration(ms) φ quantiles ",
|
||||||
|
Labels: []string{"method"},
|
||||||
|
},
|
||||||
|
Objectives: map[float64]float64{
|
||||||
|
0.5: 0.01,
|
||||||
|
0.9: 0.01,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
defer summaryVec.close()
|
||||||
|
summaryVecNil := NewSummaryVec(nil)
|
||||||
|
assert.NotNil(t, summaryVec)
|
||||||
|
assert.Nil(t, summaryVecNil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSummaryObserve(t *testing.T) {
|
||||||
|
startAgent()
|
||||||
|
summaryVec := NewSummaryVec(&SummaryVecOpts{
|
||||||
|
VecOpt: VectorOpts{
|
||||||
|
Namespace: "http_server",
|
||||||
|
Subsystem: "requests",
|
||||||
|
Name: "duration_quantiles",
|
||||||
|
Help: "rpc client requests duration(ms) φ quantiles ",
|
||||||
|
Labels: []string{"method"},
|
||||||
|
},
|
||||||
|
Objectives: map[float64]float64{
|
||||||
|
0.3: 0.01,
|
||||||
|
0.6: 0.01,
|
||||||
|
1: 0.01,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
defer summaryVec.close()
|
||||||
|
sv := summaryVec.(*promSummaryVec)
|
||||||
|
sv.Observe(100, "GET")
|
||||||
|
sv.Observe(200, "GET")
|
||||||
|
sv.Observe(300, "GET")
|
||||||
|
metadata := `
|
||||||
|
# HELP http_server_requests_duration_quantiles rpc client requests duration(ms) φ quantiles
|
||||||
|
# TYPE http_server_requests_duration_quantiles summary
|
||||||
|
`
|
||||||
|
val := `
|
||||||
|
http_server_requests_duration_quantiles{method="GET",quantile="0.3"} 100
|
||||||
|
http_server_requests_duration_quantiles{method="GET",quantile="0.6"} 200
|
||||||
|
http_server_requests_duration_quantiles{method="GET",quantile="1"} 300
|
||||||
|
http_server_requests_duration_quantiles_sum{method="GET"} 600
|
||||||
|
http_server_requests_duration_quantiles_count{method="GET"} 3
|
||||||
|
`
|
||||||
|
|
||||||
|
err := testutil.CollectAndCompare(sv.summary, strings.NewReader(metadata+val))
|
||||||
|
assert.Nil(t, err)
|
||||||
|
proc.Shutdown()
|
||||||
|
}
|
||||||
@@ -3,7 +3,7 @@ package mr
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"io/ioutil"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"runtime"
|
"runtime"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
@@ -17,7 +17,7 @@ import (
|
|||||||
var errDummy = errors.New("dummy")
|
var errDummy = errors.New("dummy")
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
log.SetOutput(ioutil.Discard)
|
log.SetOutput(io.Discard)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFinish(t *testing.T) {
|
func TestFinish(t *testing.T) {
|
||||||
@@ -574,6 +574,7 @@ func TestMapReduceWithContext(t *testing.T) {
|
|||||||
cancel()
|
cancel()
|
||||||
}
|
}
|
||||||
writer.Write(i)
|
writer.Write(i)
|
||||||
|
time.Sleep(time.Millisecond)
|
||||||
}, func(pipe <-chan int, cancel func(error)) {
|
}, func(pipe <-chan int, cancel func(error)) {
|
||||||
for item := range pipe {
|
for item := range pipe {
|
||||||
i := item
|
i := item
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package proc
|
package proc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@@ -21,13 +20,11 @@ func TestEnvInt(t *testing.T) {
|
|||||||
val, ok := EnvInt("any")
|
val, ok := EnvInt("any")
|
||||||
assert.Equal(t, 0, val)
|
assert.Equal(t, 0, val)
|
||||||
assert.False(t, ok)
|
assert.False(t, ok)
|
||||||
err := os.Setenv("anyInt", "10")
|
t.Setenv("anyInt", "10")
|
||||||
assert.Nil(t, err)
|
|
||||||
val, ok = EnvInt("anyInt")
|
val, ok = EnvInt("anyInt")
|
||||||
assert.Equal(t, 10, val)
|
assert.Equal(t, 10, val)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
err = os.Setenv("anyString", "a")
|
t.Setenv("anyString", "a")
|
||||||
assert.Nil(t, err)
|
|
||||||
val, ok = EnvInt("anyString")
|
val, ok = EnvInt("anyString")
|
||||||
assert.Equal(t, 0, val)
|
assert.Equal(t, 0, val)
|
||||||
assert.False(t, ok)
|
assert.False(t, ok)
|
||||||
|
|||||||
@@ -1,6 +0,0 @@
|
|||||||
//go:build windows
|
|
||||||
|
|
||||||
package proc
|
|
||||||
|
|
||||||
func dumpGoroutines() {
|
|
||||||
}
|
|
||||||
@@ -18,7 +18,11 @@ const (
|
|||||||
debugLevel = 2
|
debugLevel = 2
|
||||||
)
|
)
|
||||||
|
|
||||||
func dumpGoroutines() {
|
type creator interface {
|
||||||
|
Create(name string) (file *os.File, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func dumpGoroutines(ctor creator) {
|
||||||
command := path.Base(os.Args[0])
|
command := path.Base(os.Args[0])
|
||||||
pid := syscall.Getpid()
|
pid := syscall.Getpid()
|
||||||
dumpFile := path.Join(os.TempDir(), fmt.Sprintf("%s-%d-goroutines-%s.dump",
|
dumpFile := path.Join(os.TempDir(), fmt.Sprintf("%s-%d-goroutines-%s.dump",
|
||||||
@@ -26,10 +30,16 @@ func dumpGoroutines() {
|
|||||||
|
|
||||||
logx.Infof("Got dump goroutine signal, printing goroutine profile to %s", dumpFile)
|
logx.Infof("Got dump goroutine signal, printing goroutine profile to %s", dumpFile)
|
||||||
|
|
||||||
if f, err := os.Create(dumpFile); err != nil {
|
if f, err := ctor.Create(dumpFile); err != nil {
|
||||||
logx.Errorf("Failed to dump goroutine profile, error: %v", err)
|
logx.Errorf("Failed to dump goroutine profile, error: %v", err)
|
||||||
} else {
|
} else {
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
pprof.Lookup(goroutineProfile).WriteTo(f, debugLevel)
|
pprof.Lookup(goroutineProfile).WriteTo(f, debugLevel)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type fileCreator struct{}
|
||||||
|
|
||||||
|
func (fc fileCreator) Create(name string) (file *os.File, err error) {
|
||||||
|
return os.Create(name)
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,23 +1,41 @@
|
|||||||
|
//go:build linux || darwin
|
||||||
|
|
||||||
package proc
|
package proc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/zeromicro/go-zero/core/logx"
|
"github.com/zeromicro/go-zero/core/logx/logtest"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDumpGoroutines(t *testing.T) {
|
func TestDumpGoroutines(t *testing.T) {
|
||||||
var buf strings.Builder
|
t.Run("real file", func(t *testing.T) {
|
||||||
w := logx.NewWriter(&buf)
|
buf := logtest.NewCollector(t)
|
||||||
o := logx.Reset()
|
dumpGoroutines(fileCreator{})
|
||||||
logx.SetWriter(w)
|
assert.True(t, strings.Contains(buf.String(), ".dump"))
|
||||||
defer func() {
|
})
|
||||||
logx.Reset()
|
|
||||||
logx.SetWriter(o)
|
|
||||||
}()
|
|
||||||
|
|
||||||
dumpGoroutines()
|
t.Run("fake file", func(t *testing.T) {
|
||||||
assert.True(t, strings.Contains(buf.String(), ".dump"))
|
const msg = "any message"
|
||||||
|
buf := logtest.NewCollector(t)
|
||||||
|
err := errors.New(msg)
|
||||||
|
dumpGoroutines(fakeCreator{
|
||||||
|
file: &os.File{},
|
||||||
|
err: err,
|
||||||
|
})
|
||||||
|
assert.True(t, strings.Contains(buf.String(), msg))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type fakeCreator struct {
|
||||||
|
file *os.File
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fc fakeCreator) Create(name string) (file *os.File, err error) {
|
||||||
|
return fc.file, fc.err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,25 +5,16 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/zeromicro/go-zero/core/logx"
|
"github.com/zeromicro/go-zero/core/logx/logtest"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestProfile(t *testing.T) {
|
func TestProfile(t *testing.T) {
|
||||||
var buf strings.Builder
|
c := logtest.NewCollector(t)
|
||||||
w := logx.NewWriter(&buf)
|
|
||||||
o := logx.Reset()
|
|
||||||
logx.SetWriter(w)
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
logx.Reset()
|
|
||||||
logx.SetWriter(o)
|
|
||||||
}()
|
|
||||||
|
|
||||||
profiler := StartProfile()
|
profiler := StartProfile()
|
||||||
// start again should not work
|
// start again should not work
|
||||||
assert.NotNil(t, StartProfile())
|
assert.NotNil(t, StartProfile())
|
||||||
profiler.Stop()
|
profiler.Stop()
|
||||||
// stop twice
|
// stop twice
|
||||||
profiler.Stop()
|
profiler.Stop()
|
||||||
assert.True(t, strings.Contains(buf.String(), ".pprof"))
|
assert.True(t, strings.Contains(c.String(), ".pprof"))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -96,4 +96,6 @@ func (lm *listenerManager) notifyListeners() {
|
|||||||
group.RunSafe(listener)
|
group.RunSafe(listener)
|
||||||
}
|
}
|
||||||
group.Wait()
|
group.Wait()
|
||||||
|
|
||||||
|
lm.listeners = nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,3 +28,33 @@ func TestShutdown(t *testing.T) {
|
|||||||
called()
|
called()
|
||||||
assert.Equal(t, 3, val)
|
assert.Equal(t, 3, val)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNotifyMoreThanOnce(t *testing.T) {
|
||||||
|
ch := make(chan struct{}, 1)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
var val int
|
||||||
|
called := AddWrapUpListener(func() {
|
||||||
|
val++
|
||||||
|
})
|
||||||
|
WrapUp()
|
||||||
|
WrapUp()
|
||||||
|
called()
|
||||||
|
assert.Equal(t, 1, val)
|
||||||
|
|
||||||
|
called = AddShutdownListener(func() {
|
||||||
|
val += 2
|
||||||
|
})
|
||||||
|
Shutdown()
|
||||||
|
Shutdown()
|
||||||
|
called()
|
||||||
|
assert.Equal(t, 3, val)
|
||||||
|
ch <- struct{}{}
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ch:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("timeout, check error logs")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ func init() {
|
|||||||
v := <-signals
|
v := <-signals
|
||||||
switch v {
|
switch v {
|
||||||
case syscall.SIGUSR1:
|
case syscall.SIGUSR1:
|
||||||
dumpGoroutines()
|
dumpGoroutines(fileCreator{})
|
||||||
case syscall.SIGUSR2:
|
case syscall.SIGUSR2:
|
||||||
if profiler == nil {
|
if profiler == nil {
|
||||||
profiler = StartProfile()
|
profiler = StartProfile()
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build linux || darwin
|
||||||
|
|
||||||
package proc
|
package proc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ package prof
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@@ -13,6 +15,10 @@ const (
|
|||||||
|
|
||||||
// DisplayStats prints the goroutine, memory, GC stats with given interval, default to 5 seconds.
|
// DisplayStats prints the goroutine, memory, GC stats with given interval, default to 5 seconds.
|
||||||
func DisplayStats(interval ...time.Duration) {
|
func DisplayStats(interval ...time.Duration) {
|
||||||
|
displayStatsWithWriter(os.Stdout, interval...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func displayStatsWithWriter(writer io.Writer, interval ...time.Duration) {
|
||||||
duration := defaultInterval
|
duration := defaultInterval
|
||||||
for _, val := range interval {
|
for _, val := range interval {
|
||||||
duration = val
|
duration = val
|
||||||
@@ -24,7 +30,7 @@ func DisplayStats(interval ...time.Duration) {
|
|||||||
for range ticker.C {
|
for range ticker.C {
|
||||||
var m runtime.MemStats
|
var m runtime.MemStats
|
||||||
runtime.ReadMemStats(&m)
|
runtime.ReadMemStats(&m)
|
||||||
fmt.Printf("Goroutines: %d, Alloc: %vm, TotalAlloc: %vm, Sys: %vm, NumGC: %v\n",
|
fmt.Fprintf(writer, "Goroutines: %d, Alloc: %vm, TotalAlloc: %vm, Sys: %vm, NumGC: %v\n",
|
||||||
runtime.NumGoroutine(), m.Alloc/mega, m.TotalAlloc/mega, m.Sys/mega, m.NumGC)
|
runtime.NumGoroutine(), m.Alloc/mega, m.TotalAlloc/mega, m.Sys/mega, m.NumGC)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|||||||
36
core/prof/runtime_test.go
Normal file
36
core/prof/runtime_test.go
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
package prof
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDisplayStats(t *testing.T) {
|
||||||
|
writer := &threadSafeBuffer{
|
||||||
|
buf: strings.Builder{},
|
||||||
|
}
|
||||||
|
displayStatsWithWriter(writer, time.Millisecond*10)
|
||||||
|
time.Sleep(time.Millisecond * 50)
|
||||||
|
assert.Contains(t, writer.String(), "Goroutines: ")
|
||||||
|
}
|
||||||
|
|
||||||
|
type threadSafeBuffer struct {
|
||||||
|
buf strings.Builder
|
||||||
|
lock sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *threadSafeBuffer) String() string {
|
||||||
|
b.lock.Lock()
|
||||||
|
defer b.lock.Unlock()
|
||||||
|
return b.buf.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *threadSafeBuffer) Write(p []byte) (n int, err error) {
|
||||||
|
b.lock.Lock()
|
||||||
|
defer b.lock.Unlock()
|
||||||
|
return b.buf.Write(p)
|
||||||
|
}
|
||||||
@@ -21,6 +21,11 @@ func Enabled() bool {
|
|||||||
return enabled.True()
|
return enabled.True()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Enable enables prometheus.
|
||||||
|
func Enable() {
|
||||||
|
enabled.Set(true)
|
||||||
|
}
|
||||||
|
|
||||||
// StartAgent starts a prometheus agent.
|
// StartAgent starts a prometheus agent.
|
||||||
func StartAgent(c Config) {
|
func StartAgent(c Config) {
|
||||||
if len(c.Host) == 0 {
|
if len(c.Host) == 0 {
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
package queue
|
package queue
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
|
"math"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -37,10 +39,82 @@ func TestQueue(t *testing.T) {
|
|||||||
assert.Equal(t, int32(rounds), atomic.LoadInt32(&consumer.count))
|
assert.Equal(t, int32(rounds), atomic.LoadInt32(&consumer.count))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestQueue_Broadcast(t *testing.T) {
|
||||||
|
producer := newMockedProducer(math.MaxInt32)
|
||||||
|
consumer := newMockedConsumer()
|
||||||
|
consumer.wait.Add(consumers)
|
||||||
|
q := NewQueue(func() (Producer, error) {
|
||||||
|
return producer, nil
|
||||||
|
}, func() (Consumer, error) {
|
||||||
|
return consumer, nil
|
||||||
|
})
|
||||||
|
q.AddListener(new(mockedListener))
|
||||||
|
q.SetName("mockqueue")
|
||||||
|
q.SetNumConsumer(consumers)
|
||||||
|
q.SetNumProducer(1)
|
||||||
|
go func() {
|
||||||
|
time.Sleep(time.Millisecond * 100)
|
||||||
|
q.Stop()
|
||||||
|
}()
|
||||||
|
go q.Start()
|
||||||
|
time.Sleep(time.Millisecond * 50)
|
||||||
|
q.Broadcast("message")
|
||||||
|
consumer.wait.Wait()
|
||||||
|
assert.Equal(t, int32(consumers), atomic.LoadInt32(&consumer.events))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueue_PauseResume(t *testing.T) {
|
||||||
|
producer := newMockedProducer(rounds)
|
||||||
|
consumer := newMockedConsumer()
|
||||||
|
consumer.wait.Add(consumers)
|
||||||
|
q := NewQueue(func() (Producer, error) {
|
||||||
|
return producer, nil
|
||||||
|
}, func() (Consumer, error) {
|
||||||
|
return consumer, nil
|
||||||
|
})
|
||||||
|
q.AddListener(new(mockedListener))
|
||||||
|
q.SetName("mockqueue")
|
||||||
|
q.SetNumConsumer(consumers)
|
||||||
|
q.SetNumProducer(1)
|
||||||
|
go func() {
|
||||||
|
producer.wait.Wait()
|
||||||
|
q.Stop()
|
||||||
|
}()
|
||||||
|
q.Start()
|
||||||
|
producer.listener.OnProducerPause()
|
||||||
|
assert.Equal(t, int32(0), atomic.LoadInt32(&q.active))
|
||||||
|
producer.listener.OnProducerResume()
|
||||||
|
assert.Equal(t, int32(1), atomic.LoadInt32(&q.active))
|
||||||
|
assert.Equal(t, int32(rounds), atomic.LoadInt32(&consumer.count))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueue_ConsumeError(t *testing.T) {
|
||||||
|
producer := newMockedProducer(rounds)
|
||||||
|
consumer := newMockedConsumer()
|
||||||
|
consumer.consumeErr = errors.New("consume error")
|
||||||
|
consumer.wait.Add(consumers)
|
||||||
|
q := NewQueue(func() (Producer, error) {
|
||||||
|
return producer, nil
|
||||||
|
}, func() (Consumer, error) {
|
||||||
|
return consumer, nil
|
||||||
|
})
|
||||||
|
q.AddListener(new(mockedListener))
|
||||||
|
q.SetName("mockqueue")
|
||||||
|
q.SetNumConsumer(consumers)
|
||||||
|
q.SetNumProducer(1)
|
||||||
|
go func() {
|
||||||
|
producer.wait.Wait()
|
||||||
|
q.Stop()
|
||||||
|
}()
|
||||||
|
q.Start()
|
||||||
|
assert.Equal(t, int32(rounds), atomic.LoadInt32(&consumer.count))
|
||||||
|
}
|
||||||
|
|
||||||
type mockedConsumer struct {
|
type mockedConsumer struct {
|
||||||
count int32
|
count int32
|
||||||
events int32
|
events int32
|
||||||
wait sync.WaitGroup
|
consumeErr error
|
||||||
|
wait sync.WaitGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
func newMockedConsumer() *mockedConsumer {
|
func newMockedConsumer() *mockedConsumer {
|
||||||
@@ -49,7 +123,7 @@ func newMockedConsumer() *mockedConsumer {
|
|||||||
|
|
||||||
func (c *mockedConsumer) Consume(string) error {
|
func (c *mockedConsumer) Consume(string) error {
|
||||||
atomic.AddInt32(&c.count, 1)
|
atomic.AddInt32(&c.count, 1)
|
||||||
return nil
|
return c.consumeErr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *mockedConsumer) OnEvent(any) {
|
func (c *mockedConsumer) OnEvent(any) {
|
||||||
@@ -59,9 +133,10 @@ func (c *mockedConsumer) OnEvent(any) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type mockedProducer struct {
|
type mockedProducer struct {
|
||||||
total int32
|
total int32
|
||||||
count int32
|
count int32
|
||||||
wait sync.WaitGroup
|
listener ProduceListener
|
||||||
|
wait sync.WaitGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
func newMockedProducer(total int32) *mockedProducer {
|
func newMockedProducer(total int32) *mockedProducer {
|
||||||
@@ -72,6 +147,7 @@ func newMockedProducer(total int32) *mockedProducer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *mockedProducer) AddListener(listener ProduceListener) {
|
func (p *mockedProducer) AddListener(listener ProduceListener) {
|
||||||
|
p.listener = listener
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *mockedProducer) Produce() (string, bool) {
|
func (p *mockedProducer) Produce() (string, bool) {
|
||||||
|
|||||||
@@ -1,6 +1,11 @@
|
|||||||
package rescue
|
package rescue
|
||||||
|
|
||||||
import "github.com/zeromicro/go-zero/core/logx"
|
import (
|
||||||
|
"context"
|
||||||
|
"runtime/debug"
|
||||||
|
|
||||||
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
|
)
|
||||||
|
|
||||||
// Recover is used with defer to do cleanup on panics.
|
// Recover is used with defer to do cleanup on panics.
|
||||||
// Use it like:
|
// Use it like:
|
||||||
@@ -15,3 +20,14 @@ func Recover(cleanups ...func()) {
|
|||||||
logx.ErrorStack(p)
|
logx.ErrorStack(p)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RecoverCtx is used with defer to do cleanup on panics.
|
||||||
|
func RecoverCtx(ctx context.Context, cleanups ...func()) {
|
||||||
|
for _, cleanup := range cleanups {
|
||||||
|
cleanup()
|
||||||
|
}
|
||||||
|
|
||||||
|
if p := recover(); p != nil {
|
||||||
|
logx.WithContext(ctx).Errorf("%+v\n%s", p, debug.Stack())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package rescue
|
package rescue
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@@ -25,3 +26,17 @@ func TestRescue(t *testing.T) {
|
|||||||
})
|
})
|
||||||
assert.Equal(t, int32(5), atomic.LoadInt32(&count))
|
assert.Equal(t, int32(5), atomic.LoadInt32(&count))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRescueCtx(t *testing.T) {
|
||||||
|
var count int32
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
defer RecoverCtx(context.Background(), func() {
|
||||||
|
atomic.AddInt32(&count, 2)
|
||||||
|
}, func() {
|
||||||
|
atomic.AddInt32(&count, 3)
|
||||||
|
})
|
||||||
|
|
||||||
|
panic("hello")
|
||||||
|
})
|
||||||
|
assert.Equal(t, int32(5), atomic.LoadInt32(&count))
|
||||||
|
}
|
||||||
|
|||||||
@@ -171,11 +171,11 @@ func add(nd *node, route string, item any) error {
|
|||||||
token := route[:i]
|
token := route[:i]
|
||||||
children := nd.getChildren(token)
|
children := nd.getChildren(token)
|
||||||
if child, ok := children[token]; ok {
|
if child, ok := children[token]; ok {
|
||||||
if child != nil {
|
if child == nil {
|
||||||
return add(child, route[i+1:], item)
|
return errInvalidState
|
||||||
}
|
}
|
||||||
|
|
||||||
return errInvalidState
|
return add(child, route[i+1:], item)
|
||||||
}
|
}
|
||||||
|
|
||||||
child := newNode(nil)
|
child := newNode(nil)
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import (
|
|||||||
|
|
||||||
type mockedRoute struct {
|
type mockedRoute struct {
|
||||||
route string
|
route string
|
||||||
value int
|
value any
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSearch(t *testing.T) {
|
func TestSearch(t *testing.T) {
|
||||||
@@ -187,6 +187,12 @@ func TestSearchInvalidItem(t *testing.T) {
|
|||||||
assert.Equal(t, errEmptyItem, err)
|
assert.Equal(t, errEmptyItem, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSearchInvalidState(t *testing.T) {
|
||||||
|
nd := newNode("0")
|
||||||
|
nd.children[0]["1"] = nil
|
||||||
|
assert.Error(t, add(nd, "1/2", "2"))
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkSearchTree(b *testing.B) {
|
func BenchmarkSearchTree(b *testing.B) {
|
||||||
const (
|
const (
|
||||||
avgLen = 1000
|
avgLen = 1000
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"log"
|
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/core/load"
|
"github.com/zeromicro/go-zero/core/load"
|
||||||
"github.com/zeromicro/go-zero/core/logx"
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
"github.com/zeromicro/go-zero/core/proc"
|
"github.com/zeromicro/go-zero/core/proc"
|
||||||
@@ -39,9 +37,7 @@ type ServiceConf struct {
|
|||||||
|
|
||||||
// MustSetUp sets up the service, exits on error.
|
// MustSetUp sets up the service, exits on error.
|
||||||
func (sc ServiceConf) MustSetUp() {
|
func (sc ServiceConf) MustSetUp() {
|
||||||
if err := sc.SetUp(); err != nil {
|
logx.Must(sc.SetUp())
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetUp sets up the service.
|
// SetUp sets up the service.
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ func (sg *ServiceGroup) doStart() {
|
|||||||
|
|
||||||
for i := range sg.services {
|
for i := range sg.services {
|
||||||
service := sg.services[i]
|
service := sg.services[i]
|
||||||
routineGroup.RunSafe(func() {
|
routineGroup.Run(func() {
|
||||||
service.Start()
|
service.Start()
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,30 +14,6 @@ var (
|
|||||||
done = make(chan struct{})
|
done = make(chan struct{})
|
||||||
)
|
)
|
||||||
|
|
||||||
type mockedService struct {
|
|
||||||
quit chan struct{}
|
|
||||||
multiplier int
|
|
||||||
}
|
|
||||||
|
|
||||||
func newMockedService(multiplier int) *mockedService {
|
|
||||||
return &mockedService{
|
|
||||||
quit: make(chan struct{}),
|
|
||||||
multiplier: multiplier,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *mockedService) Start() {
|
|
||||||
mutex.Lock()
|
|
||||||
number *= s.multiplier
|
|
||||||
mutex.Unlock()
|
|
||||||
done <- struct{}{}
|
|
||||||
<-s.quit
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *mockedService) Stop() {
|
|
||||||
close(s.quit)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestServiceGroup(t *testing.T) {
|
func TestServiceGroup(t *testing.T) {
|
||||||
multipliers := []int{2, 3, 5, 7}
|
multipliers := []int{2, 3, 5, 7}
|
||||||
want := 1
|
want := 1
|
||||||
@@ -126,3 +102,27 @@ type mockedStarter struct {
|
|||||||
func (s mockedStarter) Start() {
|
func (s mockedStarter) Start() {
|
||||||
s.fn()
|
s.fn()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type mockedService struct {
|
||||||
|
quit chan struct{}
|
||||||
|
multiplier int
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMockedService(multiplier int) *mockedService {
|
||||||
|
return &mockedService{
|
||||||
|
quit: make(chan struct{}),
|
||||||
|
multiplier: multiplier,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *mockedService) Start() {
|
||||||
|
mutex.Lock()
|
||||||
|
number *= s.multiplier
|
||||||
|
mutex.Unlock()
|
||||||
|
done <- struct{}{}
|
||||||
|
<-s.quit
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *mockedService) Stop() {
|
||||||
|
close(s.quit)
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
package stat
|
package stat
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -12,8 +11,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestReport(t *testing.T) {
|
func TestReport(t *testing.T) {
|
||||||
os.Setenv(clusterNameKey, "test-cluster")
|
t.Setenv(clusterNameKey, "test-cluster")
|
||||||
defer os.Unsetenv(clusterNameKey)
|
|
||||||
|
|
||||||
var count int32
|
var count int32
|
||||||
SetReporter(func(s string) {
|
SetReporter(func(s string) {
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package internal
|
|||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -218,6 +219,7 @@ func parseUints(val string) ([]uint64, error) {
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var sets []uint64
|
||||||
ints := make(map[uint64]lang.PlaceholderType)
|
ints := make(map[uint64]lang.PlaceholderType)
|
||||||
cols := strings.Split(val, ",")
|
cols := strings.Split(val, ",")
|
||||||
for _, r := range cols {
|
for _, r := range cols {
|
||||||
@@ -238,7 +240,10 @@ func parseUints(val string) ([]uint64, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for i := min; i <= max; i++ {
|
for i := min; i <= max; i++ {
|
||||||
ints[i] = lang.Placeholder
|
if _, ok := ints[i]; !ok {
|
||||||
|
ints[i] = lang.Placeholder
|
||||||
|
sets = append(sets, i)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
v, err := parseUint(r)
|
v, err := parseUint(r)
|
||||||
@@ -246,19 +251,17 @@ func parseUints(val string) ([]uint64, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
ints[v] = lang.Placeholder
|
if _, ok := ints[v]; !ok {
|
||||||
|
ints[v] = lang.Placeholder
|
||||||
|
sets = append(sets, v)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var sets []uint64
|
|
||||||
for k := range ints {
|
|
||||||
sets = append(sets, k)
|
|
||||||
}
|
|
||||||
|
|
||||||
return sets, nil
|
return sets, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// runningInUserNS detects whether we are currently running in an user namespace.
|
// runningInUserNS detects whether we are currently running in a user namespace.
|
||||||
func runningInUserNS() bool {
|
func runningInUserNS() bool {
|
||||||
nsOnce.Do(func() {
|
nsOnce.Do(func() {
|
||||||
file, err := os.Open("/proc/self/uid_map")
|
file, err := os.Open("/proc/self/uid_map")
|
||||||
@@ -280,9 +283,10 @@ func runningInUserNS() bool {
|
|||||||
|
|
||||||
// We assume we are in the initial user namespace if we have a full
|
// We assume we are in the initial user namespace if we have a full
|
||||||
// range - 4294967295 uids starting at uid 0.
|
// range - 4294967295 uids starting at uid 0.
|
||||||
if a == 0 && b == 0 && c == 4294967295 {
|
if a == 0 && b == 0 && c == math.MaxUint32 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
inUserNS = true
|
inUserNS = true
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
71
core/stat/internal/cgroup_linux_test.go
Normal file
71
core/stat/internal/cgroup_linux_test.go
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRunningInUserNS(t *testing.T) {
|
||||||
|
// should be false in docker
|
||||||
|
assert.False(t, runningInUserNS())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCgroupV1(t *testing.T) {
|
||||||
|
if isCgroup2UnifiedMode() {
|
||||||
|
cg, err := currentCgroupV1()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
_, err = cg.cpus()
|
||||||
|
assert.Error(t, err)
|
||||||
|
_, err = cg.cpuPeriodUs()
|
||||||
|
assert.Error(t, err)
|
||||||
|
_, err = cg.cpuQuotaUs()
|
||||||
|
assert.Error(t, err)
|
||||||
|
_, err = cg.usageAllCpus()
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseUint(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
want uint64
|
||||||
|
err error
|
||||||
|
}{
|
||||||
|
{"0", 0, nil},
|
||||||
|
{"123", 123, nil},
|
||||||
|
{"-1", 0, nil},
|
||||||
|
{"-18446744073709551616", 0, nil},
|
||||||
|
{"foo", 0, fmt.Errorf("cgroup: bad int format: foo")},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
got, err := parseUint(tt.input)
|
||||||
|
assert.Equal(t, tt.err, err)
|
||||||
|
assert.Equal(t, tt.want, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseUints(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
want []uint64
|
||||||
|
err error
|
||||||
|
}{
|
||||||
|
{"", nil, nil},
|
||||||
|
{"1,2,3", []uint64{1, 2, 3}, nil},
|
||||||
|
{"1-3", []uint64{1, 2, 3}, nil},
|
||||||
|
{"1-3,5,7-9", []uint64{1, 2, 3, 5, 7, 8, 9}, nil},
|
||||||
|
{"foo", nil, fmt.Errorf("cgroup: bad int format: foo")},
|
||||||
|
{"1-bar", nil, fmt.Errorf("cgroup: bad int list format: 1-bar")},
|
||||||
|
{"bar-3", nil, fmt.Errorf("cgroup: bad int list format: bar-3")},
|
||||||
|
{"3-1", nil, fmt.Errorf("cgroup: bad int list format: 3-1")},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
got, err := parseUints(tt.input)
|
||||||
|
assert.Equal(t, tt.err, err)
|
||||||
|
assert.Equal(t, tt.want, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -141,7 +141,7 @@ func (c *metricsContainer) Execute(v any) {
|
|||||||
report.Median = float32(medianTask.Duration) / float32(time.Millisecond)
|
report.Median = float32(medianTask.Duration) / float32(time.Millisecond)
|
||||||
tenPercent := fiftyPercent / 5
|
tenPercent := fiftyPercent / 5
|
||||||
if tenPercent > 0 {
|
if tenPercent > 0 {
|
||||||
top10pTasks := topK(tasks, tenPercent)
|
top10pTasks := topK(top50pTasks, tenPercent)
|
||||||
task90th := top10pTasks[0]
|
task90th := top10pTasks[0]
|
||||||
report.Top90th = float32(task90th.Duration) / float32(time.Millisecond)
|
report.Top90th = float32(task90th.Duration) / float32(time.Millisecond)
|
||||||
onePercent := tenPercent / 10
|
onePercent := tenPercent / 10
|
||||||
@@ -163,7 +163,7 @@ func (c *metricsContainer) Execute(v any) {
|
|||||||
report.Top99p9th = mostDuration
|
report.Top99p9th = mostDuration
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
mostDuration := getTopDuration(tasks)
|
mostDuration := getTopDuration(top50pTasks)
|
||||||
report.Top90th = mostDuration
|
report.Top90th = mostDuration
|
||||||
report.Top99th = mostDuration
|
report.Top99th = mostDuration
|
||||||
report.Top99p9th = mostDuration
|
report.Top99p9th = mostDuration
|
||||||
|
|||||||
@@ -1,12 +1,11 @@
|
|||||||
package stat
|
package stat
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/zeromicro/go-zero/core/logx"
|
"github.com/zeromicro/go-zero/core/logx/logtest"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestBToMb(t *testing.T) {
|
func TestBToMb(t *testing.T) {
|
||||||
@@ -41,15 +40,11 @@ func TestBToMb(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPrintUsage(t *testing.T) {
|
func TestPrintUsage(t *testing.T) {
|
||||||
var buf bytes.Buffer
|
c := logtest.NewCollector(t)
|
||||||
writer := logx.NewWriter(&buf)
|
|
||||||
old := logx.Reset()
|
|
||||||
logx.SetWriter(writer)
|
|
||||||
defer logx.SetWriter(old)
|
|
||||||
|
|
||||||
printUsage()
|
printUsage()
|
||||||
|
|
||||||
output := buf.String()
|
output := c.String()
|
||||||
assert.Contains(t, output, "CPU:")
|
assert.Contains(t, output, "CPU:")
|
||||||
assert.Contains(t, output, "MEMORY:")
|
assert.Contains(t, output, "MEMORY:")
|
||||||
assert.Contains(t, output, "Alloc=")
|
assert.Contains(t, output, "Alloc=")
|
||||||
|
|||||||
@@ -69,3 +69,62 @@ func TestFieldNamesWithDashTagAndOptions(t *testing.T) {
|
|||||||
assert.Equal(t, expected, out)
|
assert.Equal(t, expected, out)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPostgreSqlJoin(t *testing.T) {
|
||||||
|
// Test with empty input array
|
||||||
|
var input []string
|
||||||
|
var expectedOutput string
|
||||||
|
assert.Equal(t, expectedOutput, PostgreSqlJoin(input))
|
||||||
|
|
||||||
|
// Test with single element input array
|
||||||
|
input = []string{"foo"}
|
||||||
|
expectedOutput = "foo = $2"
|
||||||
|
assert.Equal(t, expectedOutput, PostgreSqlJoin(input))
|
||||||
|
|
||||||
|
// Test with multiple elements input array
|
||||||
|
input = []string{"foo", "bar", "baz"}
|
||||||
|
expectedOutput = "foo = $2, bar = $3, baz = $4"
|
||||||
|
assert.Equal(t, expectedOutput, PostgreSqlJoin(input))
|
||||||
|
}
|
||||||
|
|
||||||
|
type testStruct struct {
|
||||||
|
Foo string `db:"foo"`
|
||||||
|
Bar int `db:"bar"`
|
||||||
|
Baz bool `db:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRawFieldNames(t *testing.T) {
|
||||||
|
// Test with a struct without tags
|
||||||
|
in := struct {
|
||||||
|
Foo string
|
||||||
|
Bar int
|
||||||
|
}{}
|
||||||
|
expectedOutput := []string{"`Foo`", "`Bar`"}
|
||||||
|
assert.ElementsMatch(t, expectedOutput, RawFieldNames(in))
|
||||||
|
|
||||||
|
// Test pg without db tag
|
||||||
|
expectedOutput = []string{"Foo", "Bar"}
|
||||||
|
assert.ElementsMatch(t, expectedOutput, RawFieldNames(in, true))
|
||||||
|
|
||||||
|
// Test with a struct with tags
|
||||||
|
input := testStruct{}
|
||||||
|
expectedOutput = []string{"`foo`", "`bar`"}
|
||||||
|
assert.ElementsMatch(t, expectedOutput, RawFieldNames(input))
|
||||||
|
|
||||||
|
// Test with nil input (pointer)
|
||||||
|
var nilInput *testStruct
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
RawFieldNames(nilInput)
|
||||||
|
}, "RawFieldNames should panic with nil input")
|
||||||
|
|
||||||
|
// Test with non-struct input
|
||||||
|
inputInt := 42
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
RawFieldNames(inputInt)
|
||||||
|
}, "RawFieldNames should panic with non-struct input")
|
||||||
|
|
||||||
|
// Test with PostgreSQL flag
|
||||||
|
input = testStruct{}
|
||||||
|
expectedOutput = []string{"foo", "bar"}
|
||||||
|
assert.ElementsMatch(t, expectedOutput, RawFieldNames(input, true))
|
||||||
|
}
|
||||||
|
|||||||
136
core/stores/cache/cachenode_test.go
vendored
136
core/stores/cache/cachenode_test.go
vendored
@@ -1,6 +1,3 @@
|
|||||||
//go:build !race
|
|
||||||
|
|
||||||
// Disable data race detection is because of the timingWheel in cacheNode.
|
|
||||||
package cache
|
package cache
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -34,8 +31,10 @@ func init() {
|
|||||||
|
|
||||||
func TestCacheNode_DelCache(t *testing.T) {
|
func TestCacheNode_DelCache(t *testing.T) {
|
||||||
t.Run("del cache", func(t *testing.T) {
|
t.Run("del cache", func(t *testing.T) {
|
||||||
store := redistest.CreateRedis(t)
|
r, err := miniredis.Run()
|
||||||
store.Type = redis.ClusterType
|
assert.NoError(t, err)
|
||||||
|
defer r.Close()
|
||||||
|
store := redis.New(r.Addr(), redis.Cluster())
|
||||||
|
|
||||||
cn := cacheNode{
|
cn := cacheNode{
|
||||||
rds: store,
|
rds: store,
|
||||||
@@ -56,16 +55,16 @@ func TestCacheNode_DelCache(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("del cache with errors", func(t *testing.T) {
|
t.Run("del cache with errors", func(t *testing.T) {
|
||||||
old := timingWheel
|
old := timingWheel.Load()
|
||||||
ticker := timex.NewFakeTicker()
|
ticker := timex.NewFakeTicker()
|
||||||
var err error
|
tw, err := collection.NewTimingWheelWithTicker(
|
||||||
timingWheel, err = collection.NewTimingWheelWithTicker(
|
|
||||||
time.Millisecond, timingWheelSlots, func(key, value any) {
|
time.Millisecond, timingWheelSlots, func(key, value any) {
|
||||||
clean(key, value)
|
clean(key, value)
|
||||||
}, ticker)
|
}, ticker)
|
||||||
|
timingWheel.Store(tw)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
timingWheel = old
|
timingWheel.Store(old)
|
||||||
})
|
})
|
||||||
|
|
||||||
r, err := miniredis.Run()
|
r, err := miniredis.Run()
|
||||||
@@ -166,40 +165,99 @@ func TestCacheNode_TakeBadRedis(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCacheNode_TakeNotFound(t *testing.T) {
|
func TestCacheNode_TakeNotFound(t *testing.T) {
|
||||||
store := redistest.CreateRedis(t)
|
t.Run("not found", func(t *testing.T) {
|
||||||
|
store := redistest.CreateRedis(t)
|
||||||
|
|
||||||
cn := cacheNode{
|
cn := cacheNode{
|
||||||
rds: store,
|
rds: store,
|
||||||
r: rand.New(rand.NewSource(time.Now().UnixNano())),
|
r: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||||
barrier: syncx.NewSingleFlight(),
|
barrier: syncx.NewSingleFlight(),
|
||||||
lock: new(sync.Mutex),
|
lock: new(sync.Mutex),
|
||||||
unstableExpiry: mathx.NewUnstable(expiryDeviation),
|
unstableExpiry: mathx.NewUnstable(expiryDeviation),
|
||||||
stat: NewStat("any"),
|
stat: NewStat("any"),
|
||||||
errNotFound: errTestNotFound,
|
errNotFound: errTestNotFound,
|
||||||
}
|
}
|
||||||
var str string
|
var str string
|
||||||
err := cn.Take(&str, "any", func(v any) error {
|
err := cn.Take(&str, "any", func(v any) error {
|
||||||
return errTestNotFound
|
return errTestNotFound
|
||||||
})
|
})
|
||||||
assert.True(t, cn.IsNotFound(err))
|
assert.True(t, cn.IsNotFound(err))
|
||||||
assert.True(t, cn.IsNotFound(cn.Get("any", &str)))
|
assert.True(t, cn.IsNotFound(cn.Get("any", &str)))
|
||||||
val, err := store.Get("any")
|
val, err := store.Get("any")
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, `*`, val)
|
assert.Equal(t, `*`, val)
|
||||||
|
|
||||||
store.Set("any", "*")
|
store.Set("any", "*")
|
||||||
err = cn.Take(&str, "any", func(v any) error {
|
err = cn.Take(&str, "any", func(v any) error {
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
assert.True(t, cn.IsNotFound(err))
|
assert.True(t, cn.IsNotFound(err))
|
||||||
assert.True(t, cn.IsNotFound(cn.Get("any", &str)))
|
assert.True(t, cn.IsNotFound(cn.Get("any", &str)))
|
||||||
|
|
||||||
store.Del("any")
|
store.Del("any")
|
||||||
errDummy := errors.New("dummy")
|
errDummy := errors.New("dummy")
|
||||||
err = cn.Take(&str, "any", func(v any) error {
|
err = cn.Take(&str, "any", func(v any) error {
|
||||||
return errDummy
|
return errDummy
|
||||||
|
})
|
||||||
|
assert.Equal(t, errDummy, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("not found with redis error", func(t *testing.T) {
|
||||||
|
r, err := miniredis.Run()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
defer r.Close()
|
||||||
|
store, err := redis.NewRedis(redis.RedisConf{
|
||||||
|
Host: r.Addr(),
|
||||||
|
Type: redis.NodeType,
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
cn := cacheNode{
|
||||||
|
rds: store,
|
||||||
|
r: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||||
|
barrier: syncx.NewSingleFlight(),
|
||||||
|
lock: new(sync.Mutex),
|
||||||
|
unstableExpiry: mathx.NewUnstable(expiryDeviation),
|
||||||
|
stat: NewStat("any"),
|
||||||
|
errNotFound: errTestNotFound,
|
||||||
|
}
|
||||||
|
var str string
|
||||||
|
err = cn.Take(&str, "any", func(v any) error {
|
||||||
|
r.SetError("mock error")
|
||||||
|
return errTestNotFound
|
||||||
|
})
|
||||||
|
assert.True(t, cn.IsNotFound(err))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCacheNode_TakeCtxWithRedisError(t *testing.T) {
|
||||||
|
t.Run("not found with redis error", func(t *testing.T) {
|
||||||
|
r, err := miniredis.Run()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
defer r.Close()
|
||||||
|
store, err := redis.NewRedis(redis.RedisConf{
|
||||||
|
Host: r.Addr(),
|
||||||
|
Type: redis.NodeType,
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
cn := cacheNode{
|
||||||
|
rds: store,
|
||||||
|
r: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||||
|
barrier: syncx.NewSingleFlight(),
|
||||||
|
lock: new(sync.Mutex),
|
||||||
|
unstableExpiry: mathx.NewUnstable(expiryDeviation),
|
||||||
|
stat: NewStat("any"),
|
||||||
|
errNotFound: errTestNotFound,
|
||||||
|
}
|
||||||
|
var str string
|
||||||
|
err = cn.Take(&str, "any", func(v any) error {
|
||||||
|
str = "foo"
|
||||||
|
r.SetError("mock error")
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
})
|
})
|
||||||
assert.Equal(t, errDummy, err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCacheNode_TakeNotFoundButChangedByOthers(t *testing.T) {
|
func TestCacheNode_TakeNotFoundButChangedByOthers(t *testing.T) {
|
||||||
|
|||||||
28
core/stores/cache/cacheopt_test.go
vendored
Normal file
28
core/stores/cache/cacheopt_test.go
vendored
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
package cache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCacheOptions(t *testing.T) {
|
||||||
|
t.Run("default options", func(t *testing.T) {
|
||||||
|
o := newOptions()
|
||||||
|
assert.Equal(t, defaultExpiry, o.Expiry)
|
||||||
|
assert.Equal(t, defaultNotFoundExpiry, o.NotFoundExpiry)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with expiry", func(t *testing.T) {
|
||||||
|
o := newOptions(WithExpiry(time.Second))
|
||||||
|
assert.Equal(t, time.Second, o.Expiry)
|
||||||
|
assert.Equal(t, defaultNotFoundExpiry, o.NotFoundExpiry)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with not found expiry", func(t *testing.T) {
|
||||||
|
o := newOptions(WithNotFoundExpiry(time.Second))
|
||||||
|
assert.Equal(t, defaultExpiry, o.Expiry)
|
||||||
|
assert.Equal(t, time.Second, o.NotFoundExpiry)
|
||||||
|
})
|
||||||
|
}
|
||||||
24
core/stores/cache/cleaner.go
vendored
24
core/stores/cache/cleaner.go
vendored
@@ -2,6 +2,7 @@ package cache
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/core/collection"
|
"github.com/zeromicro/go-zero/core/collection"
|
||||||
@@ -19,7 +20,8 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
timingWheel *collection.TimingWheel
|
// use atomic to avoid data race in unit tests
|
||||||
|
timingWheel atomic.Value
|
||||||
taskRunner = threading.NewTaskRunner(cleanWorkers)
|
taskRunner = threading.NewTaskRunner(cleanWorkers)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -30,22 +32,27 @@ type delayTask struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
var err error
|
tw, err := collection.NewTimingWheel(time.Second, timingWheelSlots, clean)
|
||||||
timingWheel, err = collection.NewTimingWheel(time.Second, timingWheelSlots, clean)
|
|
||||||
logx.Must(err)
|
logx.Must(err)
|
||||||
|
timingWheel.Store(tw)
|
||||||
|
|
||||||
proc.AddShutdownListener(func() {
|
proc.AddShutdownListener(func() {
|
||||||
timingWheel.Drain(clean)
|
if err := tw.Drain(clean); err != nil {
|
||||||
|
logx.Errorf("failed to drain timing wheel: %v", err)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddCleanTask adds a clean task on given keys.
|
// AddCleanTask adds a clean task on given keys.
|
||||||
func AddCleanTask(task func() error, keys ...string) {
|
func AddCleanTask(task func() error, keys ...string) {
|
||||||
timingWheel.SetTimer(stringx.Randn(taskKeyLen), delayTask{
|
tw := timingWheel.Load().(*collection.TimingWheel)
|
||||||
|
if err := tw.SetTimer(stringx.Randn(taskKeyLen), delayTask{
|
||||||
delay: time.Second,
|
delay: time.Second,
|
||||||
task: task,
|
task: task,
|
||||||
keys: keys,
|
keys: keys,
|
||||||
}, time.Second)
|
}, time.Second); err != nil {
|
||||||
|
logx.Errorf("failed to set timer for keys: %q, error: %v", formatKeys(keys), err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func clean(key, value any) {
|
func clean(key, value any) {
|
||||||
@@ -59,7 +66,10 @@ func clean(key, value any) {
|
|||||||
next, ok := nextDelay(dt.delay)
|
next, ok := nextDelay(dt.delay)
|
||||||
if ok {
|
if ok {
|
||||||
dt.delay = next
|
dt.delay = next
|
||||||
timingWheel.SetTimer(key, dt, next)
|
tw := timingWheel.Load().(*collection.TimingWheel)
|
||||||
|
if err = tw.SetTimer(key, dt, next); err != nil {
|
||||||
|
logx.Errorf("failed to set timer for key: %s, error: %v", key, err)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
msg := fmt.Sprintf("retried but failed to clear cache with keys: %q, error: %v",
|
msg := fmt.Sprintf("retried but failed to clear cache with keys: %q, error: %v",
|
||||||
formatKeys(dt.keys), err)
|
formatKeys(dt.keys), err)
|
||||||
|
|||||||
14
core/stores/cache/cleaner_test.go
vendored
14
core/stores/cache/cleaner_test.go
vendored
@@ -5,7 +5,9 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/zeromicro/go-zero/core/collection"
|
||||||
"github.com/zeromicro/go-zero/core/proc"
|
"github.com/zeromicro/go-zero/core/proc"
|
||||||
|
"github.com/zeromicro/go-zero/core/timex"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNextDelay(t *testing.T) {
|
func TestNextDelay(t *testing.T) {
|
||||||
@@ -49,6 +51,18 @@ func TestNextDelay(t *testing.T) {
|
|||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run(test.name, func(t *testing.T) {
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
old := timingWheel.Load()
|
||||||
|
ticker := timex.NewFakeTicker()
|
||||||
|
tw, err := collection.NewTimingWheelWithTicker(
|
||||||
|
time.Millisecond, timingWheelSlots, func(key, value any) {
|
||||||
|
clean(key, value)
|
||||||
|
}, ticker)
|
||||||
|
timingWheel.Store(tw)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
timingWheel.Store(old)
|
||||||
|
})
|
||||||
|
|
||||||
next, ok := nextDelay(test.input)
|
next, ok := nextDelay(test.input)
|
||||||
assert.Equal(t, test.ok, ok)
|
assert.Equal(t, test.ok, ok)
|
||||||
assert.Equal(t, test.output, next)
|
assert.Equal(t, test.output, next)
|
||||||
|
|||||||
@@ -3,12 +3,11 @@ package mon
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"strings"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/zeromicro/go-zero/core/breaker"
|
"github.com/zeromicro/go-zero/core/breaker"
|
||||||
"github.com/zeromicro/go-zero/core/logx"
|
"github.com/zeromicro/go-zero/core/logx/logtest"
|
||||||
"github.com/zeromicro/go-zero/core/stringx"
|
"github.com/zeromicro/go-zero/core/stringx"
|
||||||
"github.com/zeromicro/go-zero/core/timex"
|
"github.com/zeromicro/go-zero/core/timex"
|
||||||
"go.mongodb.org/mongo-driver/bson"
|
"go.mongodb.org/mongo-driver/bson"
|
||||||
@@ -573,15 +572,7 @@ func TestDecoratedCollection_LogDuration(t *testing.T) {
|
|||||||
brk: breaker.NewBreaker(),
|
brk: breaker.NewBreaker(),
|
||||||
}
|
}
|
||||||
|
|
||||||
var buf strings.Builder
|
buf := logtest.NewCollector(t)
|
||||||
w := logx.NewWriter(&buf)
|
|
||||||
o := logx.Reset()
|
|
||||||
logx.SetWriter(w)
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
logx.Reset()
|
|
||||||
logx.SetWriter(o)
|
|
||||||
}()
|
|
||||||
|
|
||||||
buf.Reset()
|
buf.Reset()
|
||||||
c.logDuration(context.Background(), "foo", timex.Now(), nil, "bar")
|
c.logDuration(context.Background(), "foo", timex.Now(), nil, "bar")
|
||||||
|
|||||||
@@ -2,10 +2,10 @@ package mon
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"log"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/core/breaker"
|
"github.com/zeromicro/go-zero/core/breaker"
|
||||||
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
"github.com/zeromicro/go-zero/core/timex"
|
"github.com/zeromicro/go-zero/core/timex"
|
||||||
"go.mongodb.org/mongo-driver/mongo"
|
"go.mongodb.org/mongo-driver/mongo"
|
||||||
mopt "go.mongodb.org/mongo-driver/mongo/options"
|
mopt "go.mongodb.org/mongo-driver/mongo/options"
|
||||||
@@ -39,10 +39,7 @@ type (
|
|||||||
// MustNewModel returns a Model, exits on errors.
|
// MustNewModel returns a Model, exits on errors.
|
||||||
func MustNewModel(uri, db, collection string, opts ...Option) *Model {
|
func MustNewModel(uri, db, collection string, opts ...Option) *Model {
|
||||||
model, err := NewModel(uri, db, collection, opts...)
|
model, err := NewModel(uri, db, collection, opts...)
|
||||||
if err != nil {
|
logx.Must(err)
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return model
|
return model
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,12 +3,11 @@ package mon
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"strings"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/zeromicro/go-zero/core/logx"
|
"github.com/zeromicro/go-zero/core/logx/logtest"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestFormatAddrs(t *testing.T) {
|
func TestFormatAddrs(t *testing.T) {
|
||||||
@@ -40,15 +39,7 @@ func TestFormatAddrs(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Test_logDuration(t *testing.T) {
|
func Test_logDuration(t *testing.T) {
|
||||||
var buf strings.Builder
|
buf := logtest.NewCollector(t)
|
||||||
w := logx.NewWriter(&buf)
|
|
||||||
o := logx.Reset()
|
|
||||||
logx.SetWriter(w)
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
logx.Reset()
|
|
||||||
logx.SetWriter(o)
|
|
||||||
}()
|
|
||||||
|
|
||||||
buf.Reset()
|
buf.Reset()
|
||||||
logDuration(context.Background(), "foo", "bar", time.Millisecond, nil)
|
logDuration(context.Background(), "foo", "bar", time.Millisecond, nil)
|
||||||
|
|||||||
@@ -2,8 +2,8 @@ package monc
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"log"
|
|
||||||
|
|
||||||
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
"github.com/zeromicro/go-zero/core/stores/cache"
|
"github.com/zeromicro/go-zero/core/stores/cache"
|
||||||
"github.com/zeromicro/go-zero/core/stores/mon"
|
"github.com/zeromicro/go-zero/core/stores/mon"
|
||||||
"github.com/zeromicro/go-zero/core/stores/redis"
|
"github.com/zeromicro/go-zero/core/stores/redis"
|
||||||
@@ -30,20 +30,14 @@ type Model struct {
|
|||||||
// MustNewModel returns a Model with a cache cluster, exists on errors.
|
// MustNewModel returns a Model with a cache cluster, exists on errors.
|
||||||
func MustNewModel(uri, db, collection string, c cache.CacheConf, opts ...cache.Option) *Model {
|
func MustNewModel(uri, db, collection string, c cache.CacheConf, opts ...cache.Option) *Model {
|
||||||
model, err := NewModel(uri, db, collection, c, opts...)
|
model, err := NewModel(uri, db, collection, c, opts...)
|
||||||
if err != nil {
|
logx.Must(err)
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return model
|
return model
|
||||||
}
|
}
|
||||||
|
|
||||||
// MustNewNodeModel returns a Model with a cache node, exists on errors.
|
// MustNewNodeModel returns a Model with a cache node, exists on errors.
|
||||||
func MustNewNodeModel(uri, db, collection string, rds *redis.Redis, opts ...cache.Option) *Model {
|
func MustNewNodeModel(uri, db, collection string, rds *redis.Redis, opts ...cache.Option) *Model {
|
||||||
model, err := NewNodeModel(uri, db, collection, rds, opts...)
|
model, err := NewNodeModel(uri, db, collection, rds, opts...)
|
||||||
if err != nil {
|
logx.Must(err)
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return model
|
return model
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ func (h hook) AfterProcess(ctx context.Context, cmd red.Cmder) error {
|
|||||||
logDuration(ctx, []red.Cmder{cmd}, duration)
|
logDuration(ctx, []red.Cmder{cmd}, duration)
|
||||||
}
|
}
|
||||||
|
|
||||||
metricReqDur.Observe(int64(duration/time.Millisecond), cmd.Name())
|
metricReqDur.Observe(duration.Milliseconds(), cmd.Name())
|
||||||
if msg := formatError(err); len(msg) > 0 {
|
if msg := formatError(err); len(msg) > 0 {
|
||||||
metricReqErr.Inc(cmd.Name(), msg)
|
metricReqErr.Inc(cmd.Name(), msg)
|
||||||
}
|
}
|
||||||
@@ -103,7 +103,7 @@ func (h hook) AfterProcessPipeline(ctx context.Context, cmds []red.Cmder) error
|
|||||||
logDuration(ctx, cmds, duration)
|
logDuration(ctx, cmds, duration)
|
||||||
}
|
}
|
||||||
|
|
||||||
metricReqDur.Observe(int64(duration/time.Millisecond), "Pipeline")
|
metricReqDur.Observe(duration.Milliseconds(), "Pipeline")
|
||||||
if msg := formatError(batchError.Err()); len(msg) > 0 {
|
if msg := formatError(batchError.Err()); len(msg) > 0 {
|
||||||
metricReqErr.Inc("Pipeline", msg)
|
metricReqErr.Inc("Pipeline", msg)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,14 +2,18 @@ package redis
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
"log"
|
"log"
|
||||||
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
red "github.com/go-redis/redis/v8"
|
red "github.com/go-redis/redis/v8"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/zeromicro/go-zero/core/logx"
|
"github.com/zeromicro/go-zero/core/breaker"
|
||||||
|
"github.com/zeromicro/go-zero/core/logx/logtest"
|
||||||
ztrace "github.com/zeromicro/go-zero/core/trace"
|
ztrace "github.com/zeromicro/go-zero/core/trace"
|
||||||
tracesdk "go.opentelemetry.io/otel/trace"
|
tracesdk "go.opentelemetry.io/otel/trace"
|
||||||
)
|
)
|
||||||
@@ -47,8 +51,7 @@ func TestHookProcessCase2(t *testing.T) {
|
|||||||
})
|
})
|
||||||
defer ztrace.StopAgent()
|
defer ztrace.StopAgent()
|
||||||
|
|
||||||
w, restore := injectLog()
|
w := logtest.NewCollector(t)
|
||||||
defer restore()
|
|
||||||
|
|
||||||
ctx, err := durationHook.BeforeProcess(context.Background(), red.NewCmd(context.Background()))
|
ctx, err := durationHook.BeforeProcess(context.Background(), red.NewCmd(context.Background()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -115,8 +118,7 @@ func TestHookProcessPipelineCase2(t *testing.T) {
|
|||||||
})
|
})
|
||||||
defer ztrace.StopAgent()
|
defer ztrace.StopAgent()
|
||||||
|
|
||||||
w, restore := injectLog()
|
w := logtest.NewCollector(t)
|
||||||
defer restore()
|
|
||||||
|
|
||||||
ctx, err := durationHook.BeforeProcessPipeline(context.Background(), []red.Cmder{
|
ctx, err := durationHook.BeforeProcessPipeline(context.Background(), []red.Cmder{
|
||||||
red.NewCmd(context.Background()),
|
red.NewCmd(context.Background()),
|
||||||
@@ -135,8 +137,7 @@ func TestHookProcessPipelineCase2(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestHookProcessPipelineCase3(t *testing.T) {
|
func TestHookProcessPipelineCase3(t *testing.T) {
|
||||||
w, restore := injectLog()
|
w := logtest.NewCollector(t)
|
||||||
defer restore()
|
|
||||||
|
|
||||||
assert.Nil(t, durationHook.AfterProcessPipeline(context.Background(), []red.Cmder{
|
assert.Nil(t, durationHook.AfterProcessPipeline(context.Background(), []red.Cmder{
|
||||||
red.NewCmd(context.Background()),
|
red.NewCmd(context.Background()),
|
||||||
@@ -145,8 +146,7 @@ func TestHookProcessPipelineCase3(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestHookProcessPipelineCase4(t *testing.T) {
|
func TestHookProcessPipelineCase4(t *testing.T) {
|
||||||
w, restore := injectLog()
|
w := logtest.NewCollector(t)
|
||||||
defer restore()
|
|
||||||
|
|
||||||
ctx := context.WithValue(context.Background(), startTimeKey, "foo")
|
ctx := context.WithValue(context.Background(), startTimeKey, "foo")
|
||||||
assert.Nil(t, durationHook.AfterProcessPipeline(ctx, []red.Cmder{
|
assert.Nil(t, durationHook.AfterProcessPipeline(ctx, []red.Cmder{
|
||||||
@@ -169,8 +169,7 @@ func TestHookProcessPipelineCase5(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestLogDuration(t *testing.T) {
|
func TestLogDuration(t *testing.T) {
|
||||||
w, restore := injectLog()
|
w := logtest.NewCollector(t)
|
||||||
defer restore()
|
|
||||||
|
|
||||||
logDuration(context.Background(), []red.Cmder{
|
logDuration(context.Background(), []red.Cmder{
|
||||||
red.NewCmd(context.Background(), "get", "foo"),
|
red.NewCmd(context.Background(), "get", "foo"),
|
||||||
@@ -184,14 +183,39 @@ func TestLogDuration(t *testing.T) {
|
|||||||
assert.True(t, strings.Contains(w.String(), `get foo\nset bar 0`))
|
assert.True(t, strings.Contains(w.String(), `get foo\nset bar 0`))
|
||||||
}
|
}
|
||||||
|
|
||||||
func injectLog() (r *strings.Builder, restore func()) {
|
func TestFormatError(t *testing.T) {
|
||||||
var buf strings.Builder
|
// Test case: err is OpError
|
||||||
w := logx.NewWriter(&buf)
|
err := &net.OpError{
|
||||||
o := logx.Reset()
|
Err: mockOpError{},
|
||||||
logx.SetWriter(w)
|
|
||||||
|
|
||||||
return &buf, func() {
|
|
||||||
logx.Reset()
|
|
||||||
logx.SetWriter(o)
|
|
||||||
}
|
}
|
||||||
|
assert.Equal(t, "timeout", formatError(err))
|
||||||
|
|
||||||
|
// Test case: err is nil
|
||||||
|
assert.Equal(t, "", formatError(nil))
|
||||||
|
|
||||||
|
// Test case: err is red.Nil
|
||||||
|
assert.Equal(t, "", formatError(red.Nil))
|
||||||
|
|
||||||
|
// Test case: err is io.EOF
|
||||||
|
assert.Equal(t, "eof", formatError(io.EOF))
|
||||||
|
|
||||||
|
// Test case: err is context.DeadlineExceeded
|
||||||
|
assert.Equal(t, "context deadline", formatError(context.DeadlineExceeded))
|
||||||
|
|
||||||
|
// Test case: err is breaker.ErrServiceUnavailable
|
||||||
|
assert.Equal(t, "breaker", formatError(breaker.ErrServiceUnavailable))
|
||||||
|
|
||||||
|
// Test case: err is unknown
|
||||||
|
assert.Equal(t, "unexpected error", formatError(errors.New("some error")))
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockOpError struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mockOpError) Error() string {
|
||||||
|
return "mock error"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mockOpError) Timeout() bool {
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,13 +4,13 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
red "github.com/go-redis/redis/v8"
|
red "github.com/go-redis/redis/v8"
|
||||||
"github.com/zeromicro/go-zero/core/breaker"
|
"github.com/zeromicro/go-zero/core/breaker"
|
||||||
"github.com/zeromicro/go-zero/core/errorx"
|
"github.com/zeromicro/go-zero/core/errorx"
|
||||||
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
"github.com/zeromicro/go-zero/core/mapping"
|
"github.com/zeromicro/go-zero/core/mapping"
|
||||||
"github.com/zeromicro/go-zero/core/syncx"
|
"github.com/zeromicro/go-zero/core/syncx"
|
||||||
)
|
)
|
||||||
@@ -91,22 +91,19 @@ type (
|
|||||||
Script = red.Script
|
Script = red.Script
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// MustNewRedis returns a Redis with given options.
|
||||||
|
func MustNewRedis(conf RedisConf, opts ...Option) *Redis {
|
||||||
|
rds, err := NewRedis(conf, opts...)
|
||||||
|
logx.Must(err)
|
||||||
|
return rds
|
||||||
|
}
|
||||||
|
|
||||||
// New returns a Redis with given options.
|
// New returns a Redis with given options.
|
||||||
// Deprecated: use MustNewRedis or NewRedis instead.
|
// Deprecated: use MustNewRedis or NewRedis instead.
|
||||||
func New(addr string, opts ...Option) *Redis {
|
func New(addr string, opts ...Option) *Redis {
|
||||||
return newRedis(addr, opts...)
|
return newRedis(addr, opts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// MustNewRedis returns a Redis with given options.
|
|
||||||
func MustNewRedis(conf RedisConf, opts ...Option) *Redis {
|
|
||||||
rds, err := NewRedis(conf, opts...)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return rds
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewRedis returns a Redis with given options.
|
// NewRedis returns a Redis with given options.
|
||||||
func NewRedis(conf RedisConf, opts ...Option) (*Redis, error) {
|
func NewRedis(conf RedisConf, opts ...Option) (*Redis, error) {
|
||||||
if err := conf.Validate(); err != nil {
|
if err := conf.Validate(); err != nil {
|
||||||
|
|||||||
@@ -3,6 +3,8 @@ package redis
|
|||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/alicebob/miniredis/v2"
|
||||||
|
red "github.com/go-redis/redis/v8"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -41,3 +43,17 @@ func TestSplitClusterAddrs(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetCluster(t *testing.T) {
|
||||||
|
r := miniredis.RunT(t)
|
||||||
|
defer r.Close()
|
||||||
|
c, err := getCluster(&Redis{
|
||||||
|
Addr: r.Addr(),
|
||||||
|
Type: ClusterType,
|
||||||
|
tls: true,
|
||||||
|
hooks: []red.Hook{durationHook},
|
||||||
|
})
|
||||||
|
if assert.NoError(t, err) {
|
||||||
|
assert.NotNil(t, c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -97,6 +97,9 @@ func (cc CachedConn) Exec(exec ExecFn, keys ...string) (sql.Result, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ExecCtx runs given exec on given keys, and returns execution result.
|
// ExecCtx runs given exec on given keys, and returns execution result.
|
||||||
|
// If DB operation succeeds, it will delete cache with given keys,
|
||||||
|
// if DB operation fails, it will return nil result and non-nil error,
|
||||||
|
// if DB operation succeeds but cache deletion fails, it will return result and non-nil error.
|
||||||
func (cc CachedConn) ExecCtx(ctx context.Context, exec ExecCtxFn, keys ...string) (
|
func (cc CachedConn) ExecCtx(ctx context.Context, exec ExecCtxFn, keys ...string) (
|
||||||
sql.Result, error) {
|
sql.Result, error) {
|
||||||
res, err := exec(ctx, cc.db)
|
res, err := exec(ctx, cc.db)
|
||||||
@@ -104,11 +107,7 @@ func (cc CachedConn) ExecCtx(ctx context.Context, exec ExecCtxFn, keys ...string
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := cc.DelCacheCtx(ctx, keys...); err != nil {
|
return res, cc.DelCacheCtx(ctx, keys...)
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return res, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExecNoCache runs exec with given sql statement, without affecting cache.
|
// ExecNoCache runs exec with given sql statement, without affecting cache.
|
||||||
@@ -214,6 +213,17 @@ func (cc CachedConn) SetCacheCtx(ctx context.Context, key string, val any) error
|
|||||||
return cc.cache.SetCtx(ctx, key, val)
|
return cc.cache.SetCtx(ctx, key, val)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetCacheWithExpire sets v into cache with given key with given expire.
|
||||||
|
func (cc CachedConn) SetCacheWithExpire(key string, val any, expire time.Duration) error {
|
||||||
|
return cc.SetCacheWithExpireCtx(context.Background(), key, val, expire)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetCacheWithExpireCtx sets v into cache with given key with given expire.
|
||||||
|
func (cc CachedConn) SetCacheWithExpireCtx(ctx context.Context, key string, val any,
|
||||||
|
expire time.Duration) error {
|
||||||
|
return cc.cache.SetWithExpireCtx(ctx, key, val, expire)
|
||||||
|
}
|
||||||
|
|
||||||
// Transact runs given fn in transaction mode.
|
// Transact runs given fn in transaction mode.
|
||||||
func (cc CachedConn) Transact(fn func(sqlx.Session) error) error {
|
func (cc CachedConn) Transact(fn func(sqlx.Session) error) error {
|
||||||
fnCtx := func(_ context.Context, session sqlx.Session) error {
|
fnCtx := func(_ context.Context, session sqlx.Session) error {
|
||||||
@@ -226,3 +236,15 @@ func (cc CachedConn) Transact(fn func(sqlx.Session) error) error {
|
|||||||
func (cc CachedConn) TransactCtx(ctx context.Context, fn func(context.Context, sqlx.Session) error) error {
|
func (cc CachedConn) TransactCtx(ctx context.Context, fn func(context.Context, sqlx.Session) error) error {
|
||||||
return cc.db.TransactCtx(ctx, fn)
|
return cc.db.TransactCtx(ctx, fn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithSession returns a new CachedConn with given session.
|
||||||
|
// If query from session, the uncommitted data might be returned.
|
||||||
|
// Don't query for the uncommitted data, you should just use it,
|
||||||
|
// and don't use the cache for the uncommitted data.
|
||||||
|
// Not recommend to use cache within transactions due to consistency problem.
|
||||||
|
func (cc CachedConn) WithSession(session sqlx.Session) CachedConn {
|
||||||
|
return CachedConn{
|
||||||
|
db: sqlx.NewSqlConnFromSession(session),
|
||||||
|
cache: cc.cache,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/DATA-DOG/go-sqlmock"
|
||||||
"github.com/alicebob/miniredis/v2"
|
"github.com/alicebob/miniredis/v2"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/zeromicro/go-zero/core/fx"
|
"github.com/zeromicro/go-zero/core/fx"
|
||||||
@@ -24,6 +25,8 @@ import (
|
|||||||
"github.com/zeromicro/go-zero/core/stores/redis"
|
"github.com/zeromicro/go-zero/core/stores/redis"
|
||||||
"github.com/zeromicro/go-zero/core/stores/redis/redistest"
|
"github.com/zeromicro/go-zero/core/stores/redis/redistest"
|
||||||
"github.com/zeromicro/go-zero/core/stores/sqlx"
|
"github.com/zeromicro/go-zero/core/stores/sqlx"
|
||||||
|
"github.com/zeromicro/go-zero/core/syncx"
|
||||||
|
"github.com/zeromicro/go-zero/internal/dbtest"
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
@@ -39,7 +42,7 @@ func TestCachedConn_GetCache(t *testing.T) {
|
|||||||
var value string
|
var value string
|
||||||
err := c.GetCache("any", &value)
|
err := c.GetCache("any", &value)
|
||||||
assert.Equal(t, ErrNotFound, err)
|
assert.Equal(t, ErrNotFound, err)
|
||||||
r.Set("any", `"value"`)
|
_ = r.Set("any", `"value"`)
|
||||||
err = c.GetCache("any", &value)
|
err = c.GetCache("any", &value)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, "value", value)
|
assert.Equal(t, "value", value)
|
||||||
@@ -368,6 +371,24 @@ func TestStatFromMemory(t *testing.T) {
|
|||||||
assert.Equal(t, uint64(9), atomic.LoadUint64(&stats.Hit))
|
assert.Equal(t, uint64(9), atomic.LoadUint64(&stats.Hit))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCachedConn_DelCache(t *testing.T) {
|
||||||
|
r := redistest.CreateRedis(t)
|
||||||
|
|
||||||
|
const (
|
||||||
|
key = "user"
|
||||||
|
value = "any"
|
||||||
|
)
|
||||||
|
assert.NoError(t, r.Set(key, value))
|
||||||
|
|
||||||
|
c := NewNodeConn(&trackedConn{}, r, cache.WithExpiry(time.Second*30))
|
||||||
|
err := c.DelCache(key)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
val, err := r.Get(key)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Empty(t, val)
|
||||||
|
}
|
||||||
|
|
||||||
func TestCachedConnQueryRow(t *testing.T) {
|
func TestCachedConnQueryRow(t *testing.T) {
|
||||||
r := redistest.CreateRedis(t)
|
r := redistest.CreateRedis(t)
|
||||||
|
|
||||||
@@ -450,6 +471,36 @@ func TestCachedConnExec(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCachedConnExecDropCache(t *testing.T) {
|
func TestCachedConnExecDropCache(t *testing.T) {
|
||||||
|
t.Run("drop cache", func(t *testing.T) {
|
||||||
|
r, err := miniredis.Run()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
defer fx.DoWithTimeout(func() error {
|
||||||
|
r.Close()
|
||||||
|
return nil
|
||||||
|
}, time.Second)
|
||||||
|
|
||||||
|
const (
|
||||||
|
key = "user"
|
||||||
|
value = "any"
|
||||||
|
)
|
||||||
|
var conn trackedConn
|
||||||
|
c := NewNodeConn(&conn, redis.New(r.Addr()), cache.WithExpiry(time.Second*30))
|
||||||
|
assert.Nil(t, c.SetCache(key, value))
|
||||||
|
_, err = c.Exec(func(conn sqlx.SqlConn) (result sql.Result, e error) {
|
||||||
|
return conn.Exec("delete from user_table where id='kevin'")
|
||||||
|
}, key)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.True(t, conn.execValue)
|
||||||
|
_, err = r.Get(key)
|
||||||
|
assert.Exactly(t, miniredis.ErrKeyNotFound, err)
|
||||||
|
_, err = c.Exec(func(conn sqlx.SqlConn) (result sql.Result, e error) {
|
||||||
|
return nil, errors.New("foo")
|
||||||
|
}, key)
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCachedConn_SetCacheWithExpire(t *testing.T) {
|
||||||
r, err := miniredis.Run()
|
r, err := miniredis.Run()
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
defer fx.DoWithTimeout(func() error {
|
defer fx.DoWithTimeout(func() error {
|
||||||
@@ -463,18 +514,13 @@ func TestCachedConnExecDropCache(t *testing.T) {
|
|||||||
)
|
)
|
||||||
var conn trackedConn
|
var conn trackedConn
|
||||||
c := NewNodeConn(&conn, redis.New(r.Addr()), cache.WithExpiry(time.Second*30))
|
c := NewNodeConn(&conn, redis.New(r.Addr()), cache.WithExpiry(time.Second*30))
|
||||||
assert.Nil(t, c.SetCache(key, value))
|
assert.Nil(t, c.SetCacheWithExpire(key, value, time.Minute))
|
||||||
_, err = c.Exec(func(conn sqlx.SqlConn) (result sql.Result, e error) {
|
val, err := r.Get(key)
|
||||||
return conn.Exec("delete from user_table where id='kevin'")
|
if assert.NoError(t, err) {
|
||||||
}, key)
|
ttl := r.TTL(key)
|
||||||
assert.Nil(t, err)
|
assert.True(t, ttl > 0 && ttl <= time.Minute)
|
||||||
assert.True(t, conn.execValue)
|
assert.Equal(t, fmt.Sprintf("%q", value), val)
|
||||||
_, err = r.Get(key)
|
}
|
||||||
assert.Exactly(t, miniredis.ErrKeyNotFound, err)
|
|
||||||
_, err = c.Exec(func(conn sqlx.SqlConn) (result sql.Result, e error) {
|
|
||||||
return nil, errors.New("foo")
|
|
||||||
}, key)
|
|
||||||
assert.NotNil(t, err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCachedConnExecDropCacheFailed(t *testing.T) {
|
func TestCachedConnExecDropCacheFailed(t *testing.T) {
|
||||||
@@ -543,6 +589,125 @@ func TestNewConnWithCache(t *testing.T) {
|
|||||||
assert.True(t, conn.execValue)
|
assert.True(t, conn.execValue)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCachedConn_WithSession(t *testing.T) {
|
||||||
|
dbtest.RunTxTest(t, func(tx *sql.Tx, mock sqlmock.Sqlmock) {
|
||||||
|
mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
|
||||||
|
|
||||||
|
r := redistest.CreateRedis(t)
|
||||||
|
conn := CachedConn{
|
||||||
|
cache: cache.NewNode(r, syncx.NewSingleFlight(), stats, sql.ErrNoRows),
|
||||||
|
}
|
||||||
|
conn = conn.WithSession(sqlx.NewSessionFromTx(tx))
|
||||||
|
res, err := conn.Exec(func(conn sqlx.SqlConn) (sql.Result, error) {
|
||||||
|
return conn.Exec("any")
|
||||||
|
}, "foo")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
last, err := res.LastInsertId()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, int64(2), last)
|
||||||
|
affected, err := res.RowsAffected()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, int64(3), affected)
|
||||||
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
mock.ExpectBegin()
|
||||||
|
mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
|
||||||
|
mock.ExpectCommit()
|
||||||
|
|
||||||
|
r := redistest.CreateRedis(t)
|
||||||
|
conn := CachedConn{
|
||||||
|
db: sqlx.NewSqlConnFromDB(db),
|
||||||
|
cache: cache.NewNode(r, syncx.NewSingleFlight(), stats, sql.ErrNoRows),
|
||||||
|
}
|
||||||
|
assert.NoError(t, conn.Transact(func(session sqlx.Session) error {
|
||||||
|
conn = conn.WithSession(session)
|
||||||
|
res, err := conn.Exec(func(conn sqlx.SqlConn) (sql.Result, error) {
|
||||||
|
return conn.Exec("any")
|
||||||
|
}, "foo")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
last, err := res.LastInsertId()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, int64(2), last)
|
||||||
|
affected, err := res.RowsAffected()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, int64(3), affected)
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
mock.ExpectBegin()
|
||||||
|
mock.ExpectExec("any").WillReturnError(errors.New("foo"))
|
||||||
|
mock.ExpectRollback()
|
||||||
|
|
||||||
|
r := redistest.CreateRedis(t)
|
||||||
|
conn := CachedConn{
|
||||||
|
db: sqlx.NewSqlConnFromDB(db),
|
||||||
|
cache: cache.NewNode(r, syncx.NewSingleFlight(), stats, sql.ErrNoRows),
|
||||||
|
}
|
||||||
|
assert.Error(t, conn.Transact(func(session sqlx.Session) error {
|
||||||
|
conn = conn.WithSession(session)
|
||||||
|
_, err := conn.Exec(func(conn sqlx.SqlConn) (sql.Result, error) {
|
||||||
|
return conn.Exec("any")
|
||||||
|
}, "bar")
|
||||||
|
return err
|
||||||
|
}))
|
||||||
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
mock.ExpectBegin()
|
||||||
|
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2))
|
||||||
|
mock.ExpectCommit()
|
||||||
|
|
||||||
|
r := redistest.CreateRedis(t)
|
||||||
|
conn := CachedConn{
|
||||||
|
db: sqlx.NewSqlConnFromDB(db),
|
||||||
|
cache: cache.NewNode(r, syncx.NewSingleFlight(), stats, sql.ErrNoRows),
|
||||||
|
}
|
||||||
|
assert.NoError(t, conn.Transact(func(session sqlx.Session) error {
|
||||||
|
var val string
|
||||||
|
conn = conn.WithSession(session)
|
||||||
|
err := conn.QueryRow(&val, "foo", func(conn sqlx.SqlConn, v interface{}) error {
|
||||||
|
return conn.QueryRow(v, "any")
|
||||||
|
})
|
||||||
|
assert.Equal(t, "2", val)
|
||||||
|
return err
|
||||||
|
}))
|
||||||
|
val, err := r.Get("foo")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, `"2"`, val)
|
||||||
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
mock.ExpectBegin()
|
||||||
|
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2))
|
||||||
|
mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
|
||||||
|
mock.ExpectCommit()
|
||||||
|
|
||||||
|
r := redistest.CreateRedis(t)
|
||||||
|
conn := CachedConn{
|
||||||
|
db: sqlx.NewSqlConnFromDB(db),
|
||||||
|
cache: cache.NewNode(r, syncx.NewSingleFlight(), stats, sql.ErrNoRows),
|
||||||
|
}
|
||||||
|
assert.NoError(t, conn.Transact(func(session sqlx.Session) error {
|
||||||
|
var val string
|
||||||
|
conn = conn.WithSession(session)
|
||||||
|
assert.NoError(t, conn.QueryRow(&val, "foo", func(conn sqlx.SqlConn, v interface{}) error {
|
||||||
|
return conn.QueryRow(v, "any")
|
||||||
|
}))
|
||||||
|
assert.Equal(t, "2", val)
|
||||||
|
_, err := conn.Exec(func(conn sqlx.SqlConn) (sql.Result, error) {
|
||||||
|
return conn.Exec("any")
|
||||||
|
}, "foo")
|
||||||
|
return err
|
||||||
|
}))
|
||||||
|
val, err := r.Get("foo")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Empty(t, val)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func resetStats() {
|
func resetStats() {
|
||||||
atomic.StoreUint64(&stats.Total, 0)
|
atomic.StoreUint64(&stats.Total, 0)
|
||||||
atomic.StoreUint64(&stats.Hit, 0)
|
atomic.StoreUint64(&stats.Hit, 0)
|
||||||
@@ -554,35 +719,35 @@ type dummySqlConn struct {
|
|||||||
queryRow func(any, string, ...any) error
|
queryRow func(any, string, ...any) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d dummySqlConn) ExecCtx(ctx context.Context, query string, args ...any) (sql.Result, error) {
|
func (d dummySqlConn) ExecCtx(_ context.Context, _ string, _ ...any) (sql.Result, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d dummySqlConn) PrepareCtx(ctx context.Context, query string) (sqlx.StmtSession, error) {
|
func (d dummySqlConn) PrepareCtx(_ context.Context, _ string) (sqlx.StmtSession, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d dummySqlConn) QueryRowPartialCtx(ctx context.Context, v any, query string, args ...any) error {
|
func (d dummySqlConn) QueryRowPartialCtx(_ context.Context, _ any, _ string, _ ...any) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d dummySqlConn) QueryRowsCtx(ctx context.Context, v any, query string, args ...any) error {
|
func (d dummySqlConn) QueryRowsCtx(_ context.Context, _ any, _ string, _ ...any) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d dummySqlConn) QueryRowsPartialCtx(ctx context.Context, v any, query string, args ...any) error {
|
func (d dummySqlConn) QueryRowsPartialCtx(_ context.Context, _ any, _ string, _ ...any) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d dummySqlConn) TransactCtx(ctx context.Context, fn func(context.Context, sqlx.Session) error) error {
|
func (d dummySqlConn) TransactCtx(_ context.Context, _ func(context.Context, sqlx.Session) error) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d dummySqlConn) Exec(query string, args ...any) (sql.Result, error) {
|
func (d dummySqlConn) Exec(_ string, _ ...any) (sql.Result, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d dummySqlConn) Prepare(query string) (sqlx.StmtSession, error) {
|
func (d dummySqlConn) Prepare(_ string) (sqlx.StmtSession, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -597,15 +762,15 @@ func (d dummySqlConn) QueryRowCtx(_ context.Context, v any, query string, args .
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d dummySqlConn) QueryRowPartial(v any, query string, args ...any) error {
|
func (d dummySqlConn) QueryRowPartial(_ any, _ string, _ ...any) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d dummySqlConn) QueryRows(v any, query string, args ...any) error {
|
func (d dummySqlConn) QueryRows(_ any, _ string, _ ...any) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d dummySqlConn) QueryRowsPartial(v any, query string, args ...any) error {
|
func (d dummySqlConn) QueryRowsPartial(_ any, _ string, _ ...any) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import (
|
|||||||
|
|
||||||
"github.com/DATA-DOG/go-sqlmock"
|
"github.com/DATA-DOG/go-sqlmock"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/zeromicro/go-zero/core/logx"
|
"github.com/zeromicro/go-zero/internal/dbtest"
|
||||||
)
|
)
|
||||||
|
|
||||||
type mockedConn struct {
|
type mockedConn struct {
|
||||||
@@ -81,7 +81,7 @@ func (c *mockedConn) Transact(func(session Session) error) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestBulkInserter(t *testing.T) {
|
func TestBulkInserter(t *testing.T) {
|
||||||
runSqlTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
var conn mockedConn
|
var conn mockedConn
|
||||||
inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES(?, ?, ?)`)
|
inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES(?, ?, ?)`)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
@@ -98,7 +98,7 @@ func TestBulkInserter(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestBulkInserterSuffix(t *testing.T) {
|
func TestBulkInserterSuffix(t *testing.T) {
|
||||||
runSqlTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
var conn mockedConn
|
var conn mockedConn
|
||||||
inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES`+
|
inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES`+
|
||||||
`(?, ?, ?) ON DUPLICATE KEY UPDATE is_overtime=VALUES(is_overtime)`)
|
`(?, ?, ?) ON DUPLICATE KEY UPDATE is_overtime=VALUES(is_overtime)`)
|
||||||
@@ -119,7 +119,7 @@ func TestBulkInserterSuffix(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestBulkInserterBadStatement(t *testing.T) {
|
func TestBulkInserterBadStatement(t *testing.T) {
|
||||||
runSqlTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
var conn mockedConn
|
var conn mockedConn
|
||||||
_, err := NewBulkInserter(&conn, "foo")
|
_, err := NewBulkInserter(&conn, "foo")
|
||||||
assert.NotNil(t, err)
|
assert.NotNil(t, err)
|
||||||
@@ -144,19 +144,3 @@ func TestBulkInserter_Update(t *testing.T) {
|
|||||||
assert.NotNil(t, inserter.UpdateStmt("foo"))
|
assert.NotNil(t, inserter.UpdateStmt("foo"))
|
||||||
assert.NotNil(t, inserter.Insert("foo", "bar"))
|
assert.NotNil(t, inserter.Insert("foo", "bar"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func runSqlTest(t *testing.T, fn func(db *sql.DB, mock sqlmock.Sqlmock)) {
|
|
||||||
logx.Disable()
|
|
||||||
|
|
||||||
db, mock, err := sqlmock.New()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
|
|
||||||
}
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
fn(db, mock)
|
|
||||||
|
|
||||||
if err := mock.ExpectationsWereMet(); err != nil {
|
|
||||||
t.Errorf("there were unfulfilled expectations: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
14
core/stores/sqlx/errors.go
Normal file
14
core/stores/sqlx/errors.go
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
package sqlx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrNotFound is an alias of sql.ErrNoRows
|
||||||
|
ErrNotFound = sql.ErrNoRows
|
||||||
|
|
||||||
|
errCantNestTx = errors.New("cannot nest transactions")
|
||||||
|
errNoRawDBFromTx = errors.New("cannot get raw db from transaction")
|
||||||
|
)
|
||||||
@@ -32,7 +32,5 @@ func mysqlAcceptable(err error) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func withMysqlAcceptable() SqlOption {
|
func withMysqlAcceptable() SqlOption {
|
||||||
return func(conn *commonSqlConn) {
|
return WithAcceptable(mysqlAcceptable)
|
||||||
conn.accept = mysqlAcceptable
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -54,27 +54,39 @@ func getTaggedFieldValueMap(v reflect.Value) (map[string]any, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
valueField := reflect.Indirect(v).Field(i)
|
valueField := reflect.Indirect(v).Field(i)
|
||||||
switch valueField.Kind() {
|
valueData, err := getValueInterface(valueField)
|
||||||
case reflect.Ptr:
|
if err != nil {
|
||||||
if !valueField.CanInterface() {
|
return nil, err
|
||||||
return nil, ErrNotReadableValue
|
|
||||||
}
|
|
||||||
if valueField.IsNil() {
|
|
||||||
baseValueType := mapping.Deref(valueField.Type())
|
|
||||||
valueField.Set(reflect.New(baseValueType))
|
|
||||||
}
|
|
||||||
result[key] = valueField.Interface()
|
|
||||||
default:
|
|
||||||
if !valueField.CanAddr() || !valueField.Addr().CanInterface() {
|
|
||||||
return nil, ErrNotReadableValue
|
|
||||||
}
|
|
||||||
result[key] = valueField.Addr().Interface()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
result[key] = valueData
|
||||||
}
|
}
|
||||||
|
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getValueInterface(value reflect.Value) (any, error) {
|
||||||
|
switch value.Kind() {
|
||||||
|
case reflect.Ptr:
|
||||||
|
if !value.CanInterface() {
|
||||||
|
return nil, ErrNotReadableValue
|
||||||
|
}
|
||||||
|
|
||||||
|
if value.IsNil() {
|
||||||
|
baseValueType := mapping.Deref(value.Type())
|
||||||
|
value.Set(reflect.New(baseValueType))
|
||||||
|
}
|
||||||
|
|
||||||
|
return value.Interface(), nil
|
||||||
|
default:
|
||||||
|
if !value.CanAddr() || !value.Addr().CanInterface() {
|
||||||
|
return nil, ErrNotReadableValue
|
||||||
|
}
|
||||||
|
|
||||||
|
return value.Addr().Interface(), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func mapStructFieldsIntoSlice(v reflect.Value, columns []string, strict bool) ([]any, error) {
|
func mapStructFieldsIntoSlice(v reflect.Value, columns []string, strict bool) ([]any, error) {
|
||||||
fields := unwrapFields(v)
|
fields := unwrapFields(v)
|
||||||
if strict && len(columns) < len(fields) {
|
if strict && len(columns) < len(fields) {
|
||||||
@@ -88,24 +100,18 @@ func mapStructFieldsIntoSlice(v reflect.Value, columns []string, strict bool) ([
|
|||||||
|
|
||||||
values := make([]any, len(columns))
|
values := make([]any, len(columns))
|
||||||
if len(taggedMap) == 0 {
|
if len(taggedMap) == 0 {
|
||||||
|
if len(fields) < len(values) {
|
||||||
|
return nil, ErrNotMatchDestination
|
||||||
|
}
|
||||||
|
|
||||||
for i := 0; i < len(values); i++ {
|
for i := 0; i < len(values); i++ {
|
||||||
valueField := fields[i]
|
valueField := fields[i]
|
||||||
switch valueField.Kind() {
|
valueData, err := getValueInterface(valueField)
|
||||||
case reflect.Ptr:
|
if err != nil {
|
||||||
if !valueField.CanInterface() {
|
return nil, err
|
||||||
return nil, ErrNotReadableValue
|
|
||||||
}
|
|
||||||
if valueField.IsNil() {
|
|
||||||
baseValueType := mapping.Deref(valueField.Type())
|
|
||||||
valueField.Set(reflect.New(baseValueType))
|
|
||||||
}
|
|
||||||
values[i] = valueField.Interface()
|
|
||||||
default:
|
|
||||||
if !valueField.CanAddr() || !valueField.Addr().CanInterface() {
|
|
||||||
return nil, ErrNotReadableValue
|
|
||||||
}
|
|
||||||
values[i] = valueField.Addr().Interface()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
values[i] = valueData
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for i, column := range columns {
|
for i, column := range columns {
|
||||||
@@ -140,7 +146,7 @@ func unmarshalRow(v any, scanner rowsScanner, strict bool) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
rv := reflect.ValueOf(v)
|
rv := reflect.ValueOf(v)
|
||||||
if err := mapping.ValidatePtr(&rv); err != nil {
|
if err := mapping.ValidatePtr(rv); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -152,11 +158,11 @@ func unmarshalRow(v any, scanner rowsScanner, strict bool) error {
|
|||||||
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
|
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
|
||||||
reflect.Float32, reflect.Float64,
|
reflect.Float32, reflect.Float64,
|
||||||
reflect.String:
|
reflect.String:
|
||||||
if rve.CanSet() {
|
if !rve.CanSet() {
|
||||||
return scanner.Scan(v)
|
return ErrNotSettable
|
||||||
}
|
}
|
||||||
|
|
||||||
return ErrNotSettable
|
return scanner.Scan(v)
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
columns, err := scanner.Columns()
|
columns, err := scanner.Columns()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -176,76 +182,73 @@ func unmarshalRow(v any, scanner rowsScanner, strict bool) error {
|
|||||||
|
|
||||||
func unmarshalRows(v any, scanner rowsScanner, strict bool) error {
|
func unmarshalRows(v any, scanner rowsScanner, strict bool) error {
|
||||||
rv := reflect.ValueOf(v)
|
rv := reflect.ValueOf(v)
|
||||||
if err := mapping.ValidatePtr(&rv); err != nil {
|
if err := mapping.ValidatePtr(rv); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
rt := reflect.TypeOf(v)
|
rt := reflect.TypeOf(v)
|
||||||
rte := rt.Elem()
|
rte := rt.Elem()
|
||||||
rve := rv.Elem()
|
rve := rv.Elem()
|
||||||
|
if !rve.CanSet() {
|
||||||
|
return ErrNotSettable
|
||||||
|
}
|
||||||
|
|
||||||
switch rte.Kind() {
|
switch rte.Kind() {
|
||||||
case reflect.Slice:
|
case reflect.Slice:
|
||||||
if rve.CanSet() {
|
ptr := rte.Elem().Kind() == reflect.Ptr
|
||||||
ptr := rte.Elem().Kind() == reflect.Ptr
|
appendFn := func(item reflect.Value) {
|
||||||
appendFn := func(item reflect.Value) {
|
if ptr {
|
||||||
if ptr {
|
rve.Set(reflect.Append(rve, item))
|
||||||
rve.Set(reflect.Append(rve, item))
|
} else {
|
||||||
} else {
|
rve.Set(reflect.Append(rve, reflect.Indirect(item)))
|
||||||
rve.Set(reflect.Append(rve, reflect.Indirect(item)))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
fillFn := func(value any) error {
|
}
|
||||||
if rve.CanSet() {
|
fillFn := func(value any) error {
|
||||||
if err := scanner.Scan(value); err != nil {
|
if err := scanner.Scan(value); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
|
||||||
|
|
||||||
appendFn(reflect.ValueOf(value))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return ErrNotSettable
|
|
||||||
}
|
}
|
||||||
|
|
||||||
base := mapping.Deref(rte.Elem())
|
appendFn(reflect.ValueOf(value))
|
||||||
switch base.Kind() {
|
return nil
|
||||||
case reflect.Bool,
|
}
|
||||||
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
|
||||||
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
|
base := mapping.Deref(rte.Elem())
|
||||||
reflect.Float32, reflect.Float64,
|
switch base.Kind() {
|
||||||
reflect.String:
|
case reflect.Bool,
|
||||||
for scanner.Next() {
|
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
||||||
value := reflect.New(base)
|
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
|
||||||
if err := fillFn(value.Interface()); err != nil {
|
reflect.Float32, reflect.Float64,
|
||||||
return err
|
reflect.String:
|
||||||
}
|
for scanner.Next() {
|
||||||
|
value := reflect.New(base)
|
||||||
|
if err := fillFn(value.Interface()); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
case reflect.Struct:
|
}
|
||||||
columns, err := scanner.Columns()
|
case reflect.Struct:
|
||||||
|
columns, err := scanner.Columns()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for scanner.Next() {
|
||||||
|
value := reflect.New(base)
|
||||||
|
values, err := mapStructFieldsIntoSlice(value, columns, strict)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
for scanner.Next() {
|
if err := scanner.Scan(values...); err != nil {
|
||||||
value := reflect.New(base)
|
return err
|
||||||
values, err := mapStructFieldsIntoSlice(value, columns, strict)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := scanner.Scan(values...); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
appendFn(value)
|
|
||||||
}
|
}
|
||||||
default:
|
|
||||||
return ErrUnsupportedValueType
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
appendFn(value)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return ErrUnsupportedValueType
|
||||||
}
|
}
|
||||||
|
|
||||||
return ErrNotSettable
|
return nil
|
||||||
default:
|
default:
|
||||||
return ErrUnsupportedValueType
|
return ErrUnsupportedValueType
|
||||||
}
|
}
|
||||||
@@ -257,6 +260,10 @@ func unwrapFields(v reflect.Value) []reflect.Value {
|
|||||||
|
|
||||||
for i := 0; i < indirect.NumField(); i++ {
|
for i := 0; i < indirect.NumField(); i++ {
|
||||||
child := indirect.Field(i)
|
child := indirect.Field(i)
|
||||||
|
if !child.CanSet() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
if child.Kind() == reflect.Ptr && child.IsNil() {
|
if child.Kind() == reflect.Ptr && child.IsNil() {
|
||||||
baseValueType := mapping.Deref(child.Type())
|
baseValueType := mapping.Deref(child.Type())
|
||||||
child.Set(reflect.New(baseValueType))
|
child.Set(reflect.New(baseValueType))
|
||||||
|
|||||||
@@ -8,11 +8,11 @@ import (
|
|||||||
|
|
||||||
"github.com/DATA-DOG/go-sqlmock"
|
"github.com/DATA-DOG/go-sqlmock"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/zeromicro/go-zero/core/logx"
|
"github.com/zeromicro/go-zero/internal/dbtest"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestUnmarshalRowBool(t *testing.T) {
|
func TestUnmarshalRowBool(t *testing.T) {
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1")
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
|
||||||
@@ -22,10 +22,22 @@ func TestUnmarshalRowBool(t *testing.T) {
|
|||||||
}, "select value from users where user=?", "anyone"))
|
}, "select value from users where user=?", "anyone"))
|
||||||
assert.True(t, value)
|
assert.True(t, value)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1")
|
||||||
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
|
||||||
|
var value struct {
|
||||||
|
Value bool `db:"value"`
|
||||||
|
}
|
||||||
|
assert.Error(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRow(value, rows, true)
|
||||||
|
}, "select value from users where user=?", "anyone"))
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalRowBoolNotSettable(t *testing.T) {
|
func TestUnmarshalRowBoolNotSettable(t *testing.T) {
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1")
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
|
||||||
@@ -37,7 +49,7 @@ func TestUnmarshalRowBoolNotSettable(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalRowInt(t *testing.T) {
|
func TestUnmarshalRowInt(t *testing.T) {
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2")
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
|
||||||
@@ -50,7 +62,7 @@ func TestUnmarshalRowInt(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalRowInt8(t *testing.T) {
|
func TestUnmarshalRowInt8(t *testing.T) {
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("3")
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("3")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
|
||||||
@@ -63,7 +75,7 @@ func TestUnmarshalRowInt8(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalRowInt16(t *testing.T) {
|
func TestUnmarshalRowInt16(t *testing.T) {
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("4")
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("4")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
|
||||||
@@ -76,7 +88,7 @@ func TestUnmarshalRowInt16(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalRowInt32(t *testing.T) {
|
func TestUnmarshalRowInt32(t *testing.T) {
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("5")
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("5")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
|
||||||
@@ -89,7 +101,7 @@ func TestUnmarshalRowInt32(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalRowInt64(t *testing.T) {
|
func TestUnmarshalRowInt64(t *testing.T) {
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("6")
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("6")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
|
||||||
@@ -102,7 +114,7 @@ func TestUnmarshalRowInt64(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalRowUint(t *testing.T) {
|
func TestUnmarshalRowUint(t *testing.T) {
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2")
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
|
||||||
@@ -115,7 +127,7 @@ func TestUnmarshalRowUint(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalRowUint8(t *testing.T) {
|
func TestUnmarshalRowUint8(t *testing.T) {
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("3")
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("3")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
|
||||||
@@ -128,7 +140,7 @@ func TestUnmarshalRowUint8(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalRowUint16(t *testing.T) {
|
func TestUnmarshalRowUint16(t *testing.T) {
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("4")
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("4")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
|
||||||
@@ -141,7 +153,7 @@ func TestUnmarshalRowUint16(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalRowUint32(t *testing.T) {
|
func TestUnmarshalRowUint32(t *testing.T) {
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("5")
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("5")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
|
||||||
@@ -154,7 +166,7 @@ func TestUnmarshalRowUint32(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalRowUint64(t *testing.T) {
|
func TestUnmarshalRowUint64(t *testing.T) {
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("6")
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("6")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
|
||||||
@@ -167,7 +179,7 @@ func TestUnmarshalRowUint64(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalRowFloat32(t *testing.T) {
|
func TestUnmarshalRowFloat32(t *testing.T) {
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("7")
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("7")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
|
||||||
@@ -180,7 +192,7 @@ func TestUnmarshalRowFloat32(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalRowFloat64(t *testing.T) {
|
func TestUnmarshalRowFloat64(t *testing.T) {
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("8")
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("8")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
|
||||||
@@ -193,7 +205,7 @@ func TestUnmarshalRowFloat64(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalRowString(t *testing.T) {
|
func TestUnmarshalRowString(t *testing.T) {
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
const expect = "hello"
|
const expect = "hello"
|
||||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString(expect)
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString(expect)
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
@@ -207,12 +219,12 @@ func TestUnmarshalRowString(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalRowStruct(t *testing.T) {
|
func TestUnmarshalRowStruct(t *testing.T) {
|
||||||
value := new(struct {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
Name string
|
value := new(struct {
|
||||||
Age int
|
Name string
|
||||||
})
|
Age int
|
||||||
|
})
|
||||||
|
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
|
||||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
|
||||||
@@ -222,15 +234,58 @@ func TestUnmarshalRowStruct(t *testing.T) {
|
|||||||
assert.Equal(t, "liao", value.Name)
|
assert.Equal(t, "liao", value.Name)
|
||||||
assert.Equal(t, 5, value.Age)
|
assert.Equal(t, 5, value.Age)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
value := new(struct {
|
||||||
|
Name string
|
||||||
|
Age int
|
||||||
|
})
|
||||||
|
|
||||||
|
errAny := errors.New("any error")
|
||||||
|
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
||||||
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
|
||||||
|
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRow(value, &mockedScanner{
|
||||||
|
colErr: errAny,
|
||||||
|
next: 1,
|
||||||
|
}, true)
|
||||||
|
}, "select name, age from users where user=?", "anyone"), errAny)
|
||||||
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
value := new(struct {
|
||||||
|
Name string
|
||||||
|
age *int
|
||||||
|
})
|
||||||
|
|
||||||
|
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
||||||
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
|
||||||
|
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRow(value, rows, true)
|
||||||
|
}, "select name, age from users where user=?", "anyone"), ErrNotMatchDestination)
|
||||||
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("8")
|
||||||
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
|
||||||
|
type myString chan int
|
||||||
|
var value myString
|
||||||
|
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRow(&value, rows, true)
|
||||||
|
}, "select value from users where user=?", "anyone"), ErrUnsupportedValueType)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalRowStructWithTags(t *testing.T) {
|
func TestUnmarshalRowStructWithTags(t *testing.T) {
|
||||||
value := new(struct {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
Age int `db:"age"`
|
value := new(struct {
|
||||||
Name string `db:"name"`
|
Age int `db:"age"`
|
||||||
})
|
Name string `db:"name"`
|
||||||
|
})
|
||||||
|
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
|
||||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
|
||||||
@@ -240,6 +295,51 @@ func TestUnmarshalRowStructWithTags(t *testing.T) {
|
|||||||
assert.Equal(t, "liao", value.Name)
|
assert.Equal(t, "liao", value.Name)
|
||||||
assert.Equal(t, 5, value.Age)
|
assert.Equal(t, 5, value.Age)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
value := new(struct {
|
||||||
|
age *int `db:"age"`
|
||||||
|
Name string `db:"name"`
|
||||||
|
})
|
||||||
|
|
||||||
|
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
||||||
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
|
||||||
|
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRow(value, rows, true)
|
||||||
|
}, "select name, age from users where user=?", "anyone"), ErrNotReadableValue)
|
||||||
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
var value struct {
|
||||||
|
Age *int `db:"age"`
|
||||||
|
Name *string `db:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
||||||
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
|
||||||
|
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRow(&value, rows, true)
|
||||||
|
}, "select name, age from users where user=?", "anyone"))
|
||||||
|
assert.Equal(t, "liao", *value.Name)
|
||||||
|
assert.Equal(t, 5, *value.Age)
|
||||||
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
value := new(struct {
|
||||||
|
Age int `db:"age"`
|
||||||
|
Name string
|
||||||
|
})
|
||||||
|
|
||||||
|
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
||||||
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
|
||||||
|
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRow(value, rows, true)
|
||||||
|
}, "select name, age from users where user=?", "anyone"))
|
||||||
|
assert.Equal(t, 5, value.Age)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalRowStructWithTagsWrongColumns(t *testing.T) {
|
func TestUnmarshalRowStructWithTagsWrongColumns(t *testing.T) {
|
||||||
@@ -248,7 +348,7 @@ func TestUnmarshalRowStructWithTagsWrongColumns(t *testing.T) {
|
|||||||
Name string `db:"name"`
|
Name string `db:"name"`
|
||||||
})
|
})
|
||||||
|
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
rs := sqlmock.NewRows([]string{"name"}).FromCSVString("liao")
|
rs := sqlmock.NewRows([]string{"name"}).FromCSVString("liao")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
|
||||||
@@ -259,7 +359,7 @@ func TestUnmarshalRowStructWithTagsWrongColumns(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalRowsBool(t *testing.T) {
|
func TestUnmarshalRowsBool(t *testing.T) {
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
expect := []bool{true, false}
|
expect := []bool{true, false}
|
||||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
@@ -270,10 +370,46 @@ func TestUnmarshalRowsBool(t *testing.T) {
|
|||||||
}, "select value from users where user=?", "anyone"))
|
}, "select value from users where user=?", "anyone"))
|
||||||
assert.EqualValues(t, expect, value)
|
assert.EqualValues(t, expect, value)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
|
||||||
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
|
||||||
|
var value []bool
|
||||||
|
assert.Error(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRows(value, rows, true)
|
||||||
|
}, "select value from users where user=?", "anyone"))
|
||||||
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
|
||||||
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
|
||||||
|
var value struct {
|
||||||
|
value []bool `db:"value"`
|
||||||
|
}
|
||||||
|
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRows(&value, rows, true)
|
||||||
|
}, "select value from users where user=?", "anyone"), ErrUnsupportedValueType)
|
||||||
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
|
||||||
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
|
||||||
|
var value []bool
|
||||||
|
errAny := errors.New("any")
|
||||||
|
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRows(&value, &mockedScanner{
|
||||||
|
scanErr: errAny,
|
||||||
|
next: 1,
|
||||||
|
}, true)
|
||||||
|
}, "select value from users where user=?", "anyone"), errAny)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalRowsInt(t *testing.T) {
|
func TestUnmarshalRowsInt(t *testing.T) {
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
expect := []int{2, 3}
|
expect := []int{2, 3}
|
||||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
@@ -287,7 +423,7 @@ func TestUnmarshalRowsInt(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalRowsInt8(t *testing.T) {
|
func TestUnmarshalRowsInt8(t *testing.T) {
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
expect := []int8{2, 3}
|
expect := []int8{2, 3}
|
||||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
@@ -301,7 +437,7 @@ func TestUnmarshalRowsInt8(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalRowsInt16(t *testing.T) {
|
func TestUnmarshalRowsInt16(t *testing.T) {
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
expect := []int16{2, 3}
|
expect := []int16{2, 3}
|
||||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
@@ -315,7 +451,7 @@ func TestUnmarshalRowsInt16(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalRowsInt32(t *testing.T) {
|
func TestUnmarshalRowsInt32(t *testing.T) {
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
expect := []int32{2, 3}
|
expect := []int32{2, 3}
|
||||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
@@ -329,7 +465,7 @@ func TestUnmarshalRowsInt32(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalRowsInt64(t *testing.T) {
|
func TestUnmarshalRowsInt64(t *testing.T) {
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
expect := []int64{2, 3}
|
expect := []int64{2, 3}
|
||||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
@@ -343,7 +479,7 @@ func TestUnmarshalRowsInt64(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalRowsUint(t *testing.T) {
|
func TestUnmarshalRowsUint(t *testing.T) {
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
expect := []uint{2, 3}
|
expect := []uint{2, 3}
|
||||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
@@ -357,7 +493,7 @@ func TestUnmarshalRowsUint(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalRowsUint8(t *testing.T) {
|
func TestUnmarshalRowsUint8(t *testing.T) {
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
expect := []uint8{2, 3}
|
expect := []uint8{2, 3}
|
||||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
@@ -371,7 +507,7 @@ func TestUnmarshalRowsUint8(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalRowsUint16(t *testing.T) {
|
func TestUnmarshalRowsUint16(t *testing.T) {
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
expect := []uint16{2, 3}
|
expect := []uint16{2, 3}
|
||||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
@@ -385,7 +521,7 @@ func TestUnmarshalRowsUint16(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalRowsUint32(t *testing.T) {
|
func TestUnmarshalRowsUint32(t *testing.T) {
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
expect := []uint32{2, 3}
|
expect := []uint32{2, 3}
|
||||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
@@ -399,7 +535,7 @@ func TestUnmarshalRowsUint32(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalRowsUint64(t *testing.T) {
|
func TestUnmarshalRowsUint64(t *testing.T) {
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
expect := []uint64{2, 3}
|
expect := []uint64{2, 3}
|
||||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
@@ -413,7 +549,7 @@ func TestUnmarshalRowsUint64(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalRowsFloat32(t *testing.T) {
|
func TestUnmarshalRowsFloat32(t *testing.T) {
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
expect := []float32{2, 3}
|
expect := []float32{2, 3}
|
||||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
@@ -427,7 +563,7 @@ func TestUnmarshalRowsFloat32(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalRowsFloat64(t *testing.T) {
|
func TestUnmarshalRowsFloat64(t *testing.T) {
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
expect := []float64{2, 3}
|
expect := []float64{2, 3}
|
||||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
@@ -441,7 +577,7 @@ func TestUnmarshalRowsFloat64(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalRowsString(t *testing.T) {
|
func TestUnmarshalRowsString(t *testing.T) {
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
expect := []string{"hello", "world"}
|
expect := []string{"hello", "world"}
|
||||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("hello\nworld")
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("hello\nworld")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
@@ -457,7 +593,7 @@ func TestUnmarshalRowsString(t *testing.T) {
|
|||||||
func TestUnmarshalRowsBoolPtr(t *testing.T) {
|
func TestUnmarshalRowsBoolPtr(t *testing.T) {
|
||||||
yes := true
|
yes := true
|
||||||
no := false
|
no := false
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
expect := []*bool{&yes, &no}
|
expect := []*bool{&yes, &no}
|
||||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
@@ -473,7 +609,7 @@ func TestUnmarshalRowsBoolPtr(t *testing.T) {
|
|||||||
func TestUnmarshalRowsIntPtr(t *testing.T) {
|
func TestUnmarshalRowsIntPtr(t *testing.T) {
|
||||||
two := 2
|
two := 2
|
||||||
three := 3
|
three := 3
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
expect := []*int{&two, &three}
|
expect := []*int{&two, &three}
|
||||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
@@ -489,7 +625,7 @@ func TestUnmarshalRowsIntPtr(t *testing.T) {
|
|||||||
func TestUnmarshalRowsInt8Ptr(t *testing.T) {
|
func TestUnmarshalRowsInt8Ptr(t *testing.T) {
|
||||||
two := int8(2)
|
two := int8(2)
|
||||||
three := int8(3)
|
three := int8(3)
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
expect := []*int8{&two, &three}
|
expect := []*int8{&two, &three}
|
||||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
@@ -505,7 +641,7 @@ func TestUnmarshalRowsInt8Ptr(t *testing.T) {
|
|||||||
func TestUnmarshalRowsInt16Ptr(t *testing.T) {
|
func TestUnmarshalRowsInt16Ptr(t *testing.T) {
|
||||||
two := int16(2)
|
two := int16(2)
|
||||||
three := int16(3)
|
three := int16(3)
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
expect := []*int16{&two, &three}
|
expect := []*int16{&two, &three}
|
||||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
@@ -521,7 +657,7 @@ func TestUnmarshalRowsInt16Ptr(t *testing.T) {
|
|||||||
func TestUnmarshalRowsInt32Ptr(t *testing.T) {
|
func TestUnmarshalRowsInt32Ptr(t *testing.T) {
|
||||||
two := int32(2)
|
two := int32(2)
|
||||||
three := int32(3)
|
three := int32(3)
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
expect := []*int32{&two, &three}
|
expect := []*int32{&two, &three}
|
||||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
@@ -537,7 +673,7 @@ func TestUnmarshalRowsInt32Ptr(t *testing.T) {
|
|||||||
func TestUnmarshalRowsInt64Ptr(t *testing.T) {
|
func TestUnmarshalRowsInt64Ptr(t *testing.T) {
|
||||||
two := int64(2)
|
two := int64(2)
|
||||||
three := int64(3)
|
three := int64(3)
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
expect := []*int64{&two, &three}
|
expect := []*int64{&two, &three}
|
||||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
@@ -553,7 +689,7 @@ func TestUnmarshalRowsInt64Ptr(t *testing.T) {
|
|||||||
func TestUnmarshalRowsUintPtr(t *testing.T) {
|
func TestUnmarshalRowsUintPtr(t *testing.T) {
|
||||||
two := uint(2)
|
two := uint(2)
|
||||||
three := uint(3)
|
three := uint(3)
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
expect := []*uint{&two, &three}
|
expect := []*uint{&two, &three}
|
||||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
@@ -569,7 +705,7 @@ func TestUnmarshalRowsUintPtr(t *testing.T) {
|
|||||||
func TestUnmarshalRowsUint8Ptr(t *testing.T) {
|
func TestUnmarshalRowsUint8Ptr(t *testing.T) {
|
||||||
two := uint8(2)
|
two := uint8(2)
|
||||||
three := uint8(3)
|
three := uint8(3)
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
expect := []*uint8{&two, &three}
|
expect := []*uint8{&two, &three}
|
||||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
@@ -585,7 +721,7 @@ func TestUnmarshalRowsUint8Ptr(t *testing.T) {
|
|||||||
func TestUnmarshalRowsUint16Ptr(t *testing.T) {
|
func TestUnmarshalRowsUint16Ptr(t *testing.T) {
|
||||||
two := uint16(2)
|
two := uint16(2)
|
||||||
three := uint16(3)
|
three := uint16(3)
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
expect := []*uint16{&two, &three}
|
expect := []*uint16{&two, &three}
|
||||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
@@ -601,7 +737,7 @@ func TestUnmarshalRowsUint16Ptr(t *testing.T) {
|
|||||||
func TestUnmarshalRowsUint32Ptr(t *testing.T) {
|
func TestUnmarshalRowsUint32Ptr(t *testing.T) {
|
||||||
two := uint32(2)
|
two := uint32(2)
|
||||||
three := uint32(3)
|
three := uint32(3)
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
expect := []*uint32{&two, &three}
|
expect := []*uint32{&two, &three}
|
||||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
@@ -617,7 +753,7 @@ func TestUnmarshalRowsUint32Ptr(t *testing.T) {
|
|||||||
func TestUnmarshalRowsUint64Ptr(t *testing.T) {
|
func TestUnmarshalRowsUint64Ptr(t *testing.T) {
|
||||||
two := uint64(2)
|
two := uint64(2)
|
||||||
three := uint64(3)
|
three := uint64(3)
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
expect := []*uint64{&two, &three}
|
expect := []*uint64{&two, &three}
|
||||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
@@ -633,7 +769,7 @@ func TestUnmarshalRowsUint64Ptr(t *testing.T) {
|
|||||||
func TestUnmarshalRowsFloat32Ptr(t *testing.T) {
|
func TestUnmarshalRowsFloat32Ptr(t *testing.T) {
|
||||||
two := float32(2)
|
two := float32(2)
|
||||||
three := float32(3)
|
three := float32(3)
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
expect := []*float32{&two, &three}
|
expect := []*float32{&two, &three}
|
||||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
@@ -649,7 +785,7 @@ func TestUnmarshalRowsFloat32Ptr(t *testing.T) {
|
|||||||
func TestUnmarshalRowsFloat64Ptr(t *testing.T) {
|
func TestUnmarshalRowsFloat64Ptr(t *testing.T) {
|
||||||
two := float64(2)
|
two := float64(2)
|
||||||
three := float64(3)
|
three := float64(3)
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
expect := []*float64{&two, &three}
|
expect := []*float64{&two, &three}
|
||||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
@@ -665,7 +801,7 @@ func TestUnmarshalRowsFloat64Ptr(t *testing.T) {
|
|||||||
func TestUnmarshalRowsStringPtr(t *testing.T) {
|
func TestUnmarshalRowsStringPtr(t *testing.T) {
|
||||||
hello := "hello"
|
hello := "hello"
|
||||||
world := "world"
|
world := "world"
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
expect := []*string{&hello, &world}
|
expect := []*string{&hello, &world}
|
||||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("hello\nworld")
|
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("hello\nworld")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
@@ -679,25 +815,25 @@ func TestUnmarshalRowsStringPtr(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalRowsStruct(t *testing.T) {
|
func TestUnmarshalRowsStruct(t *testing.T) {
|
||||||
expect := []struct {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
Name string
|
expect := []struct {
|
||||||
Age int64
|
Name string
|
||||||
}{
|
Age int64
|
||||||
{
|
}{
|
||||||
Name: "first",
|
{
|
||||||
Age: 2,
|
Name: "first",
|
||||||
},
|
Age: 2,
|
||||||
{
|
},
|
||||||
Name: "second",
|
{
|
||||||
Age: 3,
|
Name: "second",
|
||||||
},
|
Age: 3,
|
||||||
}
|
},
|
||||||
var value []struct {
|
}
|
||||||
Name string
|
var value []struct {
|
||||||
Age int64
|
Name string
|
||||||
}
|
Age int64
|
||||||
|
}
|
||||||
|
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
|
||||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
|
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
@@ -709,6 +845,56 @@ func TestUnmarshalRowsStruct(t *testing.T) {
|
|||||||
assert.Equal(t, each.Age, value[i].Age)
|
assert.Equal(t, each.Age, value[i].Age)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
var value []struct {
|
||||||
|
Name string
|
||||||
|
Age int64
|
||||||
|
}
|
||||||
|
|
||||||
|
errAny := errors.New("any error")
|
||||||
|
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
|
||||||
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRows(&value, &mockedScanner{
|
||||||
|
colErr: errAny,
|
||||||
|
next: 1,
|
||||||
|
}, true)
|
||||||
|
}, "select name, age from users where user=?", "anyone"), errAny)
|
||||||
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
var value []struct {
|
||||||
|
Name string
|
||||||
|
Age int64
|
||||||
|
}
|
||||||
|
|
||||||
|
errAny := errors.New("any error")
|
||||||
|
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
|
||||||
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRows(&value, &mockedScanner{
|
||||||
|
cols: []string{"name", "age"},
|
||||||
|
scanErr: errAny,
|
||||||
|
next: 1,
|
||||||
|
}, true)
|
||||||
|
}, "select name, age from users where user=?", "anyone"), errAny)
|
||||||
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
var value []chan int
|
||||||
|
|
||||||
|
errAny := errors.New("any error")
|
||||||
|
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
|
||||||
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRows(&value, &mockedScanner{
|
||||||
|
cols: []string{"name", "age"},
|
||||||
|
scanErr: errAny,
|
||||||
|
next: 1,
|
||||||
|
}, true)
|
||||||
|
}, "select name, age from users where user=?", "anyone"), ErrUnsupportedValueType)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalRowsStructWithNullStringType(t *testing.T) {
|
func TestUnmarshalRowsStructWithNullStringType(t *testing.T) {
|
||||||
@@ -736,7 +922,7 @@ func TestUnmarshalRowsStructWithNullStringType(t *testing.T) {
|
|||||||
NullString sql.NullString `db:"value"`
|
NullString sql.NullString `db:"value"`
|
||||||
}
|
}
|
||||||
|
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
rs := sqlmock.NewRows([]string{"name", "value"}).AddRow(
|
rs := sqlmock.NewRows([]string{"name", "value"}).AddRow(
|
||||||
"first", "firstnullstring").AddRow("second", nil)
|
"first", "firstnullstring").AddRow("second", nil)
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
@@ -771,7 +957,7 @@ func TestUnmarshalRowsStructWithTags(t *testing.T) {
|
|||||||
Name string `db:"name"`
|
Name string `db:"name"`
|
||||||
}
|
}
|
||||||
|
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
|
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
@@ -812,7 +998,7 @@ func TestUnmarshalRowsStructAndEmbeddedAnonymousStructWithTags(t *testing.T) {
|
|||||||
Embed
|
Embed
|
||||||
}
|
}
|
||||||
|
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
rs := sqlmock.NewRows([]string{"name", "age", "value"}).FromCSVString("first,2,3\nsecond,3,4")
|
rs := sqlmock.NewRows([]string{"name", "age", "value"}).FromCSVString("first,2,3\nsecond,3,4")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
@@ -854,7 +1040,7 @@ func TestUnmarshalRowsStructAndEmbeddedStructPtrAnonymousWithTags(t *testing.T)
|
|||||||
*Embed
|
*Embed
|
||||||
}
|
}
|
||||||
|
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
rs := sqlmock.NewRows([]string{"name", "age", "value"}).FromCSVString("first,2,3\nsecond,3,4")
|
rs := sqlmock.NewRows([]string{"name", "age", "value"}).FromCSVString("first,2,3\nsecond,3,4")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
@@ -888,7 +1074,7 @@ func TestUnmarshalRowsStructPtr(t *testing.T) {
|
|||||||
Age int64
|
Age int64
|
||||||
}
|
}
|
||||||
|
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
|
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
@@ -921,7 +1107,7 @@ func TestUnmarshalRowsStructWithTagsPtr(t *testing.T) {
|
|||||||
Name string `db:"name"`
|
Name string `db:"name"`
|
||||||
}
|
}
|
||||||
|
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
|
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
@@ -954,7 +1140,7 @@ func TestUnmarshalRowsStructWithTagsPtrWithInnerPtr(t *testing.T) {
|
|||||||
Name string `db:"name"`
|
Name string `db:"name"`
|
||||||
}
|
}
|
||||||
|
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
|
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
@@ -969,7 +1155,7 @@ func TestUnmarshalRowsStructWithTagsPtrWithInnerPtr(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCommonSqlConn_QueryRowOptional(t *testing.T) {
|
func TestCommonSqlConn_QueryRowOptional(t *testing.T) {
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
rs := sqlmock.NewRows([]string{"age"}).FromCSVString("5")
|
rs := sqlmock.NewRows([]string{"age"}).FromCSVString("5")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
|
||||||
@@ -1019,7 +1205,7 @@ func TestUnmarshalRowError(t *testing.T) {
|
|||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run(test.name, func(t *testing.T) {
|
t.Run(test.name, func(t *testing.T) {
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
rs := sqlmock.NewRows([]string{"age"}).FromCSVString("5")
|
rs := sqlmock.NewRows([]string{"age"}).FromCSVString("5")
|
||||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs(
|
mock.ExpectQuery("select (.+) from users where user=?").WithArgs(
|
||||||
"anyone").WillReturnRows(rs)
|
"anyone").WillReturnRows(rs)
|
||||||
@@ -1091,7 +1277,7 @@ func TestAnonymousStructPr(t *testing.T) {
|
|||||||
Name string `db:"name"`
|
Name string `db:"name"`
|
||||||
}
|
}
|
||||||
|
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
rs := sqlmock.NewRows([]string{
|
rs := sqlmock.NewRows([]string{
|
||||||
"name",
|
"name",
|
||||||
"age",
|
"age",
|
||||||
@@ -1139,7 +1325,7 @@ func TestAnonymousStructPrError(t *testing.T) {
|
|||||||
Name string `db:"name"`
|
Name string `db:"name"`
|
||||||
}
|
}
|
||||||
|
|
||||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
rs := sqlmock.NewRows([]string{
|
rs := sqlmock.NewRows([]string{
|
||||||
"name",
|
"name",
|
||||||
"age",
|
"age",
|
||||||
@@ -1154,7 +1340,7 @@ func TestAnonymousStructPrError(t *testing.T) {
|
|||||||
WithArgs("anyone").WillReturnRows(rs)
|
WithArgs("anyone").WillReturnRows(rs)
|
||||||
assert.Error(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
assert.Error(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
return unmarshalRows(&value, rows, true)
|
return unmarshalRows(&value, rows, true)
|
||||||
}, "select name, age,grade,discipline,class_name,score from users where user=?",
|
}, "select name, age, grade, discipline, class_name, score from users where user=?",
|
||||||
"anyone"))
|
"anyone"))
|
||||||
if len(value) > 0 {
|
if len(value) > 0 {
|
||||||
assert.Equal(t, value[0].score, 0)
|
assert.Equal(t, value[0].score, 0)
|
||||||
@@ -1162,23 +1348,8 @@ func TestAnonymousStructPrError(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func runOrmTest(t *testing.T, fn func(db *sql.DB, mock sqlmock.Sqlmock)) {
|
|
||||||
logx.Disable()
|
|
||||||
|
|
||||||
db, mock, err := sqlmock.New()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
|
|
||||||
}
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
fn(db, mock)
|
|
||||||
|
|
||||||
if err := mock.ExpectationsWereMet(); err != nil {
|
|
||||||
t.Errorf("there were unfulfilled expectations: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type mockedScanner struct {
|
type mockedScanner struct {
|
||||||
|
cols []string
|
||||||
colErr error
|
colErr error
|
||||||
scanErr error
|
scanErr error
|
||||||
err error
|
err error
|
||||||
@@ -1186,7 +1357,7 @@ type mockedScanner struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockedScanner) Columns() ([]string, error) {
|
func (m *mockedScanner) Columns() ([]string, error) {
|
||||||
return nil, m.colErr
|
return m.cols, m.colErr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockedScanner) Err() error {
|
func (m *mockedScanner) Err() error {
|
||||||
|
|||||||
@@ -11,9 +11,6 @@ import (
|
|||||||
// spanName is used to identify the span name for the SQL execution.
|
// spanName is used to identify the span name for the SQL execution.
|
||||||
const spanName = "sql"
|
const spanName = "sql"
|
||||||
|
|
||||||
// ErrNotFound is an alias of sql.ErrNoRows
|
|
||||||
var ErrNotFound = sql.ErrNoRows
|
|
||||||
|
|
||||||
type (
|
type (
|
||||||
// Session stands for raw connections or transaction sessions
|
// Session stands for raw connections or transaction sessions
|
||||||
Session interface {
|
Session interface {
|
||||||
@@ -131,6 +128,13 @@ func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn {
|
|||||||
return conn
|
return conn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewSqlConnFromSession returns a SqlConn with the given session.
|
||||||
|
func NewSqlConnFromSession(session Session) SqlConn {
|
||||||
|
return txConn{
|
||||||
|
Session: session,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (db *commonSqlConn) Exec(q string, args ...any) (result sql.Result, err error) {
|
func (db *commonSqlConn) Exec(q string, args ...any) (result sql.Result, err error) {
|
||||||
return db.ExecCtx(context.Background(), q, args...)
|
return db.ExecCtx(context.Background(), q, args...)
|
||||||
}
|
}
|
||||||
@@ -287,12 +291,19 @@ func (db *commonSqlConn) TransactCtx(ctx context.Context, fn func(context.Contex
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (db *commonSqlConn) acceptable(err error) bool {
|
func (db *commonSqlConn) acceptable(err error) bool {
|
||||||
ok := err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone || err == context.Canceled
|
if err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone || err == context.Canceled {
|
||||||
if db.accept == nil {
|
return true
|
||||||
return ok
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return ok || db.accept(err)
|
if _, ok := err.(acceptableError); ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if db.accept == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return db.accept(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows) error,
|
func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows) error,
|
||||||
@@ -395,3 +406,11 @@ func (s statement) QueryRowsPartialCtx(ctx context.Context, v any, args ...any)
|
|||||||
return unmarshalRows(v, rows, false)
|
return unmarshalRows(v, rows, false)
|
||||||
}, s.query, args...)
|
}, s.query, args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithAcceptable returns a SqlOption that setting the acceptable function.
|
||||||
|
// acceptable is the func to check if the error can be accepted.
|
||||||
|
func WithAcceptable(acceptable func(err error) bool) SqlOption {
|
||||||
|
return func(conn *commonSqlConn) {
|
||||||
|
conn.accept = acceptable
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,13 +2,16 @@ package sqlx
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/DATA-DOG/go-sqlmock"
|
"github.com/DATA-DOG/go-sqlmock"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/zeromicro/go-zero/core/breaker"
|
||||||
"github.com/zeromicro/go-zero/core/logx"
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
"github.com/zeromicro/go-zero/core/trace/tracetest"
|
"github.com/zeromicro/go-zero/core/trace/tracetest"
|
||||||
|
"github.com/zeromicro/go-zero/internal/dbtest"
|
||||||
)
|
)
|
||||||
|
|
||||||
const mockedDatasource = "sqlmock"
|
const mockedDatasource = "sqlmock"
|
||||||
@@ -54,8 +57,214 @@ func TestSqlConn(t *testing.T) {
|
|||||||
assert.Equal(t, 14, len(me.GetSpans()))
|
assert.Equal(t, 14, len(me.GetSpans()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSqlConn_RawDB(t *testing.T) {
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
rows := sqlmock.NewRows([]string{"foo"}).AddRow("bar")
|
||||||
|
mock.ExpectQuery("any").WillReturnRows(rows)
|
||||||
|
conn := NewSqlConnFromDB(db)
|
||||||
|
var val string
|
||||||
|
assert.NoError(t, conn.QueryRow(&val, "any"))
|
||||||
|
assert.Equal(t, "bar", val)
|
||||||
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
rows := sqlmock.NewRows([]string{"foo"}).AddRow("bar")
|
||||||
|
mock.ExpectQuery("any").WillReturnRows(rows)
|
||||||
|
conn := NewSqlConnFromDB(db)
|
||||||
|
var val string
|
||||||
|
assert.NoError(t, conn.QueryRowPartial(&val, "any"))
|
||||||
|
assert.Equal(t, "bar", val)
|
||||||
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
rows := sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar")
|
||||||
|
mock.ExpectQuery("any").WillReturnRows(rows)
|
||||||
|
conn := NewSqlConnFromDB(db)
|
||||||
|
var vals []string
|
||||||
|
assert.NoError(t, conn.QueryRows(&vals, "any"))
|
||||||
|
assert.ElementsMatch(t, []string{"foo", "bar"}, vals)
|
||||||
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
rows := sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar")
|
||||||
|
mock.ExpectQuery("any").WillReturnRows(rows)
|
||||||
|
conn := NewSqlConnFromDB(db)
|
||||||
|
var vals []string
|
||||||
|
assert.NoError(t, conn.QueryRowsPartial(&vals, "any"))
|
||||||
|
assert.ElementsMatch(t, []string{"foo", "bar"}, vals)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlConn_Errors(t *testing.T) {
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
conn := NewSqlConnFromDB(db)
|
||||||
|
conn.(*commonSqlConn).connProv = func() (*sql.DB, error) {
|
||||||
|
return nil, errors.New("error")
|
||||||
|
}
|
||||||
|
_, err := conn.Prepare("any")
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
mock.ExpectExec("any").WillReturnError(breaker.ErrServiceUnavailable)
|
||||||
|
conn := NewSqlConnFromDB(db)
|
||||||
|
_, err := conn.Exec("any")
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
mock.ExpectPrepare("any").WillReturnError(breaker.ErrServiceUnavailable)
|
||||||
|
conn := NewSqlConnFromDB(db)
|
||||||
|
_, err := conn.Prepare("any")
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
mock.ExpectBegin()
|
||||||
|
mock.ExpectRollback()
|
||||||
|
conn := NewSqlConnFromDB(db)
|
||||||
|
err := conn.Transact(func(session Session) error {
|
||||||
|
return breaker.ErrServiceUnavailable
|
||||||
|
})
|
||||||
|
assert.Equal(t, breaker.ErrServiceUnavailable, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
mock.ExpectQuery("any").WillReturnError(breaker.ErrServiceUnavailable)
|
||||||
|
conn := NewSqlConnFromDB(db)
|
||||||
|
var vals []string
|
||||||
|
err := conn.QueryRows(&vals, "any")
|
||||||
|
assert.Equal(t, breaker.ErrServiceUnavailable, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStatement(t *testing.T) {
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
mock.ExpectPrepare("any").WillBeClosed()
|
||||||
|
|
||||||
|
conn := NewSqlConnFromDB(db)
|
||||||
|
stmt, err := conn.Prepare("any")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NoError(t, stmt.Close())
|
||||||
|
})
|
||||||
|
|
||||||
|
dbtest.RunTxTest(t, func(tx *sql.Tx, mock sqlmock.Sqlmock) {
|
||||||
|
mock.ExpectPrepare("any").WillBeClosed()
|
||||||
|
|
||||||
|
stmt, err := tx.Prepare("any")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
st := statement{
|
||||||
|
query: "foo",
|
||||||
|
stmt: stmt,
|
||||||
|
}
|
||||||
|
assert.NoError(t, st.Close())
|
||||||
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
mock.ExpectPrepare("any")
|
||||||
|
mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
|
||||||
|
|
||||||
|
conn := NewSqlConnFromDB(db)
|
||||||
|
stmt, err := conn.Prepare("any")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
res, err := stmt.Exec()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
lastInsertID, err := res.LastInsertId()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, int64(2), lastInsertID)
|
||||||
|
rowsAffected, err := res.RowsAffected()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, int64(3), rowsAffected)
|
||||||
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
mock.ExpectPrepare("any")
|
||||||
|
row := sqlmock.NewRows([]string{"foo"}).AddRow("bar")
|
||||||
|
mock.ExpectQuery("any").WillReturnRows(row)
|
||||||
|
|
||||||
|
conn := NewSqlConnFromDB(db)
|
||||||
|
stmt, err := conn.Prepare("any")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
var val string
|
||||||
|
err = stmt.QueryRow(&val)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "bar", val)
|
||||||
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
mock.ExpectPrepare("any")
|
||||||
|
row := sqlmock.NewRows([]string{"foo"}).AddRow("bar")
|
||||||
|
mock.ExpectQuery("any").WillReturnRows(row)
|
||||||
|
|
||||||
|
conn := NewSqlConnFromDB(db)
|
||||||
|
stmt, err := conn.Prepare("any")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
var val string
|
||||||
|
err = stmt.QueryRowPartial(&val)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "bar", val)
|
||||||
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
mock.ExpectPrepare("any")
|
||||||
|
rows := sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar")
|
||||||
|
mock.ExpectQuery("any").WillReturnRows(rows)
|
||||||
|
|
||||||
|
conn := NewSqlConnFromDB(db)
|
||||||
|
stmt, err := conn.Prepare("any")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
var vals []string
|
||||||
|
assert.NoError(t, stmt.QueryRows(&vals))
|
||||||
|
assert.ElementsMatch(t, []string{"foo", "bar"}, vals)
|
||||||
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
mock.ExpectPrepare("any")
|
||||||
|
rows := sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar")
|
||||||
|
mock.ExpectQuery("any").WillReturnRows(rows)
|
||||||
|
|
||||||
|
conn := NewSqlConnFromDB(db)
|
||||||
|
stmt, err := conn.Prepare("any")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
var vals []string
|
||||||
|
assert.NoError(t, stmt.QueryRowsPartial(&vals))
|
||||||
|
assert.ElementsMatch(t, []string{"foo", "bar"}, vals)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBreakerWithFormatError(t *testing.T) {
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
conn := NewSqlConnFromDB(db, withMysqlAcceptable())
|
||||||
|
for i := 0; i < 1000; i++ {
|
||||||
|
var val string
|
||||||
|
if !assert.NotEqual(t, breaker.ErrServiceUnavailable,
|
||||||
|
conn.QueryRow(&val, "any ?, ?", "foo")) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBreakerWithScanError(t *testing.T) {
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
conn := NewSqlConnFromDB(db, withMysqlAcceptable())
|
||||||
|
for i := 0; i < 1000; i++ {
|
||||||
|
rows := sqlmock.NewRows([]string{"foo"}).AddRow("bar")
|
||||||
|
mock.ExpectQuery("any").WillReturnRows(rows)
|
||||||
|
var val int
|
||||||
|
if !assert.NotEqual(t, breaker.ErrServiceUnavailable, conn.QueryRow(&val, "any")) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func buildConn() (mock sqlmock.Sqlmock, err error) {
|
func buildConn() (mock sqlmock.Sqlmock, err error) {
|
||||||
connManager.GetResource(mockedDatasource, func() (io.Closer, error) {
|
_, err = connManager.GetResource(mockedDatasource, func() (io.Closer, error) {
|
||||||
var db *sql.DB
|
var db *sql.DB
|
||||||
var err error
|
var err error
|
||||||
db, mock, err = sqlmock.New()
|
db, mock, err = sqlmock.New()
|
||||||
|
|||||||
@@ -136,7 +136,7 @@ func (e *realSqlGuard) finish(ctx context.Context, err error) {
|
|||||||
logSqlError(ctx, e.stmt, err)
|
logSqlError(ctx, e.stmt, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
metricReqDur.Observe(int64(duration/time.Millisecond), e.command)
|
metricReqDur.Observe(duration.Milliseconds(), e.command)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *realSqlGuard) start(q string, args ...any) error {
|
func (e *realSqlGuard) start(q string, args ...any) error {
|
||||||
|
|||||||
@@ -15,11 +15,27 @@ type (
|
|||||||
Rollback() error
|
Rollback() error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
txConn struct {
|
||||||
|
Session
|
||||||
|
}
|
||||||
|
|
||||||
txSession struct {
|
txSession struct {
|
||||||
*sql.Tx
|
*sql.Tx
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func (s txConn) RawDB() (*sql.DB, error) {
|
||||||
|
return nil, errNoRawDBFromTx
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s txConn) Transact(_ func(Session) error) error {
|
||||||
|
return errCantNestTx
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s txConn) TransactCtx(_ context.Context, _ func(context.Context, Session) error) error {
|
||||||
|
return errCantNestTx
|
||||||
|
}
|
||||||
|
|
||||||
// NewSessionFromTx returns a Session with the given sql.Tx.
|
// NewSessionFromTx returns a Session with the given sql.Tx.
|
||||||
// Use it with caution, it's provided for other ORM to interact with.
|
// Use it with caution, it's provided for other ORM to interact with.
|
||||||
func NewSessionFromTx(tx *sql.Tx) Session {
|
func NewSessionFromTx(tx *sql.Tx) Session {
|
||||||
|
|||||||
@@ -6,7 +6,10 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/DATA-DOG/go-sqlmock"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/zeromicro/go-zero/core/breaker"
|
||||||
|
"github.com/zeromicro/go-zero/internal/dbtest"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -23,51 +26,51 @@ func (mt *mockTx) Commit() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mt *mockTx) Exec(q string, args ...any) (sql.Result, error) {
|
func (mt *mockTx) Exec(_ string, _ ...any) (sql.Result, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mt *mockTx) ExecCtx(ctx context.Context, query string, args ...any) (sql.Result, error) {
|
func (mt *mockTx) ExecCtx(_ context.Context, _ string, _ ...any) (sql.Result, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mt *mockTx) Prepare(query string) (StmtSession, error) {
|
func (mt *mockTx) Prepare(_ string) (StmtSession, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mt *mockTx) PrepareCtx(ctx context.Context, query string) (StmtSession, error) {
|
func (mt *mockTx) PrepareCtx(_ context.Context, _ string) (StmtSession, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mt *mockTx) QueryRow(v any, q string, args ...any) error {
|
func (mt *mockTx) QueryRow(_ any, _ string, _ ...any) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mt *mockTx) QueryRowCtx(ctx context.Context, v any, query string, args ...any) error {
|
func (mt *mockTx) QueryRowCtx(_ context.Context, _ any, _ string, _ ...any) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mt *mockTx) QueryRowPartial(v any, q string, args ...any) error {
|
func (mt *mockTx) QueryRowPartial(_ any, _ string, _ ...any) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mt *mockTx) QueryRowPartialCtx(ctx context.Context, v any, query string, args ...any) error {
|
func (mt *mockTx) QueryRowPartialCtx(_ context.Context, _ any, _ string, _ ...any) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mt *mockTx) QueryRows(v any, q string, args ...any) error {
|
func (mt *mockTx) QueryRows(_ any, _ string, _ ...any) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mt *mockTx) QueryRowsCtx(ctx context.Context, v any, query string, args ...any) error {
|
func (mt *mockTx) QueryRowsCtx(_ context.Context, _ any, _ string, _ ...any) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mt *mockTx) QueryRowsPartial(v any, q string, args ...any) error {
|
func (mt *mockTx) QueryRowsPartial(_ any, _ string, _ ...any) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mt *mockTx) QueryRowsPartialCtx(ctx context.Context, v any, query string, args ...any) error {
|
func (mt *mockTx) QueryRowsPartialCtx(_ context.Context, _ any, _ string, _ ...any) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -101,3 +104,209 @@ func TestTransactRollback(t *testing.T) {
|
|||||||
assert.Equal(t, mockRollback, mock.status)
|
assert.Equal(t, mockRollback, mock.status)
|
||||||
assert.NotNil(t, err)
|
assert.NotNil(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTxExceptions(t *testing.T) {
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
mock.ExpectBegin()
|
||||||
|
mock.ExpectCommit()
|
||||||
|
conn := NewSqlConnFromDB(db)
|
||||||
|
assert.NoError(t, conn.Transact(func(session Session) error {
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
conn := &commonSqlConn{
|
||||||
|
connProv: func() (*sql.DB, error) {
|
||||||
|
return nil, errors.New("foo")
|
||||||
|
},
|
||||||
|
beginTx: begin,
|
||||||
|
onError: func(ctx context.Context, err error) {},
|
||||||
|
brk: breaker.NewBreaker(),
|
||||||
|
}
|
||||||
|
assert.Error(t, conn.Transact(func(session Session) error {
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
})
|
||||||
|
|
||||||
|
runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
|
||||||
|
_, err := conn.RawDB()
|
||||||
|
assert.Equal(t, errNoRawDBFromTx, err)
|
||||||
|
assert.Equal(t, errCantNestTx, conn.Transact(nil))
|
||||||
|
assert.Equal(t, errCantNestTx, conn.TransactCtx(context.Background(), nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
mock.ExpectBegin()
|
||||||
|
conn := NewSqlConnFromDB(db)
|
||||||
|
assert.Error(t, conn.Transact(func(session Session) error {
|
||||||
|
return errors.New("foo")
|
||||||
|
}))
|
||||||
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
mock.ExpectBegin()
|
||||||
|
mock.ExpectRollback().WillReturnError(errors.New("foo"))
|
||||||
|
conn := NewSqlConnFromDB(db)
|
||||||
|
assert.Error(t, conn.Transact(func(session Session) error {
|
||||||
|
panic("foo")
|
||||||
|
}))
|
||||||
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
mock.ExpectBegin()
|
||||||
|
mock.ExpectRollback()
|
||||||
|
conn := NewSqlConnFromDB(db)
|
||||||
|
assert.Error(t, conn.Transact(func(session Session) error {
|
||||||
|
panic(errors.New("foo"))
|
||||||
|
}))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTxSession(t *testing.T) {
|
||||||
|
runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
|
||||||
|
mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
|
||||||
|
res, err := conn.Exec("any")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
last, err := res.LastInsertId()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, int64(2), last)
|
||||||
|
affected, err := res.RowsAffected()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, int64(3), affected)
|
||||||
|
|
||||||
|
mock.ExpectExec("any").WillReturnError(errors.New("foo"))
|
||||||
|
_, err = conn.Exec("any")
|
||||||
|
assert.Equal(t, "foo", err.Error())
|
||||||
|
})
|
||||||
|
|
||||||
|
runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
|
||||||
|
mock.ExpectPrepare("any")
|
||||||
|
stmt, err := conn.Prepare("any")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, stmt)
|
||||||
|
|
||||||
|
mock.ExpectPrepare("any").WillReturnError(errors.New("foo"))
|
||||||
|
_, err = conn.Prepare("any")
|
||||||
|
assert.Equal(t, "foo", err.Error())
|
||||||
|
})
|
||||||
|
|
||||||
|
runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
|
||||||
|
rows := sqlmock.NewRows([]string{"col"}).AddRow("foo")
|
||||||
|
mock.ExpectQuery("any").WillReturnRows(rows)
|
||||||
|
var val string
|
||||||
|
err := conn.QueryRow(&val, "any")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "foo", val)
|
||||||
|
|
||||||
|
mock.ExpectQuery("any").WillReturnError(errors.New("foo"))
|
||||||
|
err = conn.QueryRow(&val, "any")
|
||||||
|
assert.Equal(t, "foo", err.Error())
|
||||||
|
})
|
||||||
|
|
||||||
|
runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
|
||||||
|
rows := sqlmock.NewRows([]string{"col"}).AddRow("foo")
|
||||||
|
mock.ExpectQuery("any").WillReturnRows(rows)
|
||||||
|
var val string
|
||||||
|
err := conn.QueryRowPartial(&val, "any")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "foo", val)
|
||||||
|
|
||||||
|
mock.ExpectQuery("any").WillReturnError(errors.New("foo"))
|
||||||
|
err = conn.QueryRowPartial(&val, "any")
|
||||||
|
assert.Equal(t, "foo", err.Error())
|
||||||
|
})
|
||||||
|
|
||||||
|
runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
|
||||||
|
rows := sqlmock.NewRows([]string{"col"}).AddRow("foo").AddRow("bar")
|
||||||
|
mock.ExpectQuery("any").WillReturnRows(rows)
|
||||||
|
var val []string
|
||||||
|
err := conn.QueryRows(&val, "any")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, []string{"foo", "bar"}, val)
|
||||||
|
|
||||||
|
mock.ExpectQuery("any").WillReturnError(errors.New("foo"))
|
||||||
|
err = conn.QueryRows(&val, "any")
|
||||||
|
assert.Equal(t, "foo", err.Error())
|
||||||
|
})
|
||||||
|
|
||||||
|
runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
|
||||||
|
rows := sqlmock.NewRows([]string{"col"}).AddRow("foo").AddRow("bar")
|
||||||
|
mock.ExpectQuery("any").WillReturnRows(rows)
|
||||||
|
var val []string
|
||||||
|
err := conn.QueryRowsPartial(&val, "any")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, []string{"foo", "bar"}, val)
|
||||||
|
|
||||||
|
mock.ExpectQuery("any").WillReturnError(errors.New("foo"))
|
||||||
|
err = conn.QueryRowsPartial(&val, "any")
|
||||||
|
assert.Equal(t, "foo", err.Error())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTxRollback(t *testing.T) {
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
mock.ExpectBegin()
|
||||||
|
mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
|
||||||
|
mock.ExpectQuery("foo").WillReturnError(errors.New("foo"))
|
||||||
|
mock.ExpectRollback()
|
||||||
|
|
||||||
|
conn := NewSqlConnFromDB(db)
|
||||||
|
err := conn.Transact(func(session Session) error {
|
||||||
|
c := NewSqlConnFromSession(session)
|
||||||
|
_, err := c.Exec("any")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
var val string
|
||||||
|
return c.QueryRow(&val, "foo")
|
||||||
|
})
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
mock.ExpectBegin()
|
||||||
|
mock.ExpectExec("any").WillReturnError(errors.New("foo"))
|
||||||
|
mock.ExpectRollback()
|
||||||
|
|
||||||
|
conn := NewSqlConnFromDB(db)
|
||||||
|
err := conn.Transact(func(session Session) error {
|
||||||
|
c := NewSqlConnFromSession(session)
|
||||||
|
if _, err := c.Exec("any"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var val string
|
||||||
|
assert.NoError(t, c.QueryRow(&val, "foo"))
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
mock.ExpectBegin()
|
||||||
|
mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
|
||||||
|
mock.ExpectQuery("foo").WillReturnRows(sqlmock.NewRows([]string{"col"}).AddRow("bar"))
|
||||||
|
mock.ExpectCommit()
|
||||||
|
|
||||||
|
conn := NewSqlConnFromDB(db)
|
||||||
|
err := conn.Transact(func(session Session) error {
|
||||||
|
c := NewSqlConnFromSession(session)
|
||||||
|
_, err := c.Exec("any")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
var val string
|
||||||
|
assert.NoError(t, c.QueryRow(&val, "foo"))
|
||||||
|
assert.Equal(t, "bar", val)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func runTxTest(t *testing.T, f func(conn SqlConn, mock sqlmock.Sqlmock)) {
|
||||||
|
dbtest.RunTxTest(t, func(tx *sql.Tx, mock sqlmock.Sqlmock) {
|
||||||
|
sess := NewSessionFromTx(tx)
|
||||||
|
conn := NewSqlConnFromSession(sess)
|
||||||
|
f(conn, mock)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -51,7 +51,13 @@ func escape(input string) string {
|
|||||||
return b.String()
|
return b.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func format(query string, args ...any) (string, error) {
|
func format(query string, args ...any) (val string, err error) {
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
err = newAcceptableError(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
numArgs := len(args)
|
numArgs := len(args)
|
||||||
if numArgs == 0 {
|
if numArgs == 0 {
|
||||||
return query, nil
|
return query, nil
|
||||||
@@ -66,7 +72,8 @@ func format(query string, args ...any) (string, error) {
|
|||||||
switch ch {
|
switch ch {
|
||||||
case '?':
|
case '?':
|
||||||
if argIndex >= numArgs {
|
if argIndex >= numArgs {
|
||||||
return "", fmt.Errorf("%d ? in sql, but less arguments provided", argIndex)
|
return "", fmt.Errorf("%d ? in sql, but only %d arguments provided",
|
||||||
|
argIndex+1, numArgs)
|
||||||
}
|
}
|
||||||
|
|
||||||
writeValue(&b, args[argIndex])
|
writeValue(&b, args[argIndex])
|
||||||
@@ -165,3 +172,17 @@ func writeValue(buf *strings.Builder, arg any) {
|
|||||||
buf.WriteString(mapping.Repr(v))
|
buf.WriteString(mapping.Repr(v))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type acceptableError struct {
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAcceptableError(err error) error {
|
||||||
|
return acceptableError{
|
||||||
|
err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e acceptableError) Error() string {
|
||||||
|
return e.err.Error()
|
||||||
|
}
|
||||||
|
|||||||
@@ -20,12 +20,14 @@ func ForAtomicBool(val bool) *AtomicBool {
|
|||||||
// CompareAndSwap compares current value with given old, if equals, set to given val.
|
// CompareAndSwap compares current value with given old, if equals, set to given val.
|
||||||
func (b *AtomicBool) CompareAndSwap(old, val bool) bool {
|
func (b *AtomicBool) CompareAndSwap(old, val bool) bool {
|
||||||
var ov, nv uint32
|
var ov, nv uint32
|
||||||
|
|
||||||
if old {
|
if old {
|
||||||
ov = 1
|
ov = 1
|
||||||
}
|
}
|
||||||
if val {
|
if val {
|
||||||
nv = 1
|
nv = 1
|
||||||
}
|
}
|
||||||
|
|
||||||
return atomic.CompareAndSwapUint32((*uint32)(b), ov, nv)
|
return atomic.CompareAndSwapUint32((*uint32)(b), ov, nv)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -42,7 +42,8 @@ func (manager *ResourceManager) Close() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetResource returns the resource associated with given key.
|
// GetResource returns the resource associated with given key.
|
||||||
func (manager *ResourceManager) GetResource(key string, create func() (io.Closer, error)) (io.Closer, error) {
|
func (manager *ResourceManager) GetResource(key string, create func() (io.Closer, error)) (
|
||||||
|
io.Closer, error) {
|
||||||
val, err := manager.singleFlight.Do(key, func() (any, error) {
|
val, err := manager.singleFlight.Do(key, func() (any, error) {
|
||||||
manager.lock.RLock()
|
manager.lock.RLock()
|
||||||
resource, ok := manager.resources[key]
|
resource, ok := manager.resources[key]
|
||||||
|
|||||||
@@ -9,25 +9,44 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestTimeoutLimit(t *testing.T) {
|
func TestTimeoutLimit(t *testing.T) {
|
||||||
limit := NewTimeoutLimit(2)
|
tests := []struct {
|
||||||
assert.Nil(t, limit.Borrow(time.Millisecond*200))
|
name string
|
||||||
assert.Nil(t, limit.Borrow(time.Millisecond*200))
|
interval time.Duration
|
||||||
var wait1, wait2, wait3 sync.WaitGroup
|
}{
|
||||||
wait1.Add(1)
|
{
|
||||||
wait2.Add(1)
|
name: "no wait",
|
||||||
wait3.Add(1)
|
},
|
||||||
go func() {
|
{
|
||||||
wait1.Wait()
|
name: "wait",
|
||||||
wait2.Done()
|
interval: time.Millisecond * 100,
|
||||||
assert.Nil(t, limit.Return())
|
},
|
||||||
wait3.Done()
|
}
|
||||||
}()
|
|
||||||
wait1.Done()
|
for _, test := range tests {
|
||||||
wait2.Wait()
|
test := test
|
||||||
assert.Nil(t, limit.Borrow(time.Second))
|
t.Run(test.name, func(t *testing.T) {
|
||||||
wait3.Wait()
|
limit := NewTimeoutLimit(2)
|
||||||
assert.Equal(t, ErrTimeout, limit.Borrow(time.Millisecond*100))
|
assert.Nil(t, limit.Borrow(time.Millisecond*200))
|
||||||
assert.Nil(t, limit.Return())
|
assert.Nil(t, limit.Borrow(time.Millisecond*200))
|
||||||
assert.Nil(t, limit.Return())
|
var wait1, wait2, wait3 sync.WaitGroup
|
||||||
assert.Equal(t, ErrLimitReturn, limit.Return())
|
wait1.Add(1)
|
||||||
|
wait2.Add(1)
|
||||||
|
wait3.Add(1)
|
||||||
|
go func() {
|
||||||
|
wait1.Wait()
|
||||||
|
wait2.Done()
|
||||||
|
time.Sleep(test.interval)
|
||||||
|
assert.Nil(t, limit.Return())
|
||||||
|
wait3.Done()
|
||||||
|
}()
|
||||||
|
wait1.Done()
|
||||||
|
wait2.Wait()
|
||||||
|
assert.Nil(t, limit.Borrow(time.Second))
|
||||||
|
wait3.Wait()
|
||||||
|
assert.Equal(t, ErrTimeout, limit.Borrow(time.Millisecond*100))
|
||||||
|
assert.Nil(t, limit.Return())
|
||||||
|
assert.Nil(t, limit.Return())
|
||||||
|
assert.Equal(t, ErrLimitReturn, limit.Return())
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package threading
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
@@ -13,6 +14,11 @@ func GoSafe(fn func()) {
|
|||||||
go RunSafe(fn)
|
go RunSafe(fn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GoSafeCtx runs the given fn using another goroutine, recovers if fn panics with ctx.
|
||||||
|
func GoSafeCtx(ctx context.Context, fn func()) {
|
||||||
|
go RunSafeCtx(ctx, fn)
|
||||||
|
}
|
||||||
|
|
||||||
// RoutineId is only for debug, never use it in production.
|
// RoutineId is only for debug, never use it in production.
|
||||||
func RoutineId() uint64 {
|
func RoutineId() uint64 {
|
||||||
b := make([]byte, 64)
|
b := make([]byte, 64)
|
||||||
@@ -31,3 +37,10 @@ func RunSafe(fn func()) {
|
|||||||
|
|
||||||
fn()
|
fn()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RunSafeCtx runs the given fn, recovers if fn panics with ctx.
|
||||||
|
func RunSafeCtx(ctx context.Context, fn func()) {
|
||||||
|
defer rescue.RecoverCtx(ctx)
|
||||||
|
|
||||||
|
fn()
|
||||||
|
}
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user