Compare commits

...

215 Commits

Author SHA1 Message Date
Kevin Wan
3bbc90ec24 refactor: move json related header vars to internal (#1840)
* refactor: move json related header vars to internal

* refactor: use header.ContentType
2022-04-28 15:12:04 +08:00
Kevin Wan
cef83efd4e fix #1838 (#1839) 2022-04-28 11:25:26 +08:00
anqiansong
cc09ab2aba feat: Support model code generation for multi tables (#1836)
* Support model code generation for multi tables

* Format code

* Format code

Co-authored-by: anqiansong <anqiansong@bytedance.com>
2022-04-28 10:01:04 +08:00
Kevin Wan
f7a60cdc24 fix: remove deprecated dependencies (#1837)
* fix: remove deprecated dependencies

* backup

* fix test error
2022-04-27 21:34:54 +08:00
Kevin Wan
c3a49ece8d Update readme-cn.md
add go-zero users.
2022-04-27 13:50:54 +08:00
Kevin Wan
1a38eddffe refactor: simplify the code (#1835) 2022-04-27 10:44:24 +08:00
Kevin Wan
5bcee4cf7c fix #1806 (#1833)
* fix #1806

* chore: refine error text
2022-04-27 00:01:31 +08:00
Kevin Wan
5c9fae7e62 feat: support sub domain for cors (#1827) 2022-04-25 21:56:59 +08:00
Kevin Wan
ec3e02624c feat: upgrade grpc to 1.46, and remove the deprecated grpc.WithBalancerName (#1820) 2022-04-24 22:42:40 +08:00
chen quan
22b157bb6c chore: optimize code (#1818)
Signed-off-by: chenquan <chenquan.dev@gmail.com>
2022-04-23 22:02:04 +08:00
Kevin Wan
095b603788 chore: remove gofumpt -s flag, default to be enabled (#1816) 2022-04-22 14:37:17 +08:00
Kevin Wan
bc3c9484d1 chore: refactor (#1814) 2022-04-22 09:37:09 +08:00
chen quan
162e9cef86 feat: add trace in redis & mon & sql (#1799)
* feat: add sub spanId with redis

Signed-off-by: chenquan <chenquan.dev@gmail.com>

* add tests

Signed-off-by: chenquan <chenquan.dev@gmail.com>

* fix a bug

Signed-off-by: chenquan <chenquan.dev@gmail.com>

* feat: add sub spanId in sql

Signed-off-by: chenquan <chenquan.dev@gmail.com>

* feat: add sub spanId in mon

Signed-off-by: chenquan <chenquan.dev@gmail.com>

* chore: optimize code

Signed-off-by: chenquan <chenquan.dev@gmail.com>

* feat: add breaker in warpSession

Signed-off-by: chenquan <chenquan.dev@gmail.com>

* chore: optimize code

Signed-off-by: chenquan <chenquan.dev@gmail.com>

* test: add tests

Signed-off-by: chenquan <chenquan.dev@gmail.com>

* chore: reformat code

Signed-off-by: chenquan <chenquan.dev@gmail.com>

* fix: fix typo

Signed-off-by: chenquan <chenquan.dev@gmail.com>

* fix a bug

Signed-off-by: chenquan <chenquan.dev@gmail.com>
2022-04-22 09:04:44 +08:00
Vee Zhang
94ddb3380e fix: rest: WriteJson get 200 when Marshal failed. (#1803)
Only the first WriteHeader call takes effect.
2022-04-21 21:55:01 +08:00
anqiansong
16c61c6657 chore: Embed unit test data (#1812)
* Embed unit test data

* Add testdata

Co-authored-by: anqiansong <anqiansong@bytedance.com>
2022-04-21 21:49:09 +08:00
chowyu12
14bf2f33f7 add go-grpc_opt and go_opt for grpc new command (#1769)
Co-authored-by: zhouyy <zhouyy@ickey.cn>
2022-04-21 16:45:56 +08:00
anqiansong
305587aa81 fix: Fix issue #1810 (#1811)
* Fix #1810

* Remove go embed

* Format code

* Remove useless code

Co-authored-by: anqiansong <anqiansong@bytedance.com>
2022-04-21 15:22:43 +08:00
Kevin Wan
2cdff97934 feat: use mongodb official driver instead of mgo (#1782)
* wip: backup

* wip: backup

* wip: backup

* backup

* backup

* backup

* add more tests

* fix wrong dependency

* fix lint errors

* remove test due to data race

* add tests

* fix test error

* add mon.Model

* add mon.Model unmarshal

* add monc

* add more tests for monc

* add more tests for monc

* add docs for mon and monc packages

* fix doc errors

* chhore: add comment

* chore: fix test bug

* chore: refine tests

* chore: remove primitive.NewObjectID in test code

* chore: rename test files for typo
2022-04-19 14:03:04 +08:00
fang duan
bbe1249ecb update rpc generate sample proto file (#1709)
* update rpc generate sample proto file

* update
2022-04-19 10:59:16 +08:00
Fyn
e62870e268 feat(goctl): go work multi-module support (#1800)
* feat(goctl): go work multi-module support

Resolve: #1793

* chore: print log when getting project ctx fails
2022-04-18 20:36:41 +08:00
Kevin Wan
92b450eb11 fix: ignore timeout on websocket (#1802) 2022-04-18 20:14:46 +08:00
杨圆建
d58cf7a12a fix: Hdel check result & Pfadd check result (#1801) 2022-04-18 17:38:36 +08:00
Fyn
036d803fbb docs(goctl): goctl 1.3.4 migration note (#1780)
* docs(goctl): goctl 1.3.4 migration note

* adds a simple lang check
* adds migration notes

* chore: remove i18n

* chore: remove todo
2022-04-18 14:42:13 +08:00
chen quan
c6ab11b14f chore: use grpc.WithTransportCredentials and insecure.NewCredentials() instead of grpc.WithInsecure (#1798)
Signed-off-by: chenquan <chenquan.dev@gmail.com>
2022-04-18 14:15:09 +08:00
Kevin Wan
9e20b1bbfe chhore: fix usage typo (#1797) 2022-04-17 21:17:31 +08:00
fang duan
fadef0ccd9 goctl api new should given a service_name explictly (#1688) 2022-04-17 20:59:18 +08:00
fang duan
4382ec0e0d show help when running goctl api without any flags (#1678)
close #1676
2022-04-17 20:58:12 +08:00
fang duan
db99addc64 show help when running goctl docker without any flags (#1679)
close #1677
2022-04-17 20:57:46 +08:00
fang duan
97bf3856c1 show help when running goctl rpc protoc without any flags (#1683) 2022-04-17 20:57:26 +08:00
fang duan
ff6c6558dd improve goctl rpc new (#1687) 2022-04-17 20:56:56 +08:00
Kevin Wan
5d4e7c84ee revert postgres package refactor (#1796)
* Revert "refactor: move postgres to pg package (#1781)"

This reverts commit ba8ac974aa.

* remove pg, use postgres
2022-04-17 12:07:48 +08:00
Kevin Wan
cb4fcf2c6c fix marshal ptr in httpc (#1789)
* fix marshal ptr in httpc

* add more tests

* add more tests

* add more tests

* fix issue on options and optional both provided
2022-04-15 19:07:34 +08:00
Fyn
ee88abce14 fix(goctl): api/new/api.tpl (#1788) 2022-04-14 23:43:48 +08:00
Kevin Wan
ecc3653d44 fix #1729 (#1783) 2022-04-13 19:06:00 +08:00
Kevin Wan
ba8ac974aa refactor: move postgres to pg package (#1781) 2022-04-13 12:46:09 +08:00
Kevin Wan
50de01fb49 feat: add httpc.Do & httpc.Service.Do (#1775)
* backup

* backup

* backup

* feat: add httpc.Do & httpc.Service.Do

* fix: not using strings.Cut, it's from Go 1.18

* chore: remove redudant code

* feat: httpc.Do finished

* chore: fix reviewdog

* chore: break loop if found

* add more tests
2022-04-11 11:00:28 +08:00
方航
fabea4c448 fix bug: crash when generate model with goctl. (#1777)
* fix bug: crash when generate model with goctl.

situation: column name with line.

CREATE TABLE test (
id int NOT NULL AUTO_INCREMENT,
zh-cn text CHARACTER SET utf8 COLLATE utf8_general_ci COMMENT '中文简体',
PRIMARY KEY (id) USING BTREE,
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8;

* group imports

group imports

* Use

go-zero/tools/goctl/util/string.go
 func SafeString(in string) string {
instead of ReplaceAll

Co-authored-by: 方航 <fanghang@tange.ai>
2022-04-11 10:11:40 +08:00
Fyn
6d9dfc08f9 feat(goctl): supports api multi-level importing (#1747)
* feat(goctl): supports api  multi-level importing

Resolves: #1744

* fix(goctl): import-cycle, etc.

import-cycle will not be allowed
e.g., a.api -> b.api -> a.api
regular multiple-import will be allowed
e.g., a.api -> b.api -> c.api
                   -> c.api

* refactor(goctl): adds comments to exported var

* fix(goctl): typo in a comment
2022-04-09 23:26:57 +08:00
anqiansong
252fabcc4b fix nil pointer if group not exists (#1773)
Co-authored-by: anqiansong <anqiansong@bytedance.com>
2022-04-09 11:54:30 +08:00
Kevin Wan
415c4c91fc fix: model unique keys generated differently in each re-generation (#1771) 2022-04-09 00:25:23 +08:00
fang duan
0cc9d4ff8d show help when running goctl rpc template without any flags (#1685)
close #1684
2022-04-08 22:28:45 +08:00
Kevin Wan
8bc34defc4 chore: avoid deadlock after stopping TimingWheel (#1768) 2022-04-07 11:50:18 +08:00
anqiansong
8dd764679c Fix #1765 (#1767)
Co-authored-by: anqiansong <anqiansong@bytedance.com>
2022-04-07 10:40:21 +08:00
Kevin Wan
9fe868ade9 chore: remove legacy code (#1766) 2022-04-06 23:24:20 +08:00
Kevin Wan
4e48286838 chore: add doc (#1764) 2022-04-06 22:42:40 +08:00
Kevin Wan
ab01442d46 add more tests (#1763)
* feat: add goctl docker build scripts

* chore: add more tests
2022-04-06 16:09:06 +08:00
Kevin Wan
8694e38384 feat: add goctl docker build scripts (#1760) 2022-04-05 13:07:05 +08:00
Kevin Wan
d5e550e79b Update readme-cn.md 2022-04-05 11:51:53 +08:00
Kevin Wan
affdab660e Update readme.md 2022-04-05 11:51:09 +08:00
Kevin Wan
7d5858e83a Update readme.md 2022-04-05 11:08:00 +08:00
Kevin Wan
815a6a6485 Update readme-cn.md 2022-04-05 11:07:37 +08:00
benqi
475d17e17d feat: support ctx in kv methods (#1759) 2022-04-04 23:19:58 +08:00
Kevin Wan
8472415472 fix #1754 (#1757) 2022-04-04 22:13:08 +08:00
Kevin Wan
faad6e27e3 feat: use go:embed to embed templates (#1756) 2022-04-04 13:12:05 +08:00
anqiansong
58a0b17451 Support goctl env install (#1752)
Co-authored-by: anqiansong <anqiansong@bytedance.com>
2022-04-03 21:58:43 +08:00
Kevin Wan
89eccfdb97 chore: update go-zero to v1.3.2 in goctl (#1750) 2022-04-03 20:44:33 +08:00
Kevin Wan
78ea0769fd feat: simplify httpc (#1748)
* feat: simplify httpc

* chore: fix lint errors

* chore: fix log url issue

* chore: fix log url issue

* refactor: handle resp & err in ResponseHandler

* chore: remove unnecessary var names in return clause
2022-04-03 14:32:27 +08:00
Kevin Wan
e0fa8d820d feat: return original value of setbit in redis (#1746) 2022-04-02 20:25:51 +08:00
Kevin Wan
dfd58c213c fix: model generation bug on with cache (#1743)
* fix: model generation bug on with cache

* chore: refine template

* chore: fix test failure
2022-04-02 15:36:06 +08:00
Kevin Wan
83cacf51b7 chore: update goctl version to 1.3.4 (#1742) 2022-04-02 14:19:34 +08:00
Kevin Wan
6dccfa29fd feat: let model customizable (#1738) 2022-04-01 22:19:52 +08:00
anqiansong
7e0b0ab0b1 Fix zrpc code generation error with --remote (#1739)
Co-authored-by: anqiansong <anqiansong@bytedance.com>
2022-04-01 22:19:33 +08:00
Kevin Wan
ac18cc470d chore: refactor to use const instead of var (#1731) 2022-04-01 15:23:45 +08:00
Fyn
f4471846ff feat(goctl): supports model code 'DO NOT EDIT' (#1728)
Resolves: #1710
2022-04-01 14:48:45 +08:00
anqiansong
9c2d526a11 Fix unit test (#1730)
Co-authored-by: anqiansong <anqiansong@bytedance.com>
2022-04-01 14:46:12 +08:00
Kevin Wan
2b9fc26c38 refactor: guard timeout on API files (#1726) 2022-03-31 21:39:02 +08:00
Xiaoju Jiang
321dc2d410 Added support for setting the parameter size accepted by the interface and custom timeout and maxbytes in API file (#1713)
* Added support for setting the parameter size accepted by the interface

* support custom timeout and maxbytes in API file

* support timeout used unit

* remove goctl maxBytes
2022-03-31 20:20:00 +08:00
Fyn
500bd87c85 fix(goctl): api format with reader input (#1722)
resolves #1721
2022-03-31 00:20:51 +08:00
Kevin Wan
e9620c8c05 chore: refactor code (#1708) 2022-03-24 22:10:15 +08:00
aimuz
70e51bb352 fix: empty slice are set to nil (#1702)
support for empty slce, Same behavior as json.Unmarshal
2022-03-24 21:41:38 +08:00
Kevin Wan
278cd123c8 feat: remove reentrance in redislock, timeout bug (#1704) 2022-03-24 16:17:01 +08:00
Kevin Wan
3febb1a5d0 chore: refactor code (#1700) 2022-03-23 19:09:45 +08:00
Mikael
d8054d8def fix -cache=true insert no clean cache (#1672)
* fix -cache=true insert no clean cache

* fix -cache=true insert no clean cache
2022-03-23 18:55:16 +08:00
Kevin Wan
ec271db7a0 chore: refactor code (#1699) 2022-03-23 18:24:44 +08:00
benqi
bbac994c8a feat: add getset command in redis and kv (#1693) 2022-03-23 18:02:56 +08:00
Kevin Wan
c1d9e6a00b feat: add httpc.Parse (#1698) 2022-03-23 17:58:21 +08:00
anqiansong
0aeb49a6b0 Add verbose flag (#1696)
Co-authored-by: anqiansong <anqiansong@bytedance.com>
2022-03-22 21:00:26 +08:00
Kevin Wan
fe262766b4 chore: fix lint issue (#1694) 2022-03-22 13:31:05 +08:00
Kevin Wan
7181505c8a Update LICENSE 2022-03-21 10:32:41 +08:00
Kevin Wan
f060a226bc refactor: simplify the code (#1670) 2022-03-20 17:26:12 +08:00
Mervin.Wong
93d524b797 fix: the new RawFieldNames considers the tag with options. (#1663)
Co-authored-by: JinfaWang <wangjinfa@iie.ac.cn>
2022-03-20 16:59:19 +08:00
anqiansong
5c169f4f49 Remove debug log (#1669)
Co-authored-by: anqiansong <anqiansong@bytedance.com>
2022-03-20 16:28:36 +08:00
Kevin Wan
d29dfa12e3 feat: support -base to specify base image for goctl docker (#1668)
* feat: support -base to specify base image for goctl docker

* chore: update usage
2022-03-20 11:17:55 +08:00
anqiansong
194f55e08e Remove unused code (#1667)
Co-authored-by: anqiansong <anqiansong@bytedance.com>
2022-03-19 23:15:11 +08:00
Kevin Wan
c0f9892fe3 feat: add Dockerfile for goctl (#1666) 2022-03-19 23:07:17 +08:00
anqiansong
227104d7d7 feat: Remove command goctl rpc proto (#1665)
* Fix goctl completion expression

* Fix code generation error if the pkg of pb/grpc is same to zrpc call client pkg

* Remove deprecated comment on action goctl rpc new

* Remove zrpc code generation on action goctl rpc proto

* Remove zrpc code generation on action goctl rpc proto

* Remove Generator interface

Co-authored-by: anqiansong <anqiansong@bytedance.com>
2022-03-19 22:50:22 +08:00
anqiansong
448029aa4b Mkdir if not exists (#1659)
Co-authored-by: anqiansong <anqiansong@bytedance.com>
2022-03-17 21:44:22 +08:00
Fyn
17e0afeac0 fix(goctl): model method FindOneCtx should be FindOne (#1656) 2022-03-17 17:16:53 +08:00
ronething-bot
18916b5189 [fix] typo (#1655) 2022-03-17 10:00:29 +08:00
Kevin Wan
c11a09be23 chore: remove unnecessary env (#1654) 2022-03-16 17:31:10 +08:00
ronething-bot
56e1ecf2f3 fix: typo (#1646) 2022-03-15 17:46:13 +08:00
Kevin Wan
f9e6013a6c refactor: httpc package for easy to use (#1645) 2022-03-15 14:18:46 +08:00
Kevin Wan
b5d1d8b0d1 refactor: httpc package for easy to use (#1643) 2022-03-14 20:15:14 +08:00
xybingbing
09e6d94f9e FindOneBy 漏 Context (#1642) 2022-03-14 18:56:26 +08:00
Kevin Wan
2a5717d7fb feat: add httpc/Service for convinience (#1641) 2022-03-14 15:36:06 +08:00
Kevin Wan
85cf662c6f feat: add httpc/Get httpc/Post (#1640) 2022-03-13 14:49:14 +08:00
Kevin Wan
3279a7ef0f feat: add rest/httpc to make http requests governacible (#1638)
* feat: change x-trace-id to traceparent to follow opentelemetry

* feat: add rest/httpc to make http requests governacible

* chore: remove blank lines
2022-03-13 14:11:14 +08:00
Kevin Wan
fec908a19b Update ROADMAP.md
update roadmap.
2022-03-13 14:09:11 +08:00
Kevin Wan
f5ed0cda58 Update ROADMAP.md
update roadmap.
2022-03-13 14:08:28 +08:00
anqiansong
cc9d16f505 fix: Update unix-like path regex (#1637)
* Revert import value regex

* Update linux path regex

Co-authored-by: anqiansong <anqiansong@bytedance.com>
2022-03-12 22:21:17 +08:00
Kevin Wan
c05d74b44c feat: support cpu stat on cgroups v2 (#1636)
* feat: cpu stat

* feat: add cpu stat for cgroup2

* feat: add cpu stat for cgroup2, tidy mod

* feat: support cpu stat in cgroup v2
2022-03-12 21:00:04 +08:00
mlr3000
32c88b6352 feat: support oracle :N dynamic parameters (#1552)
* chore:use struct pointer

* feat: support oracle :N dynamic parameters

* Update utils.go

* Update utils.go

* Update utils.go

pg argIndex will not always go up

* Update utils_test.go

* Keep the original

* Update utils_test.go
2022-03-12 18:49:07 +08:00
Kevin Wan
7dabec260f Update readme-cn.md
update readme.
2022-03-12 16:09:18 +08:00
Kevin Wan
4feb88f9b5 Update readme-cn.md
update readme.
2022-03-12 15:23:52 +08:00
Kevin Wan
2776caed0e Update readme.md
update readme.
2022-03-12 15:19:51 +08:00
chensy
c55694d957 Support for referencing types in different API files using format (#1630) 2022-03-12 15:17:31 +08:00
Ziyi Zhang
209ffb934b fix(goctl): kotlin code generation (#1632)
Signed-off-by: Ziyi Zhang <soasurs@gmail.com>
2022-03-11 13:44:18 +08:00
Kevin Wan
26a33932cd feat: support scratch as the base docker image (#1634) 2022-03-11 12:15:38 +08:00
Kevin Wan
d6a692971f chore: reduce the docker image size (#1633)
* chore: reduce the docker image size

* chore: format dockerfile
2022-03-11 11:30:21 +08:00
anqiansong
4624390e54 Fix #1585 #1547 (#1624) 2022-03-09 19:26:35 +08:00
Kevin Wan
63b7d292c1 chore: update goctl version to 1.3.3, change docker build temp dir (#1621) 2022-03-07 14:44:12 +08:00
Fyn
365c569d7c fix(goctl): dart gen user defined struct array (#1620) 2022-03-07 14:11:47 +08:00
anqiansong
68a81fea8a Fix #1609 (#1617) 2022-03-05 22:52:32 +08:00
anqiansong
08a8bd7ef7 Fix #1614 (#1616) 2022-03-05 21:40:41 +08:00
Kevin Wan
b939ce75ba chore: refactor code (#1613) 2022-03-04 17:55:13 +08:00
Kevin Wan
3b7ca86e4f chore: add unit tests (#1615)
* test: add more tests

* test: add more tests
2022-03-04 17:54:09 +08:00
Javen
60760b52ab model中db标签增加'-'符号以支持数据库查询时忽略对应字段. (#1612) 2022-03-04 17:00:46 +08:00
qi
96c128c58a fix: HitQuota should be returned instead of Allowed when limit is equal to 1. (#1581) 2022-03-04 16:14:45 +08:00
Fyn
0c35f39a7d fix: fix(gctl): apiparser_parser auto format (#1607) 2022-03-04 15:36:20 +08:00
Fyn
6a66dde0a1 feat(goctl): api dart support flutter v2 (#1603)
0. support null-safety code gen
1. supports -legacy flag for legacy code gen
2. supports -hostname flag for server hostname
3. use dart official format
4. fix some some bugs

Resolves: #1602
2022-03-04 15:34:13 +08:00
Kevin Wan
36b9fcba44 Update readme-cn.md 2022-03-03 14:35:48 +08:00
Kevin Wan
bf99dda620 Update readme-cn.md 2022-03-03 14:35:10 +08:00
Kevin Wan
511dfcb409 Update readme.md 2022-03-03 14:34:34 +08:00
Kevin Wan
900bc96420 test: add more tests (#1604) 2022-03-02 21:19:04 +08:00
Kevin Wan
be277a7376 Update readme-cn.md
add go-zero users.
2022-03-02 21:18:31 +08:00
Kevin Wan
f15a4f9188 chore: update go-zero to v1.3.1 in goctl (#1599) 2022-03-01 20:56:57 +08:00
Kevin Wan
e31128650e Revert "🐞 fix(gen): pg gen of insert (#1591)" (#1598)
This reverts commit cc4c4928e0.
2022-03-01 20:27:59 +08:00
Kevin Wan
168740b64d chore: upgrade etcd (#1597) 2022-03-01 20:16:44 +08:00
toutou_o
cc4c4928e0 🐞 fix(gen): pg gen of insert (#1591)
Co-authored-by: kurimi1 <d0n41df@gmail.com>
2022-03-01 19:53:23 +08:00
Fyn
fba6543b23 fix: goctl api dart support form tag (#1596) 2022-03-01 16:17:37 +08:00
Kevin Wan
877eb6ac56 Update readme.md
add producthunt.
2022-03-01 16:11:17 +08:00
Kevin Wan
259a5a13e7 chore: fix data race (#1593) 2022-02-28 23:17:51 +08:00
Fyn
cf7c7cb392 build: update goctl dependency ddl-parser to v1.0.3 (#1586)
* build: update goctl dependency ddl-parser to v1.0.3

* fix: race condition when testing logx

Resolves: #1587
2022-02-28 17:31:59 +08:00
ccx
86d01e2e99 test: add testcase for FIFO Queue in collection module (#1589)
cover the case of non-zero value for q.Header when q.Elements expands
2022-02-28 17:15:11 +08:00
Kevin Wan
7a28e19a27 Update readme-cn.md 2022-02-27 23:30:14 +08:00
Kevin Wan
900ea63d68 Update readme-cn.md
add migration notice.
2022-02-27 23:29:45 +08:00
Kevin Wan
87ab86cdd0 Update readme.md 2022-02-27 23:26:03 +08:00
Kevin Wan
0697494ffd Update readme.md
Add migrate steps.
2022-02-27 23:25:22 +08:00
anqiansong
ffd69a2f5e Fix bug int overflow while build goctl on arch 386 (#1582)
Co-authored-by: anqiansong <anqiansong@bytedance.com>
2022-02-27 10:51:57 +08:00
Kevin Wan
66f10bb5e6 chore: add goctl command help (#1578) 2022-02-26 17:02:04 +08:00
Kevin Wan
8131a0e777 Update readme.md
update chat link.
2022-02-25 23:02:09 +08:00
Kevin Wan
32a557dff6 Update readme.md
add discord.
2022-02-25 23:01:15 +08:00
Fyn
db949e40f1 feat: supports importValue for more path formats (#1569)
`importValueRegex` now can match more path formats

Resolves: #1568
2022-02-25 11:16:57 +08:00
Kevin Wan
e0454138e0 update goctl to go 1.16 for io/fs usage (#1571)
* update goctl to go 1.16 for io/fs usage

* feat: support pg serial type for auto_increment (#1563)

* add correct example for pg's url

* 🐞 fix: merge

* 🐞 fix: pg default port

*  feat: support serial type

Co-authored-by: kurimi1 <d0n41df@gmail.com>

* chore: format code

Co-authored-by: toutou_o <33993460+kurimi1@users.noreply.github.com>
Co-authored-by: kurimi1 <d0n41df@gmail.com>
2022-02-24 13:58:53 +08:00
toutou_o
3b07ed1b97 feat: support pg serial type for auto_increment (#1563)
* add correct example for pg's url

* 🐞 fix: merge

* 🐞 fix: pg default port

*  feat: support serial type

Co-authored-by: kurimi1 <d0n41df@gmail.com>
2022-02-24 13:39:31 +08:00
anqiansong
daa98f5a27 Feature: Add goctl env (#1557) 2022-02-21 10:19:33 +08:00
Kevin Wan
842656aa90 feat: log 404 requests with traceid (#1554) 2022-02-19 20:50:33 +08:00
Kevin Wan
aa29036cb3 feat: support ctx in sql model generation (#1551) 2022-02-17 10:28:55 +08:00
Kevin Wan
607bae27fa feat: support ctx in sqlx/sqlc, listed in ROADMAP (#1535)
* feat: support ctx in sqlx/sqlc

* chore: update roadmap

* fix: context.Canceled should be acceptable

* use %w to wrap errors

* chore: remove unused vars
2022-02-16 19:31:43 +08:00
Kevin Wan
7c63676be4 docs: add go-zero users (#1546) 2022-02-16 12:02:52 +08:00
Kevin Wan
9e113909b3 ignore context.Canceled for redis breaker (#1545) 2022-02-15 21:31:30 +08:00
Kevin Wan
bd105474ca chore: update help message (#1544) 2022-02-15 21:19:40 +08:00
Mikael
a078f5d764 add the serviceAccount of deployment (#1543)
Co-authored-by: 977231903@qq.com <>
2022-02-15 20:57:14 +08:00
Kevin Wan
b215fa3ee6 fix #1541 (#1542) 2022-02-15 18:40:26 +08:00
mlr3000
50b1928502 chore:use struct pointer (#1538) 2022-02-15 11:34:48 +08:00
Kevin Wan
493e3bcf4b docs: update roadmap (#1537) 2022-02-15 08:37:03 +08:00
Kevin Wan
6deb80625d fix issue of default migrate version (#1536)
* fix issue of default migrate version

* chore: update console colors
2022-02-14 23:09:32 +08:00
Kevin Wan
6ab051568c Update readme-cn.md
add go-zero users
2022-02-14 16:57:48 +08:00
Kevin Wan
2732d3cdae chore: refactor cache (#1532) 2022-02-13 18:04:31 +08:00
chenquan
e8c307e4dc feat: support ctx in Cache (#1518)
* feature: support ctx in `Cache`

Signed-off-by: chenquan <chenquan.dev@foxmail.com>

* fix: `errors.Is` instead of `=`

Signed-off-by: chenquan <chenquan.dev@foxmail.com>
2022-02-13 17:28:14 +08:00
Kevin Wan
84ddc660c4 chore: goctl format issue (#1531) 2022-02-13 13:17:19 +08:00
Kevin Wan
e60e707955 upgrade grpc version (#1530) 2022-02-12 23:58:41 +08:00
Kevin Wan
cf4321b2d0 fix #1525 (#1527) 2022-02-11 23:04:57 +08:00
chenquan
1993faf2f8 fix: fix a typo (#1522)
Signed-off-by: chenquan <chenquan.dev@foxmail.com>
2022-02-11 21:15:45 +08:00
Kevin Wan
0ce85376bf chore: update goctl version to 1.3.2 (#1524) 2022-02-11 21:02:50 +08:00
Kevin Wan
a40254156f refactor: refactor yaml unmarshaler (#1517) 2022-02-09 17:22:52 +08:00
chenquan
05cc62f5ff chore: optimize yaml unmarshaler (#1513) 2022-02-09 16:57:00 +08:00
chenquan
9c2c90e533 chore: make error clearer (#1514) 2022-02-09 14:40:05 +08:00
Kevin Wan
822ee2e1c5 feat: update go-redis to v8, support ctx in redis methods (#1507)
* feat: update go-redis to v8, support ctx in redis methods

* fix compile errors

* chore: remove unused const

* chore: add tracing log on redis
2022-02-09 11:06:06 +08:00
anqiansong
77482c8946 fixes typo (#1511)
Co-authored-by: anqiansong <anqiansong@bytedance.com>
2022-02-08 22:16:38 +08:00
Kevin Wan
7ef0ab3119 Update readme-cn.md 2022-02-08 11:47:33 +08:00
anqiansong
8bd89a297a feature: Add goctl completion (#1505)
* feature: Add `goctl completion`

* Update const

Co-authored-by: anqiansong <anqiansong@bytedance.com>
2022-02-08 10:50:21 +08:00
Kevin Wan
bb75cc796e test: change fuzz tests (#1504) 2022-02-05 09:44:01 +08:00
Kevin Wan
0fdd8f54eb ci: add test for win (#1503)
* ci: add test for win

* ci: update check names

* ci: use go build instead of go test to verify win test

* fix: windows test failure

* chore: disable logs in tests
2022-02-05 00:06:23 +08:00
anqiansong
b1ffc464cd fix typo: goctl protoc usage (#1502)
Co-authored-by: anqiansong <anqiansong@bytedance.com>
2022-02-03 22:13:02 +08:00
Kevin Wan
50174960e4 chore: update command comment (#1501) 2022-02-02 22:02:08 +08:00
Kevin Wan
8f46eab977 fix: goctl not compile on windows (#1500) 2022-02-01 13:58:08 +08:00
Kevin Wan
ec299085f5 docs: update tal-tech to zeromico in docs (#1498) 2022-02-01 13:03:30 +08:00
Kevin Wan
7727d70634 chore: update goctl version (#1497) 2022-02-01 09:50:26 +08:00
Kevin Wan
5f9d101bc6 feat: add runtime stats monitor (#1496) 2022-02-01 01:34:25 +08:00
Kevin Wan
6c2abe7474 fix: goroutine stuck on edge case (#1495)
* fix: goroutine stuck on edge case

* refactor: simplify mapreduce implementation
2022-01-30 13:09:21 +08:00
Kevin Wan
14a902c1a7 feat: handling panic in mapreduce, panic in calling goroutine, not inside goroutines (#1490)
* feat: handle panic

* chore: update fuzz test

* chore: optimize square sum algorithm
2022-01-28 10:59:41 +08:00
Kevin Wan
5ad6a6d229 Update readme-cn.md
add slogan
2022-01-27 17:16:30 +08:00
Kevin Wan
6f4b97864a chore: improve migrate confirmation (#1488) 2022-01-27 11:30:35 +08:00
Kevin Wan
0e0abc3a95 chore: update warning message (#1487) 2022-01-26 23:47:57 +08:00
anqiansong
696fda1db4 patch: goctl migrate (#1485)
* * Add signal check
* Add deprecated pkg check

* fix typo `replacementBuilderx`

* output to console if verbose

Co-authored-by: anqiansong <anqiansong@bytedance.com>
2022-01-26 23:24:25 +08:00
Kevin Wan
c1d2634427 chore: update go version for goctl (#1484) 2022-01-26 14:27:43 +08:00
Kevin Wan
4b7a680ac5 refactor: rename from tal-tech to zeromicro for goctl (#1481) 2022-01-25 23:15:07 +08:00
Kevin Wan
b3e7d2901f Feature/trie ac automation (#1479)
* fix: trie ac automation issues

* fix: trie ac automation issues

* fix: trie ac automation issues

* fix: trie ac automation issues
2022-01-25 11:14:56 +08:00
anqiansong
cdf7ec213c fix #1468 (#1478)
Co-authored-by: anqiansong <anqiansong@bytedance.com>
2022-01-24 22:23:20 +08:00
Kevin Wan
f1102fb262 chore: optimize string search with Aho–Corasick algorithm (#1476)
* chore: optimize string search with Aho–Corasick algorithm

* chore: optimize keywords replacer

* fix: replacer bugs

* chore: reorder members
2022-01-23 23:37:02 +08:00
Keqi Huang
09d1fad6e0 Polish the words in readme.md (#1475) 2022-01-22 12:20:11 +08:00
Kevin Wan
379c65a3ef docs: add go-zero users (#1473) 2022-01-20 22:36:17 +08:00
Kevin Wan
fdc7f64d6f chore: update unauthorized callback calling order (#1469)
* chore: update unauthorized callback calling order

* chore: add comments
2022-01-20 21:09:45 +08:00
anqiansong
df0f8ed59e Fix/issue#1289 (#1460)
* fix #1289

* Add unit test case

* fix `jwtTransKey`

* fix `jwtTransKey`

Co-authored-by: anqiansong <anqiansong@bytedance.com>
2022-01-18 11:52:30 +08:00
anqiansong
c903966fc7 patch: save missing templates to disk (#1463)
Co-authored-by: anqiansong <anqiansong@bytedance.com>
2022-01-18 10:45:05 +08:00
anqiansong
e57fa8ff53 Fix/issue#1447 (#1458)
* Add data for template to render

* fix #1447

Co-authored-by: anqiansong <anqiansong@bytedance.com>
2022-01-18 10:36:38 +08:00
Kevin Wan
bf2feee5b7 feat: implement console plain output for debug logs (#1456)
* feat: implement console plain output for debug logs

* chore: rename console encoding to plain

* chore: refactor names
2022-01-17 12:43:15 +08:00
Letian Jiang
ce05c429fc chore: check interface satisfaction w/o allocating new variable (#1454) 2022-01-16 23:34:42 +08:00
Kevin Wan
272a3f347d chore: remove jwt deprecated (#1452) 2022-01-16 10:34:44 +08:00
shenbaise9527
13db7a1931 feat: 支持redis的LTrim方法 (#1443) 2022-01-16 10:27:34 +08:00
Kevin Wan
468c237189 chore: upgrade dependencies (#1444)
* chore: upgrade dependencies

* ci: upgrade go to 1.15
2022-01-14 11:01:02 +08:00
Kevin Wan
b9b80c068b ci: add translator action (#1441) 2022-01-12 17:57:39 +08:00
anqiansong
9b592b3dee Feature rpc protoc (#1251)
* code generation by protoc

* generate pb by protoc direct

* support: grpc code generation by protoc directly

* format code

* check --go_out & --go-grpc_out

* fix typo

* Update version

* fix typo

* optimize: remove deprecated unit test

* format code

Co-authored-by: anqiansong <anqiansong@bytedance.com>
2022-01-11 20:34:25 +08:00
Kevin Wan
2203809e5e chore: fix typo (#1437) 2022-01-11 20:23:59 +08:00
Kevin Wan
8d6d37f71e remove unnecessary drain, fix data race (#1435)
* remove unnecessary drain, fix data race

* chore: fix parameter order

* refactor: rename MapVoid to ForEach in mr
2022-01-11 16:17:51 +08:00
Kevin Wan
ea4f2af67f fix: mr goroutine leak on context deadline (#1433)
* fix: mr goroutine leak on context deadline

* test: update fx test check
2022-01-10 22:06:10 +08:00
Kevin Wan
53af194ef9 chore: refactor periodlimit (#1428)
* chore: refactor periodlimit

* chore: add comments
2022-01-09 16:22:34 +08:00
Kevin Wan
5e0e2d2b14 docs: add go-zero users (#1425) 2022-01-08 21:41:27 +08:00
Kevin Wan
74c99184c5 docs: add go-zero users (#1424) 2022-01-08 17:08:44 +08:00
Kevin Wan
eb4b86137a fix: golint issue (#1423) 2022-01-08 16:06:56 +08:00
Kevin Wan
9c4f4f3b4e update docs (#1421) 2022-01-07 12:08:45 +08:00
spectatorMrZ
240132e7c7 Fix pg model generation without tag (#1407)
1. fix pg model struct haven't tag
2. add pg command test from datasource
2022-01-07 10:45:26 +08:00
anqiansong
9d67fc4cfb feat: Add migrate (#1419)
* Add migrate

* Remove unused module

* refactor filename

* rename refactor to migrate

Co-authored-by: anqiansong <anqiansong@bytedance.com>
2022-01-06 18:48:34 +08:00
Kevin Wan
892f93a716 docs: update install readme (#1417) 2022-01-05 12:31:49 +08:00
441 changed files with 16067 additions and 4035 deletions

View File

@@ -7,32 +7,50 @@ on:
branches: [ master ] branches: [ master ]
jobs: jobs:
build: test-linux:
name: Build name: Linux
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Set up Go 1.x
uses: actions/setup-go@v2
with:
go-version: ^1.15
id: go
- name: Set up Go 1.x - name: Check out code into the Go module directory
uses: actions/setup-go@v2 uses: actions/checkout@v2
with:
go-version: ^1.14
id: go
- name: Check out code into the Go module directory - name: Get dependencies
uses: actions/checkout@v2 run: |
go get -v -t -d ./...
- name: Get dependencies - name: Lint
run: | run: |
go get -v -t -d ./... go vet -stdmethods=false $(go list ./...)
go install mvdan.cc/gofumpt@latest
test -z "$(gofumpt -l -extra .)" || echo "Please run 'gofumpt -l -w -extra .'"
- name: Lint - name: Test
run: | run: go test -race -coverprofile=coverage.txt -covermode=atomic ./...
go vet -stdmethods=false $(go list ./...)
go install mvdan.cc/gofumpt@latest
test -z "$(gofumpt -s -l -extra .)" || echo "Please run 'gofumpt -l -w -extra .'"
- name: Test - name: Codecov
run: go test -race -coverprofile=coverage.txt -covermode=atomic ./... uses: codecov/codecov-action@v2
- name: Codecov test-win:
uses: codecov/codecov-action@v2 name: Windows
runs-on: windows-latest
steps:
- name: Set up Go 1.x
uses: actions/setup-go@v2
with:
go-version: ^1.15
- name: Checkout codebase
uses: actions/checkout@v2
- name: Test
run: |
go mod verify
go mod download
go test -v -race ./...
cd tools/goctl && go build -v goctl.go

18
.github/workflows/issue-translator.yml vendored Normal file
View File

@@ -0,0 +1,18 @@
name: 'issue-translator'
on:
issue_comment:
types: [created]
issues:
types: [opened]
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: tomsun28/issues-translate-action@v2.6
with:
IS_MODIFY_TITLE: true
# not require, default false, . Decide whether to modify the issue title
# if true, the robot account @Issues-translate-bot must have modification permissions, invite @Issues-translate-bot to your project or use your custom bot.
CUSTOM_BOT_NOTE: Bot detected the issue body's language is not English, translate it automatically. 👯👭🏻🧑‍🤝‍🧑👫🧑🏿‍🤝‍🧑🏻👩🏾‍🤝‍👨🏿👬🏿
# not require. Customize the translation robot prefix message.

3
.gitignore vendored
View File

@@ -16,7 +16,8 @@
**/logs **/logs
# for test purpose # for test purpose
adhoc **/adhoc
go.work
# gitlab ci # gitlab ci
.cache .cache

View File

@@ -40,7 +40,7 @@ We will help you to contribute in different areas like filing issues, developing
getting your work reviewed and merged. getting your work reviewed and merged.
If you have questions about the development process, If you have questions about the development process,
feel free to [file an issue](https://github.com/tal-tech/go-zero/issues/new/choose). feel free to [file an issue](https://github.com/zeromicro/go-zero/issues/new/choose).
## Find something to work on ## Find something to work on
@@ -50,10 +50,10 @@ Here is how you get started.
### Find a good first topic ### Find a good first topic
[go-zero](https://github.com/tal-tech/go-zero) has beginner-friendly issues that provide a good first issue. [go-zero](https://github.com/zeromicro/go-zero) has beginner-friendly issues that provide a good first issue.
For example, [go-zero](https://github.com/tal-tech/go-zero) has For example, [go-zero](https://github.com/zeromicro/go-zero) has
[help wanted](https://github.com/tal-tech/go-zero/issues?q=is%3Aopen+is%3Aissue+label%3A%22help+wanted%22) and [help wanted](https://github.com/zeromicro/go-zero/issues?q=is%3Aopen+is%3Aissue+label%3A%22help+wanted%22) and
[good first issue](https://github.com/tal-tech/go-zero/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22) [good first issue](https://github.com/zeromicro/go-zero/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)
labels for issues that should not need deep knowledge of the system. labels for issues that should not need deep knowledge of the system.
We can help new contributors who wish to work on such issues. We can help new contributors who wish to work on such issues.
@@ -79,7 +79,7 @@ This is a rough outline of what a contributor's workflow looks like:
- Create a topic branch from where to base the contribution. This is usually master. - Create a topic branch from where to base the contribution. This is usually master.
- Make commits of logical units. - Make commits of logical units.
- Push changes in a topic branch to a personal fork of the repository. - Push changes in a topic branch to a personal fork of the repository.
- Submit a pull request to [go-zero](https://github.com/tal-tech/go-zero). - Submit a pull request to [go-zero](https://github.com/zeromicro/go-zero).
## Creating Pull Requests ## Creating Pull Requests

View File

@@ -1,6 +1,6 @@
MIT License MIT License
Copyright (c) 2020 xiaoheiban_server_go Copyright (c) 2022 zeromicro
Permission is hereby granted, free of charge, to any person obtaining a copy Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal of this software and associated documentation files (the "Software"), to deal

View File

@@ -20,9 +20,9 @@ We hope that the items listed below will inspire further engagement from the com
- [x] Support `goctl bug` to report bugs conveniently - [x] Support `goctl bug` to report bugs conveniently
## 2022 ## 2022
- [ ] Support `goctl mock` command to start a mocking server with given `.api` file - [x] Support `context` in redis related methods for timeout and tracing
- [ ] Add `httpx.Client` with governance, like circuit breaker etc. - [x] Support `context` in sql related methods for timeout and tracing
- [ ] Support `goctl doctor` command to report potential issues for given service
- [ ] Support `context` in redis related methods for timeout and tracing
- [ ] Support `context` in sql related methods for timeout and tracing
- [ ] Support `context` in mongodb related methods for timeout and tracing - [ ] Support `context` in mongodb related methods for timeout and tracing
- [x] Add `httpc.Do` with HTTP call governance, like circuit breaker etc.
- [ ] Support `goctl doctor` command to report potential issues for given service
- [ ] Support `goctl mock` command to start a mocking server with given `.api` file

View File

@@ -171,7 +171,7 @@ func (lt loggedThrottle) allow() (Promise, error) {
func (lt loggedThrottle) doReq(req func() error, fallback func(err error) error, acceptable Acceptable) error { func (lt loggedThrottle) doReq(req func() error, fallback func(err error) error, acceptable Acceptable) error {
return lt.logError(lt.internalThrottle.doReq(req, fallback, func(err error) bool { return lt.logError(lt.internalThrottle.doReq(req, fallback, func(err error) bool {
accept := acceptable(err) accept := acceptable(err)
if !accept { if !accept && err != nil {
lt.errWin.add(err.Error()) lt.errWin.add(err.Error())
} }
return accept return accept

View File

@@ -98,13 +98,18 @@ func (c *Cache) Get(key string) (interface{}, bool) {
// Set sets value into c with key. // Set sets value into c with key.
func (c *Cache) Set(key string, value interface{}) { func (c *Cache) Set(key string, value interface{}) {
c.SetWithExpire(key, value, c.expire)
}
// SetWithExpire sets value into c with key and expire with the given value.
func (c *Cache) SetWithExpire(key string, value interface{}, expire time.Duration) {
c.lock.Lock() c.lock.Lock()
_, ok := c.data[key] _, ok := c.data[key]
c.data[key] = value c.data[key] = value
c.lruCache.add(key) c.lruCache.add(key)
c.lock.Unlock() c.lock.Unlock()
expiry := c.unstableExpiry.AroundDuration(c.expire) expiry := c.unstableExpiry.AroundDuration(expire)
if ok { if ok {
c.timingWheel.MoveTimer(key, expiry) c.timingWheel.MoveTimer(key, expiry)
} else { } else {

View File

@@ -18,7 +18,7 @@ func TestCacheSet(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
cache.Set("first", "first element") cache.Set("first", "first element")
cache.Set("second", "second element") cache.SetWithExpire("second", "second element", time.Second*3)
value, ok := cache.Get("first") value, ok := cache.Get("first")
assert.True(t, ok) assert.True(t, ok)

View File

@@ -61,3 +61,41 @@ func TestPutMore(t *testing.T) {
assert.Equal(t, string(element), string(body.([]byte))) assert.Equal(t, string(element), string(body.([]byte)))
} }
} }
func TestPutMoreWithHeaderNotZero(t *testing.T) {
elements := [][]byte{
[]byte("hello"),
[]byte("world"),
[]byte("again"),
}
queue := NewQueue(4)
for i := range elements {
queue.Put(elements[i])
}
// take 1
body, ok := queue.Take()
assert.True(t, ok)
element, ok := body.([]byte)
assert.True(t, ok)
assert.Equal(t, element, []byte("hello"))
// put more
queue.Put([]byte("b4"))
queue.Put([]byte("b5")) // will store in elements[0]
queue.Put([]byte("b6")) // cause expansion
results := [][]byte{
[]byte("world"),
[]byte("again"),
[]byte("b4"),
[]byte("b5"),
[]byte("b6"),
}
for _, element := range results {
body, ok := queue.Take()
assert.True(t, ok)
assert.Equal(t, string(element), string(body.([]byte)))
}
}

View File

@@ -2,6 +2,7 @@ package collection
import ( import (
"container/list" "container/list"
"errors"
"fmt" "fmt"
"time" "time"
@@ -12,6 +13,11 @@ import (
const drainWorkers = 8 const drainWorkers = 8
var (
ErrClosed = errors.New("TimingWheel is closed already")
ErrArgument = errors.New("incorrect task argument")
)
type ( type (
// Execute defines the method to execute the task. // Execute defines the method to execute the task.
Execute func(key, value interface{}) Execute func(key, value interface{})
@@ -59,14 +65,15 @@ type (
// NewTimingWheel returns a TimingWheel. // NewTimingWheel returns a TimingWheel.
func NewTimingWheel(interval time.Duration, numSlots int, execute Execute) (*TimingWheel, error) { func NewTimingWheel(interval time.Duration, numSlots int, execute Execute) (*TimingWheel, error) {
if interval <= 0 || numSlots <= 0 || execute == nil { if interval <= 0 || numSlots <= 0 || execute == nil {
return nil, fmt.Errorf("interval: %v, slots: %d, execute: %p", interval, numSlots, execute) return nil, fmt.Errorf("interval: %v, slots: %d, execute: %p",
interval, numSlots, execute)
} }
return newTimingWheelWithClock(interval, numSlots, execute, timex.NewTicker(interval)) return newTimingWheelWithClock(interval, numSlots, execute, timex.NewTicker(interval))
} }
func newTimingWheelWithClock(interval time.Duration, numSlots int, execute Execute, ticker timex.Ticker) ( func newTimingWheelWithClock(interval time.Duration, numSlots int, execute Execute,
*TimingWheel, error) { ticker timex.Ticker) (*TimingWheel, error) {
tw := &TimingWheel{ tw := &TimingWheel{
interval: interval, interval: interval,
ticker: ticker, ticker: ticker,
@@ -89,47 +96,67 @@ func newTimingWheelWithClock(interval time.Duration, numSlots int, execute Execu
} }
// Drain drains all items and executes them. // Drain drains all items and executes them.
func (tw *TimingWheel) Drain(fn func(key, value interface{})) { func (tw *TimingWheel) Drain(fn func(key, value interface{})) error {
tw.drainChannel <- fn select {
case tw.drainChannel <- fn:
return nil
case <-tw.stopChannel:
return ErrClosed
}
} }
// MoveTimer moves the task with the given key to the given delay. // MoveTimer moves the task with the given key to the given delay.
func (tw *TimingWheel) MoveTimer(key interface{}, delay time.Duration) { func (tw *TimingWheel) MoveTimer(key interface{}, delay time.Duration) error {
if delay <= 0 || key == nil { if delay <= 0 || key == nil {
return return ErrArgument
} }
tw.moveChannel <- baseEntry{ select {
case tw.moveChannel <- baseEntry{
delay: delay, delay: delay,
key: key, key: key,
}:
return nil
case <-tw.stopChannel:
return ErrClosed
} }
} }
// RemoveTimer removes the task with the given key. // RemoveTimer removes the task with the given key.
func (tw *TimingWheel) RemoveTimer(key interface{}) { func (tw *TimingWheel) RemoveTimer(key interface{}) error {
if key == nil { if key == nil {
return return ErrArgument
} }
tw.removeChannel <- key select {
case tw.removeChannel <- key:
return nil
case <-tw.stopChannel:
return ErrClosed
}
} }
// SetTimer sets the task value with the given key to the delay. // SetTimer sets the task value with the given key to the delay.
func (tw *TimingWheel) SetTimer(key, value interface{}, delay time.Duration) { func (tw *TimingWheel) SetTimer(key, value interface{}, delay time.Duration) error {
if delay <= 0 || key == nil { if delay <= 0 || key == nil {
return return ErrArgument
} }
tw.setChannel <- timingEntry{ select {
case tw.setChannel <- timingEntry{
baseEntry: baseEntry{ baseEntry: baseEntry{
delay: delay, delay: delay,
key: key, key: key,
}, },
value: value, value: value,
}:
return nil
case <-tw.stopChannel:
return ErrClosed
} }
} }
// Stop stops tw. // Stop stops tw. No more actions after stopping a TimingWheel.
func (tw *TimingWheel) Stop() { func (tw *TimingWheel) Stop() {
close(tw.stopChannel) close(tw.stopChannel)
} }

View File

@@ -28,7 +28,6 @@ func TestTimingWheel_Drain(t *testing.T) {
ticker := timex.NewFakeTicker() ticker := timex.NewFakeTicker()
tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) { tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {
}, ticker) }, ticker)
defer tw.Stop()
tw.SetTimer("first", 3, testStep*4) tw.SetTimer("first", 3, testStep*4)
tw.SetTimer("second", 5, testStep*7) tw.SetTimer("second", 5, testStep*7)
tw.SetTimer("third", 7, testStep*7) tw.SetTimer("third", 7, testStep*7)
@@ -56,6 +55,8 @@ func TestTimingWheel_Drain(t *testing.T) {
}) })
time.Sleep(time.Millisecond * 100) time.Sleep(time.Millisecond * 100)
assert.Equal(t, 0, count) assert.Equal(t, 0, count)
tw.Stop()
assert.Equal(t, ErrClosed, tw.Drain(func(key, value interface{}) {}))
} }
func TestTimingWheel_SetTimerSoon(t *testing.T) { func TestTimingWheel_SetTimerSoon(t *testing.T) {
@@ -102,6 +103,13 @@ func TestTimingWheel_SetTimerWrongDelay(t *testing.T) {
}) })
} }
func TestTimingWheel_SetTimerAfterClose(t *testing.T) {
ticker := timex.NewFakeTicker()
tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {}, ticker)
tw.Stop()
assert.Equal(t, ErrClosed, tw.SetTimer("any", 3, testStep))
}
func TestTimingWheel_MoveTimer(t *testing.T) { func TestTimingWheel_MoveTimer(t *testing.T) {
run := syncx.NewAtomicBool() run := syncx.NewAtomicBool()
ticker := timex.NewFakeTicker() ticker := timex.NewFakeTicker()
@@ -111,7 +119,6 @@ func TestTimingWheel_MoveTimer(t *testing.T) {
assert.Equal(t, 3, v.(int)) assert.Equal(t, 3, v.(int))
ticker.Done() ticker.Done()
}, ticker) }, ticker)
defer tw.Stop()
tw.SetTimer("any", 3, testStep*4) tw.SetTimer("any", 3, testStep*4)
tw.MoveTimer("any", testStep*7) tw.MoveTimer("any", testStep*7)
tw.MoveTimer("any", -testStep) tw.MoveTimer("any", -testStep)
@@ -125,6 +132,8 @@ func TestTimingWheel_MoveTimer(t *testing.T) {
} }
assert.Nil(t, ticker.Wait(waitTime)) assert.Nil(t, ticker.Wait(waitTime))
assert.True(t, run.True()) assert.True(t, run.True())
tw.Stop()
assert.Equal(t, ErrClosed, tw.MoveTimer("any", time.Millisecond))
} }
func TestTimingWheel_MoveTimerSoon(t *testing.T) { func TestTimingWheel_MoveTimerSoon(t *testing.T) {
@@ -175,6 +184,7 @@ func TestTimingWheel_RemoveTimer(t *testing.T) {
ticker.Tick() ticker.Tick()
} }
tw.Stop() tw.Stop()
assert.Equal(t, ErrClosed, tw.RemoveTimer("any"))
} }
func TestTimingWheel_SetTimer(t *testing.T) { func TestTimingWheel_SetTimer(t *testing.T) {

View File

@@ -2,6 +2,13 @@ package discov
import "errors" import "errors"
var (
// errEmptyEtcdHosts indicates that etcd hosts are empty.
errEmptyEtcdHosts = errors.New("empty etcd hosts")
// errEmptyEtcdKey indicates that etcd key is empty.
errEmptyEtcdKey = errors.New("empty etcd key")
)
// EtcdConf is the config item with the given key on etcd. // EtcdConf is the config item with the given key on etcd.
type EtcdConf struct { type EtcdConf struct {
Hosts []string Hosts []string
@@ -27,9 +34,9 @@ func (c EtcdConf) HasTLS() bool {
// Validate validates c. // Validate validates c.
func (c EtcdConf) Validate() error { func (c EtcdConf) Validate() error {
if len(c.Hosts) == 0 { if len(c.Hosts) == 0 {
return errors.New("empty etcd hosts") return errEmptyEtcdHosts
} else if len(c.Key) == 0 { } else if len(c.Key) == 0 {
return errors.New("empty etcd key") return errEmptyEtcdKey
} else { } else {
return nil return nil
} }

View File

@@ -11,10 +11,12 @@ type (
errorArray []error errorArray []error
) )
// Add adds err to be. // Add adds errs to be, nil errors are ignored.
func (be *BatchError) Add(err error) { func (be *BatchError) Add(errs ...error) {
if err != nil { for _, err := range errs {
be.errs = append(be.errs, err) if err != nil {
be.errs = append(be.errs, err)
}
} }
} }

View File

@@ -5,6 +5,9 @@ import (
"os" "os"
) )
// errExceedFileSize indicates that the file size is exceeded.
var errExceedFileSize = errors.New("exceed file size")
// A RangeReader is used to read a range of content from a file. // A RangeReader is used to read a range of content from a file.
type RangeReader struct { type RangeReader struct {
file *os.File file *os.File
@@ -29,7 +32,7 @@ func (rr *RangeReader) Read(p []byte) (n int, err error) {
} }
if rr.stop < rr.start || rr.start >= stat.Size() { if rr.stop < rr.start || rr.start >= stat.Size() {
return 0, errors.New("exceed file size") return 0, errExceedFileSize
} }
if rr.stop-rr.start < int64(len(p)) { if rr.stop-rr.start < int64(len(p)) {

View File

@@ -14,6 +14,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/stringx" "github.com/zeromicro/go-zero/core/stringx"
"go.uber.org/goleak"
) )
func TestBuffer(t *testing.T) { func TestBuffer(t *testing.T) {
@@ -563,9 +564,6 @@ func equal(t *testing.T, stream Stream, data []interface{}) {
} }
func runCheckedTest(t *testing.T, fn func(t *testing.T)) { func runCheckedTest(t *testing.T, fn func(t *testing.T)) {
goroutines := runtime.NumGoroutine() defer goleak.VerifyNone(t)
fn(t) fn(t)
// let scheduler schedule first
time.Sleep(time.Millisecond)
assert.True(t, runtime.NumGoroutine() <= goroutines)
} }

View File

@@ -51,5 +51,5 @@ func unmarshalUseNumber(decoder *json.Decoder, v interface{}) error {
} }
func formatError(v string, err error) error { func formatError(v string, err error) error {
return fmt.Errorf("string: `%s`, error: `%s`", v, err.Error()) return fmt.Errorf("string: `%s`, error: `%w`", v, err)
} }

View File

@@ -8,23 +8,20 @@ import (
"github.com/zeromicro/go-zero/core/stores/redis" "github.com/zeromicro/go-zero/core/stores/redis"
) )
const ( // to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key
// to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key const periodScript = `local limit = tonumber(ARGV[1])
periodScript = `local limit = tonumber(ARGV[1])
local window = tonumber(ARGV[2]) local window = tonumber(ARGV[2])
local current = redis.call("INCRBY", KEYS[1], 1) local current = redis.call("INCRBY", KEYS[1], 1)
if current == 1 then if current == 1 then
redis.call("expire", KEYS[1], window) redis.call("expire", KEYS[1], window)
return 1 end
elseif current < limit then if current < limit then
return 1 return 1
elseif current == limit then elseif current == limit then
return 2 return 2
else else
return 0 return 0
end` end`
zoneDiff = 3600 * 8 // GMT+8 for our services
)
const ( const (
// Unknown means not initialized state. // Unknown means not initialized state.
@@ -104,7 +101,9 @@ func (h *PeriodLimit) Take(key string) (int, error) {
func (h *PeriodLimit) calcExpireSeconds() int { func (h *PeriodLimit) calcExpireSeconds() int {
if h.align { if h.align {
unix := time.Now().Unix() + zoneDiff now := time.Now()
_, offset := now.Zone()
unix := now.Unix() + int64(offset)
return h.period - int(unix%int64(h.period)) return h.period - int(unix%int64(h.period))
} }
@@ -112,6 +111,8 @@ func (h *PeriodLimit) calcExpireSeconds() int {
} }
// Align returns a func to customize a PeriodLimit with alignment. // Align returns a func to customize a PeriodLimit with alignment.
// For example, if we want to limit end users with 5 sms verification messages every day,
// we need to align with the local timezone and the start of the day.
func Align() PeriodOption { func Align() PeriodOption {
return func(l *PeriodLimit) { return func(l *PeriodLimit) {
l.align = true l.align = true

View File

@@ -23,10 +23,9 @@ func TestPeriodLimit_RedisUnavailable(t *testing.T) {
const ( const (
seconds = 1 seconds = 1
total = 100
quota = 5 quota = 5
) )
l := NewPeriodLimit(seconds, quota, redis.NewRedis(s.Addr(), redis.NodeType), "periodlimit") l := NewPeriodLimit(seconds, quota, redis.New(s.Addr()), "periodlimit")
s.Close() s.Close()
val, err := l.Take("first") val, err := l.Take("first")
assert.NotNil(t, err) assert.NotNil(t, err)
@@ -66,3 +65,13 @@ func testPeriodLimit(t *testing.T, opts ...PeriodOption) {
assert.Equal(t, 1, hitQuota) assert.Equal(t, 1, hitQuota)
assert.Equal(t, total-quota, overQuota) assert.Equal(t, total-quota, overQuota)
} }
func TestQuotaFull(t *testing.T) {
s, err := miniredis.Run()
assert.Nil(t, err)
l := NewPeriodLimit(1, 1, redis.New(s.Addr()), "periodlimit")
val, err := l.Take("first")
assert.Nil(t, err)
assert.Equal(t, HitQuota, val)
}

View File

@@ -85,8 +85,8 @@ func (lim *TokenLimiter) Allow() bool {
} }
// AllowN reports whether n events may happen at time now. // AllowN reports whether n events may happen at time now.
// Use this method if you intend to drop / skip events that exceed the rate rate. // Use this method if you intend to drop / skip events that exceed the rate.
// Otherwise use Reserve or Wait. // Otherwise, use Reserve or Wait.
func (lim *TokenLimiter) AllowN(now time.Time, n int) bool { func (lim *TokenLimiter) AllowN(now time.Time, n int) bool {
return lim.reserveN(now, n) return lim.reserveN(now, n)
} }
@@ -112,7 +112,8 @@ func (lim *TokenLimiter) reserveN(now time.Time, n int) bool {
// Lua boolean false -> r Nil bulk reply // Lua boolean false -> r Nil bulk reply
if err == redis.Nil { if err == redis.Nil {
return false return false
} else if err != nil { }
if err != nil {
logx.Errorf("fail to use rate limiter: %s, use in-process limiter for rescue", err) logx.Errorf("fail to use rate limiter: %s, use in-process limiter for rescue", err)
lim.startMonitor() lim.startMonitor()
return lim.rescueLimiter.AllowN(now, n) return lim.rescueLimiter.AllowN(now, n)

View File

@@ -3,10 +3,11 @@ package logx
// A LogConf is a logging config. // A LogConf is a logging config.
type LogConf struct { type LogConf struct {
ServiceName string `json:",optional"` ServiceName string `json:",optional"`
Mode string `json:",default=console,options=console|file|volume"` Mode string `json:",default=console,options=[console,file,volume]"`
Encoding string `json:",default=json,options=[json,plain]"`
TimeFormat string `json:",optional"` TimeFormat string `json:",optional"`
Path string `json:",default=logs"` Path string `json:",default=logs"`
Level string `json:",default=info,options=info|error|severe"` Level string `json:",default=info,options=[info,error,severe]"`
Compress bool `json:",optional"` Compress bool `json:",optional"`
KeepDays int `json:",optional"` KeepDays int `json:",optional"`
StackCooldownMillis int `json:",default=100"` StackCooldownMillis int `json:",default=100"`

View File

@@ -3,6 +3,7 @@ package logx
import ( import (
"fmt" "fmt"
"io" "io"
"sync/atomic"
"time" "time"
"github.com/zeromicro/go-zero/core/timex" "github.com/zeromicro/go-zero/core/timex"
@@ -79,10 +80,15 @@ func (l *durationLogger) WithDuration(duration time.Duration) Logger {
} }
func (l *durationLogger) write(writer io.Writer, level string, val interface{}) { func (l *durationLogger) write(writer io.Writer, level string, val interface{}) {
outputJson(writer, &durationLogger{ switch atomic.LoadUint32(&encoding) {
Timestamp: getTimestamp(), case plainEncodingType:
Level: level, writePlainAny(writer, level, val, l.Duration)
Content: val, default:
Duration: l.Duration, outputJson(writer, &durationLogger{
}) Timestamp: getTimestamp(),
Level: level,
Content: val,
Duration: l.Duration,
})
}
} }

View File

@@ -3,6 +3,7 @@ package logx
import ( import (
"log" "log"
"strings" "strings"
"sync/atomic"
"testing" "testing"
"time" "time"
@@ -37,6 +38,19 @@ func TestWithDurationInfo(t *testing.T) {
assert.True(t, strings.Contains(builder.String(), "duration"), builder.String()) assert.True(t, strings.Contains(builder.String(), "duration"), builder.String())
} }
func TestWithDurationInfoConsole(t *testing.T) {
old := atomic.LoadUint32(&encoding)
atomic.StoreUint32(&encoding, plainEncodingType)
defer func() {
atomic.StoreUint32(&encoding, old)
}()
var builder strings.Builder
log.SetOutput(&builder)
WithDuration(time.Second).Info("foo")
assert.True(t, strings.Contains(builder.String(), "ms"), builder.String())
}
func TestWithDurationInfof(t *testing.T) { func TestWithDurationInfof(t *testing.T) {
var builder strings.Builder var builder strings.Builder
log.SetOutput(&builder) log.SetOutput(&builder)

View File

@@ -1,6 +1,7 @@
package logx package logx
import ( import (
"bytes"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@@ -31,6 +32,15 @@ const (
SevereLevel SevereLevel
) )
const (
jsonEncodingType = iota
plainEncodingType
jsonEncoding = "json"
plainEncoding = "plain"
plainEncodingSep = '\t'
)
const ( const (
accessFilename = "access.log" accessFilename = "access.log"
errorFilename = "error.log" errorFilename = "error.log"
@@ -62,9 +72,10 @@ var (
// ErrLogServiceNameNotSet is an error that indicates that the service name is not set. // ErrLogServiceNameNotSet is an error that indicates that the service name is not set.
ErrLogServiceNameNotSet = errors.New("log service name must be set") ErrLogServiceNameNotSet = errors.New("log service name must be set")
timeFormat = "2006-01-02T15:04:05.000Z07" timeFormat = "2006-01-02T15:04:05.000Z07:00"
writeConsole bool writeConsole bool
logLevel uint32 logLevel uint32
encoding uint32 = jsonEncodingType
// use uint32 for atomic operations // use uint32 for atomic operations
disableStat uint32 disableStat uint32
infoLog io.WriteCloser infoLog io.WriteCloser
@@ -124,6 +135,12 @@ func SetUp(c LogConf) error {
if len(c.TimeFormat) > 0 { if len(c.TimeFormat) > 0 {
timeFormat = c.TimeFormat timeFormat = c.TimeFormat
} }
switch c.Encoding {
case plainEncoding:
atomic.StoreUint32(&encoding, plainEncodingType)
default:
atomic.StoreUint32(&encoding, jsonEncodingType)
}
switch c.Mode { switch c.Mode {
case consoleMode: case consoleMode:
@@ -258,7 +275,7 @@ func Infov(v interface{}) {
infoAnySync(v) infoAnySync(v)
} }
// Must checks if err is nil, otherwise logs the err and exits. // Must checks if err is nil, otherwise logs the error and exits.
func Must(err error) { func Must(err error) {
if err != nil { if err != nil {
msg := formatWithCaller(err.Error(), 3) msg := formatWithCaller(err.Error(), 3)
@@ -407,21 +424,31 @@ func infoTextSync(msg string) {
} }
func outputAny(writer io.Writer, level string, val interface{}) { func outputAny(writer io.Writer, level string, val interface{}) {
info := logEntry{ switch atomic.LoadUint32(&encoding) {
Timestamp: getTimestamp(), case plainEncodingType:
Level: level, writePlainAny(writer, level, val)
Content: val, default:
info := logEntry{
Timestamp: getTimestamp(),
Level: level,
Content: val,
}
outputJson(writer, info)
} }
outputJson(writer, info)
} }
func outputText(writer io.Writer, level, msg string) { func outputText(writer io.Writer, level, msg string) {
info := logEntry{ switch atomic.LoadUint32(&encoding) {
Timestamp: getTimestamp(), case plainEncodingType:
Level: level, writePlainText(writer, level, msg)
Content: msg, default:
info := logEntry{
Timestamp: getTimestamp(),
Level: level,
Content: msg,
}
outputJson(writer, info)
} }
outputJson(writer, info)
} }
func outputError(writer io.Writer, msg string, callDepth int) { func outputError(writer io.Writer, msg string, callDepth int) {
@@ -565,6 +592,62 @@ func statSync(msg string) {
} }
} }
func writePlainAny(writer io.Writer, level string, val interface{}, fields ...string) {
switch v := val.(type) {
case string:
writePlainText(writer, level, v, fields...)
case error:
writePlainText(writer, level, v.Error(), fields...)
case fmt.Stringer:
writePlainText(writer, level, v.String(), fields...)
default:
var buf bytes.Buffer
buf.WriteString(getTimestamp())
buf.WriteByte(plainEncodingSep)
buf.WriteString(level)
for _, item := range fields {
buf.WriteByte(plainEncodingSep)
buf.WriteString(item)
}
buf.WriteByte(plainEncodingSep)
if err := json.NewEncoder(&buf).Encode(val); err != nil {
log.Println(err.Error())
return
}
buf.WriteByte('\n')
if atomic.LoadUint32(&initialized) == 0 || writer == nil {
log.Println(buf.String())
return
}
if _, err := writer.Write(buf.Bytes()); err != nil {
log.Println(err.Error())
}
}
}
func writePlainText(writer io.Writer, level, msg string, fields ...string) {
var buf bytes.Buffer
buf.WriteString(getTimestamp())
buf.WriteByte(plainEncodingSep)
buf.WriteString(level)
for _, item := range fields {
buf.WriteByte(plainEncodingSep)
buf.WriteString(item)
}
buf.WriteByte(plainEncodingSep)
buf.WriteString(msg)
buf.WriteByte('\n')
if atomic.LoadUint32(&initialized) == 0 || writer == nil {
log.Println(buf.String())
return
}
if _, err := writer.Write(buf.Bytes()); err != nil {
log.Println(err.Error())
}
}
type logWriter struct { type logWriter struct {
logger *log.Logger logger *log.Logger
} }

View File

@@ -141,6 +141,78 @@ func TestStructedLogInfov(t *testing.T) {
}) })
} }
func TestStructedLogInfoConsoleAny(t *testing.T) {
doTestStructedLogConsole(t, func(writer io.WriteCloser) {
infoLog = writer
}, func(v ...interface{}) {
old := atomic.LoadUint32(&encoding)
atomic.StoreUint32(&encoding, plainEncodingType)
defer func() {
atomic.StoreUint32(&encoding, old)
}()
Infov(v)
})
}
func TestStructedLogInfoConsoleAnyString(t *testing.T) {
doTestStructedLogConsole(t, func(writer io.WriteCloser) {
infoLog = writer
}, func(v ...interface{}) {
old := atomic.LoadUint32(&encoding)
atomic.StoreUint32(&encoding, plainEncodingType)
defer func() {
atomic.StoreUint32(&encoding, old)
}()
Infov(fmt.Sprint(v...))
})
}
func TestStructedLogInfoConsoleAnyError(t *testing.T) {
doTestStructedLogConsole(t, func(writer io.WriteCloser) {
infoLog = writer
}, func(v ...interface{}) {
old := atomic.LoadUint32(&encoding)
atomic.StoreUint32(&encoding, plainEncodingType)
defer func() {
atomic.StoreUint32(&encoding, old)
}()
Infov(errors.New(fmt.Sprint(v...)))
})
}
func TestStructedLogInfoConsoleAnyStringer(t *testing.T) {
doTestStructedLogConsole(t, func(writer io.WriteCloser) {
infoLog = writer
}, func(v ...interface{}) {
old := atomic.LoadUint32(&encoding)
atomic.StoreUint32(&encoding, plainEncodingType)
defer func() {
atomic.StoreUint32(&encoding, old)
}()
Infov(ValStringer{
val: fmt.Sprint(v...),
})
})
}
func TestStructedLogInfoConsoleText(t *testing.T) {
doTestStructedLogConsole(t, func(writer io.WriteCloser) {
infoLog = writer
}, func(v ...interface{}) {
old := atomic.LoadUint32(&encoding)
atomic.StoreUint32(&encoding, plainEncodingType)
defer func() {
atomic.StoreUint32(&encoding, old)
}()
Info(fmt.Sprint(v...))
})
}
func TestStructedLogSlow(t *testing.T) { func TestStructedLogSlow(t *testing.T) {
doTestStructedLog(t, levelSlow, func(writer io.WriteCloser) { doTestStructedLog(t, levelSlow, func(writer io.WriteCloser) {
slowLog = writer slowLog = writer
@@ -432,6 +504,17 @@ func doTestStructedLog(t *testing.T, level string, setup func(writer io.WriteClo
assert.True(t, strings.Contains(val, message)) assert.True(t, strings.Contains(val, message))
} }
func doTestStructedLogConsole(t *testing.T, setup func(writer io.WriteCloser),
write func(...interface{})) {
const message = "hello there"
writer := new(mockWriter)
setup(writer)
atomic.StoreUint32(&initialized, 1)
write(message)
println(writer.String())
assert.True(t, strings.Contains(writer.String(), message))
}
func testSetLevelTwiceWithMode(t *testing.T, mode string) { func testSetLevelTwiceWithMode(t *testing.T, mode string) {
SetUp(LogConf{ SetUp(LogConf{
Mode: mode, Mode: mode,
@@ -456,3 +539,11 @@ func testSetLevelTwiceWithMode(t *testing.T, mode string) {
ErrorStackf(message) ErrorStackf(message)
assert.Equal(t, 0, writer.builder.Len()) assert.Equal(t, 0, writer.builder.Len())
} }
type ValStringer struct {
val string
}
func (v ValStringer) String() string {
return v.val
}

View File

@@ -29,9 +29,9 @@ func TestRedirector(t *testing.T) {
} }
func captureOutput(f func()) string { func captureOutput(f func()) string {
atomic.StoreUint32(&initialized, 1)
writer := new(mockWriter) writer := new(mockWriter)
infoLog = writer infoLog = writer
atomic.StoreUint32(&initialized, 1)
prevLevel := atomic.LoadUint32(&logLevel) prevLevel := atomic.LoadUint32(&logLevel)
SetLevel(InfoLevel) SetLevel(InfoLevel)
@@ -44,5 +44,9 @@ func captureOutput(f func()) string {
func getContent(jsonStr string) string { func getContent(jsonStr string) string {
var entry logEntry var entry logEntry
json.Unmarshal([]byte(jsonStr), &entry) json.Unmarshal([]byte(jsonStr), &entry)
return entry.Content.(string) val, ok := entry.Content.(string)
if ok {
return val
}
return ""
} }

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"io" "io"
"sync/atomic"
"time" "time"
"github.com/zeromicro/go-zero/core/timex" "github.com/zeromicro/go-zero/core/timex"
@@ -77,16 +78,24 @@ func (l *traceLogger) WithDuration(duration time.Duration) Logger {
} }
func (l *traceLogger) write(writer io.Writer, level string, val interface{}) { func (l *traceLogger) write(writer io.Writer, level string, val interface{}) {
outputJson(writer, &traceLogger{ traceID := traceIdFromContext(l.ctx)
logEntry: logEntry{ spanID := spanIdFromContext(l.ctx)
Timestamp: getTimestamp(),
Level: level, switch atomic.LoadUint32(&encoding) {
Duration: l.Duration, case plainEncodingType:
Content: val, writePlainAny(writer, level, val, l.Duration, traceID, spanID)
}, default:
Trace: traceIdFromContext(l.ctx), outputJson(writer, &traceLogger{
Span: spanIdFromContext(l.ctx), logEntry: logEntry{
}) Timestamp: getTimestamp(),
Level: level,
Duration: l.Duration,
Content: val,
},
Trace: traceID,
Span: spanID,
})
}
} }
// WithContext sets ctx to log, for keeping tracing information. // WithContext sets ctx to log, for keeping tracing information.

View File

@@ -82,6 +82,37 @@ func TestTraceInfo(t *testing.T) {
assert.True(t, strings.Contains(buf.String(), spanKey)) assert.True(t, strings.Contains(buf.String(), spanKey))
} }
func TestTraceInfoConsole(t *testing.T) {
old := atomic.LoadUint32(&encoding)
atomic.StoreUint32(&encoding, jsonEncodingType)
defer func() {
atomic.StoreUint32(&encoding, old)
}()
var buf mockWriter
atomic.StoreUint32(&initialized, 1)
infoLog = newLogWriter(log.New(&buf, "", flags))
otp := otel.GetTracerProvider()
tp := sdktrace.NewTracerProvider(sdktrace.WithSampler(sdktrace.AlwaysSample()))
otel.SetTracerProvider(tp)
defer otel.SetTracerProvider(otp)
ctx, _ := tp.Tracer("foo").Start(context.Background(), "bar")
l := WithContext(ctx).(*traceLogger)
SetLevel(InfoLevel)
l.WithDuration(time.Second).Info(testlog)
assert.True(t, strings.Contains(buf.String(), traceIdFromContext(ctx)))
assert.True(t, strings.Contains(buf.String(), spanIdFromContext(ctx)))
buf.Reset()
l.WithDuration(time.Second).Infof(testlog)
assert.True(t, strings.Contains(buf.String(), traceIdFromContext(ctx)))
assert.True(t, strings.Contains(buf.String(), spanIdFromContext(ctx)))
buf.Reset()
l.WithDuration(time.Second).Infov(testlog)
assert.True(t, strings.Contains(buf.String(), traceIdFromContext(ctx)))
assert.True(t, strings.Contains(buf.String(), spanIdFromContext(ctx)))
}
func TestTraceSlow(t *testing.T) { func TestTraceSlow(t *testing.T) {
var buf mockWriter var buf mockWriter
atomic.StoreUint32(&initialized, 1) atomic.StoreUint32(&initialized, 1)

186
core/mapping/marshaler.go Normal file
View File

@@ -0,0 +1,186 @@
package mapping
import (
"fmt"
"reflect"
"strings"
)
const (
emptyTag = ""
tagKVSeparator = ":"
)
// Marshal marshals the given val and returns the map that contains the fields.
// optional=another is not implemented, and it's hard to implement and not common used.
func Marshal(val interface{}) (map[string]map[string]interface{}, error) {
ret := make(map[string]map[string]interface{})
tp := reflect.TypeOf(val)
if tp.Kind() == reflect.Ptr {
tp = tp.Elem()
}
rv := reflect.ValueOf(val)
if rv.Kind() == reflect.Ptr {
rv = rv.Elem()
}
for i := 0; i < tp.NumField(); i++ {
field := tp.Field(i)
value := rv.Field(i)
if err := processMember(field, value, ret); err != nil {
return nil, err
}
}
return ret, nil
}
func getTag(field reflect.StructField) (string, bool) {
tag := string(field.Tag)
if i := strings.Index(tag, tagKVSeparator); i >= 0 {
return strings.TrimSpace(tag[:i]), true
}
return strings.TrimSpace(tag), false
}
func processMember(field reflect.StructField, value reflect.Value,
collector map[string]map[string]interface{}) error {
var key string
var opt *fieldOptions
var err error
tag, ok := getTag(field)
if !ok {
tag = emptyTag
key = field.Name
} else {
key, opt, err = parseKeyAndOptions(tag, field)
if err != nil {
return err
}
if err = validate(field, value, opt); err != nil {
return err
}
}
val := value.Interface()
if opt != nil && opt.FromString {
val = fmt.Sprint(val)
}
m, ok := collector[tag]
if ok {
m[key] = val
} else {
m = map[string]interface{}{
key: val,
}
}
collector[tag] = m
return nil
}
func validate(field reflect.StructField, value reflect.Value, opt *fieldOptions) error {
if opt == nil || !opt.Optional {
if err := validateOptional(field, value); err != nil {
return err
}
}
if opt == nil {
return nil
}
if opt.Optional && value.IsZero() {
return nil
}
if len(opt.Options) > 0 {
if err := validateOptions(value, opt); err != nil {
return err
}
}
if opt.Range != nil {
if err := validateRange(value, opt); err != nil {
return err
}
}
return nil
}
func validateOptional(field reflect.StructField, value reflect.Value) error {
switch field.Type.Kind() {
case reflect.Ptr:
if value.IsNil() {
return fmt.Errorf("field %q is nil", field.Name)
}
case reflect.Array, reflect.Slice, reflect.Map:
if value.IsNil() || value.Len() == 0 {
return fmt.Errorf("field %q is empty", field.Name)
}
}
return nil
}
func validateOptions(value reflect.Value, opt *fieldOptions) error {
var found bool
val := fmt.Sprint(value.Interface())
for i := range opt.Options {
if opt.Options[i] == val {
found = true
break
}
}
if !found {
return fmt.Errorf("field %q not in options", val)
}
return nil
}
func validateRange(value reflect.Value, opt *fieldOptions) error {
var val float64
switch v := value.Interface().(type) {
case int:
val = float64(v)
case int8:
val = float64(v)
case int16:
val = float64(v)
case int32:
val = float64(v)
case int64:
val = float64(v)
case uint:
val = float64(v)
case uint8:
val = float64(v)
case uint16:
val = float64(v)
case uint32:
val = float64(v)
case uint64:
val = float64(v)
case float32:
val = float64(v)
case float64:
val = v
default:
return fmt.Errorf("unknown support type for range %q", value.Type().String())
}
// validates [left, right], [left, right), (left, right], (left, right)
if val < opt.Range.left ||
(!opt.Range.leftInclude && val == opt.Range.left) ||
val > opt.Range.right ||
(!opt.Range.rightInclude && val == opt.Range.right) {
return fmt.Errorf("%v out of range", value.Interface())
}
return nil
}

View File

@@ -0,0 +1,274 @@
package mapping
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestMarshal(t *testing.T) {
v := struct {
Name string `path:"name"`
Address string `json:"address,options=[beijing,shanghai]"`
Age int `json:"age"`
Anonymous bool
}{
Name: "kevin",
Address: "shanghai",
Age: 20,
Anonymous: true,
}
m, err := Marshal(v)
assert.Nil(t, err)
assert.Equal(t, "kevin", m["path"]["name"])
assert.Equal(t, "shanghai", m["json"]["address"])
assert.Equal(t, 20, m["json"]["age"].(int))
assert.True(t, m[emptyTag]["Anonymous"].(bool))
}
func TestMarshal_Ptr(t *testing.T) {
v := &struct {
Name string `path:"name"`
Address string `json:"address,options=[beijing,shanghai]"`
Age int `json:"age"`
Anonymous bool
}{
Name: "kevin",
Address: "shanghai",
Age: 20,
Anonymous: true,
}
m, err := Marshal(v)
assert.Nil(t, err)
assert.Equal(t, "kevin", m["path"]["name"])
assert.Equal(t, "shanghai", m["json"]["address"])
assert.Equal(t, 20, m["json"]["age"].(int))
assert.True(t, m[emptyTag]["Anonymous"].(bool))
}
func TestMarshal_OptionalPtr(t *testing.T) {
var val = 1
v := struct {
Age *int `json:"age"`
}{
Age: &val,
}
m, err := Marshal(v)
assert.Nil(t, err)
assert.Equal(t, 1, *m["json"]["age"].(*int))
}
func TestMarshal_OptionalPtrNil(t *testing.T) {
v := struct {
Age *int `json:"age"`
}{}
_, err := Marshal(v)
assert.NotNil(t, err)
}
func TestMarshal_BadOptions(t *testing.T) {
v := struct {
Name string `json:"name,options"`
}{
Name: "kevin",
}
_, err := Marshal(v)
assert.NotNil(t, err)
}
func TestMarshal_NotInOptions(t *testing.T) {
v := struct {
Name string `json:"name,options=[a,b]"`
}{
Name: "kevin",
}
_, err := Marshal(v)
assert.NotNil(t, err)
}
func TestMarshal_NotInOptionsOptional(t *testing.T) {
v := struct {
Name string `json:"name,options=[a,b],optional"`
}{}
_, err := Marshal(v)
assert.Nil(t, err)
}
func TestMarshal_NotInOptionsOptionalWrongValue(t *testing.T) {
v := struct {
Name string `json:"name,options=[a,b],optional"`
}{
Name: "kevin",
}
_, err := Marshal(v)
assert.NotNil(t, err)
}
func TestMarshal_Nested(t *testing.T) {
type address struct {
Country string `json:"country"`
City string `json:"city"`
}
v := struct {
Name string `json:"name,options=[kevin,wan]"`
Address address `json:"address"`
}{
Name: "kevin",
Address: address{
Country: "China",
City: "Shanghai",
},
}
m, err := Marshal(v)
assert.Nil(t, err)
assert.Equal(t, "kevin", m["json"]["name"])
assert.Equal(t, "China", m["json"]["address"].(address).Country)
assert.Equal(t, "Shanghai", m["json"]["address"].(address).City)
}
func TestMarshal_NestedPtr(t *testing.T) {
type address struct {
Country string `json:"country"`
City string `json:"city"`
}
v := struct {
Name string `json:"name,options=[kevin,wan]"`
Address *address `json:"address"`
}{
Name: "kevin",
Address: &address{
Country: "China",
City: "Shanghai",
},
}
m, err := Marshal(v)
assert.Nil(t, err)
assert.Equal(t, "kevin", m["json"]["name"])
assert.Equal(t, "China", m["json"]["address"].(*address).Country)
assert.Equal(t, "Shanghai", m["json"]["address"].(*address).City)
}
func TestMarshal_Slice(t *testing.T) {
v := struct {
Name []string `json:"name"`
}{
Name: []string{"kevin", "wan"},
}
m, err := Marshal(v)
assert.Nil(t, err)
assert.ElementsMatch(t, []string{"kevin", "wan"}, m["json"]["name"].([]string))
}
func TestMarshal_SliceNil(t *testing.T) {
v := struct {
Name []string `json:"name"`
}{
Name: nil,
}
_, err := Marshal(v)
assert.NotNil(t, err)
}
func TestMarshal_Range(t *testing.T) {
v := struct {
Int int `json:"int,range=[1:3]"`
Int8 int8 `json:"int8,range=[1:3)"`
Int16 int16 `json:"int16,range=(1:3]"`
Int32 int32 `json:"int32,range=(1:3)"`
Int64 int64 `json:"int64,range=(1:3)"`
Uint uint `json:"uint,range=[1:3]"`
Uint8 uint8 `json:"uint8,range=[1:3)"`
Uint16 uint16 `json:"uint16,range=(1:3]"`
Uint32 uint32 `json:"uint32,range=(1:3)"`
Uint64 uint64 `json:"uint64,range=(1:3)"`
Float32 float32 `json:"float32,range=(1:3)"`
Float64 float64 `json:"float64,range=(1:3)"`
}{
Int: 1,
Int8: 1,
Int16: 2,
Int32: 2,
Int64: 2,
Uint: 1,
Uint8: 1,
Uint16: 2,
Uint32: 2,
Uint64: 2,
Float32: 2,
Float64: 2,
}
m, err := Marshal(v)
assert.Nil(t, err)
assert.Equal(t, 1, m["json"]["int"].(int))
assert.Equal(t, int8(1), m["json"]["int8"].(int8))
assert.Equal(t, int16(2), m["json"]["int16"].(int16))
assert.Equal(t, int32(2), m["json"]["int32"].(int32))
assert.Equal(t, int64(2), m["json"]["int64"].(int64))
assert.Equal(t, uint(1), m["json"]["uint"].(uint))
assert.Equal(t, uint8(1), m["json"]["uint8"].(uint8))
assert.Equal(t, uint16(2), m["json"]["uint16"].(uint16))
assert.Equal(t, uint32(2), m["json"]["uint32"].(uint32))
assert.Equal(t, uint64(2), m["json"]["uint64"].(uint64))
assert.Equal(t, float32(2), m["json"]["float32"].(float32))
assert.Equal(t, float64(2), m["json"]["float64"].(float64))
}
func TestMarshal_RangeOut(t *testing.T) {
tests := []interface{}{
struct {
Int int `json:"int,range=[1:3]"`
}{
Int: 4,
},
struct {
Int int `json:"int,range=(1:3]"`
}{
Int: 1,
},
struct {
Int int `json:"int,range=[1:3)"`
}{
Int: 3,
},
struct {
Int int `json:"int,range=(1:3)"`
}{
Int: 3,
},
struct {
Bool bool `json:"bool,range=(1:3)"`
}{
Bool: true,
},
}
for _, test := range tests {
_, err := Marshal(test)
assert.NotNil(t, err)
}
}
func TestMarshal_FromString(t *testing.T) {
v := struct {
Age int `json:"age,string"`
}{
Age: 10,
}
m, err := Marshal(v)
assert.Nil(t, err)
assert.Equal(t, "10", m["json"]["age"].(string))
}

View File

@@ -97,10 +97,6 @@ func (u *Unmarshaler) unmarshalWithFullName(m Valuer, v interface{}, fullName st
numFields := rte.NumField() numFields := rte.NumField()
for i := 0; i < numFields; i++ { for i := 0; i < numFields; i++ {
field := rte.Field(i) field := rte.Field(i)
if usingDifferentKeys(u.key, field) {
continue
}
if err := u.processField(field, rve.Field(i), m, fullName); err != nil { if err := u.processField(field, rve.Field(i), m, fullName); err != nil {
return err return err
} }
@@ -275,6 +271,10 @@ func (u *Unmarshaler) processFieldPrimitiveWithJSONNumber(field reflect.StructFi
return err return err
} }
if iValue < 0 {
return fmt.Errorf("unmarshal %q with bad value %q", fullName, v.String())
}
value.SetUint(uint64(iValue)) value.SetUint(uint64(iValue))
case reflect.Float32, reflect.Float64: case reflect.Float32, reflect.Float64:
fValue, err := v.Float64() fValue, err := v.Float64()
@@ -448,7 +448,15 @@ func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value, map
dereffedBaseType := Deref(baseType) dereffedBaseType := Deref(baseType)
dereffedBaseKind := dereffedBaseType.Kind() dereffedBaseKind := dereffedBaseType.Kind()
refValue := reflect.ValueOf(mapValue) refValue := reflect.ValueOf(mapValue)
if refValue.IsNil() {
return nil
}
conv := reflect.MakeSlice(reflect.SliceOf(baseType), refValue.Len(), refValue.Cap()) conv := reflect.MakeSlice(reflect.SliceOf(baseType), refValue.Len(), refValue.Cap())
if refValue.Len() == 0 {
value.Set(conv)
return nil
}
var valid bool var valid bool
for i := 0; i < refValue.Len(); i++ { for i := 0; i < refValue.Len(); i++ {
@@ -723,10 +731,10 @@ func fillWithSameType(field reflect.StructField, value reflect.Value, mapValue i
if field.Type.Kind() == reflect.Ptr { if field.Type.Kind() == reflect.Ptr {
baseType := Deref(field.Type) baseType := Deref(field.Type)
target := reflect.New(baseType).Elem() target := reflect.New(baseType).Elem()
target.Set(reflect.ValueOf(mapValue)) setSameKindValue(baseType, target, mapValue)
value.Set(target.Addr()) value.Set(target.Addr())
} else { } else {
value.Set(reflect.ValueOf(mapValue)) setSameKindValue(field.Type, value, mapValue)
} }
return nil return nil
@@ -742,7 +750,9 @@ func getValueWithChainedKeys(m Valuer, keys []string) (interface{}, bool) {
if len(keys) == 1 { if len(keys) == 1 {
v, ok := m.Value(keys[0]) v, ok := m.Value(keys[0])
return v, ok return v, ok
} else if len(keys) > 1 { }
if len(keys) > 1 {
if v, ok := m.Value(keys[0]); ok { if v, ok := m.Value(keys[0]); ok {
if nextm, ok := v.(map[string]interface{}); ok { if nextm, ok := v.(map[string]interface{}); ok {
return getValueWithChainedKeys(MapValuer(nextm), keys[1:]) return getValueWithChainedKeys(MapValuer(nextm), keys[1:])
@@ -799,3 +809,11 @@ func readKeys(key string) []string {
return keys return keys
} }
func setSameKindValue(targetType reflect.Type, target reflect.Value, value interface{}) {
if reflect.ValueOf(value).Type().AssignableTo(targetType) {
target.Set(reflect.ValueOf(value))
} else {
target.Set(reflect.ValueOf(value).Convert(targetType))
}
}

View File

@@ -198,6 +198,49 @@ func TestUnmarshalIntWithDefault(t *testing.T) {
assert.Equal(t, 1, in.Int) assert.Equal(t, 1, in.Int)
} }
func TestUnmarshalBoolSliceRequired(t *testing.T) {
type inner struct {
Bools []bool `key:"bools"`
}
var in inner
assert.NotNil(t, UnmarshalKey(map[string]interface{}{}, &in))
}
func TestUnmarshalBoolSliceNil(t *testing.T) {
type inner struct {
Bools []bool `key:"bools,optional"`
}
var in inner
assert.Nil(t, UnmarshalKey(map[string]interface{}{}, &in))
assert.Nil(t, in.Bools)
}
func TestUnmarshalBoolSliceNilExplicit(t *testing.T) {
type inner struct {
Bools []bool `key:"bools,optional"`
}
var in inner
assert.Nil(t, UnmarshalKey(map[string]interface{}{
"bools": nil,
}, &in))
assert.Nil(t, in.Bools)
}
func TestUnmarshalBoolSliceEmpty(t *testing.T) {
type inner struct {
Bools []bool `key:"bools,optional"`
}
var in inner
assert.Nil(t, UnmarshalKey(map[string]interface{}{
"bools": []bool{},
}, &in))
assert.Empty(t, in.Bools)
}
func TestUnmarshalBoolSliceWithDefault(t *testing.T) { func TestUnmarshalBoolSliceWithDefault(t *testing.T) {
type inner struct { type inner struct {
Bools []bool `key:"bools,default=[true,false]"` Bools []bool `key:"bools,default=[true,false]"`
@@ -330,28 +373,34 @@ func TestUnmarshalFloat(t *testing.T) {
func TestUnmarshalInt64Slice(t *testing.T) { func TestUnmarshalInt64Slice(t *testing.T) {
var v struct { var v struct {
Ages []int64 `key:"ages"` Ages []int64 `key:"ages"`
Slice []int64 `key:"slice"`
} }
m := map[string]interface{}{ m := map[string]interface{}{
"ages": []int64{1, 2}, "ages": []int64{1, 2},
"slice": []interface{}{},
} }
ast := assert.New(t) ast := assert.New(t)
ast.Nil(UnmarshalKey(m, &v)) ast.Nil(UnmarshalKey(m, &v))
ast.ElementsMatch([]int64{1, 2}, v.Ages) ast.ElementsMatch([]int64{1, 2}, v.Ages)
ast.Equal([]int64{}, v.Slice)
} }
func TestUnmarshalIntSlice(t *testing.T) { func TestUnmarshalIntSlice(t *testing.T) {
var v struct { var v struct {
Ages []int `key:"ages"` Ages []int `key:"ages"`
Slice []int `key:"slice"`
} }
m := map[string]interface{}{ m := map[string]interface{}{
"ages": []int{1, 2}, "ages": []int{1, 2},
"slice": []interface{}{},
} }
ast := assert.New(t) ast := assert.New(t)
ast.Nil(UnmarshalKey(m, &v)) ast.Nil(UnmarshalKey(m, &v))
ast.ElementsMatch([]int{1, 2}, v.Ages) ast.ElementsMatch([]int{1, 2}, v.Ages)
ast.Equal([]int{}, v.Slice)
} }
func TestUnmarshalString(t *testing.T) { func TestUnmarshalString(t *testing.T) {
@@ -938,6 +987,43 @@ func TestUnmarshalWithStringOptionsCorrect(t *testing.T) {
ast.Equal("2", in.Correct) ast.Equal("2", in.Correct)
} }
func TestUnmarshalOptionsOptional(t *testing.T) {
type inner struct {
Value string `key:"value,options=first|second,optional"`
OptionalValue string `key:"optional_value,options=first|second,optional"`
Foo string `key:"foo,options=[bar,baz]"`
Correct string `key:"correct,options=1|2"`
}
m := map[string]interface{}{
"value": "first",
"foo": "bar",
"correct": "2",
}
var in inner
ast := assert.New(t)
ast.Nil(UnmarshalKey(m, &in))
ast.Equal("first", in.Value)
ast.Equal("", in.OptionalValue)
ast.Equal("bar", in.Foo)
ast.Equal("2", in.Correct)
}
func TestUnmarshalOptionsOptionalWrongValue(t *testing.T) {
type inner struct {
Value string `key:"value,options=first|second,optional"`
OptionalValue string `key:"optional_value,options=first|second,optional"`
WrongValue string `key:"wrong_value,options=first|second,optional"`
}
m := map[string]interface{}{
"value": "first",
"wrong_value": "third",
}
var in inner
assert.NotNil(t, UnmarshalKey(m, &in))
}
func TestUnmarshalStringOptionsWithStringOptionsNotString(t *testing.T) { func TestUnmarshalStringOptionsWithStringOptionsNotString(t *testing.T) {
type inner struct { type inner struct {
Value string `key:"value,options=first|second"` Value string `key:"value,options=first|second"`
@@ -2611,6 +2697,86 @@ func TestUnmarshalJsonWithoutKey(t *testing.T) {
assert.Equal(t, "2", res.B) assert.Equal(t, "2", res.B)
} }
func TestUnmarshalJsonUintNegative(t *testing.T) {
payload := `{"a": -1}`
var res struct {
A uint `json:"a"`
}
reader := strings.NewReader(payload)
err := UnmarshalJsonReader(reader, &res)
assert.NotNil(t, err)
}
func TestUnmarshalJsonDefinedInt(t *testing.T) {
type Int int
var res struct {
A Int `json:"a"`
}
payload := `{"a": -1}`
reader := strings.NewReader(payload)
err := UnmarshalJsonReader(reader, &res)
assert.Nil(t, err)
assert.Equal(t, Int(-1), res.A)
}
func TestUnmarshalJsonDefinedString(t *testing.T) {
type String string
var res struct {
A String `json:"a"`
}
payload := `{"a": "foo"}`
reader := strings.NewReader(payload)
err := UnmarshalJsonReader(reader, &res)
assert.Nil(t, err)
assert.Equal(t, String("foo"), res.A)
}
func TestUnmarshalJsonDefinedStringPtr(t *testing.T) {
type String string
var res struct {
A *String `json:"a"`
}
payload := `{"a": "foo"}`
reader := strings.NewReader(payload)
err := UnmarshalJsonReader(reader, &res)
assert.Nil(t, err)
assert.Equal(t, String("foo"), *res.A)
}
func TestUnmarshalJsonReaderComplex(t *testing.T) {
type (
MyInt int
MyTxt string
MyTxtArray []string
Req struct {
MyInt MyInt `json:"my_int"` // int.. ok
MyTxtArray MyTxtArray `json:"my_txt_array"`
MyTxt MyTxt `json:"my_txt"` // but string is not assignable
Int int `json:"int"`
Txt string `json:"txt"`
}
)
body := `{
"my_int": 100,
"my_txt_array": [
"a",
"b"
],
"my_txt": "my_txt",
"int": 200,
"txt": "txt"
}`
var req Req
err := UnmarshalJsonReader(strings.NewReader(body), &req)
assert.Nil(t, err)
assert.Equal(t, MyInt(100), req.MyInt)
assert.Equal(t, MyTxt("my_txt"), req.MyTxt)
assert.EqualValues(t, MyTxtArray([]string{"a", "b"}), req.MyTxtArray)
assert.Equal(t, 200, req.Int)
assert.Equal(t, "txt", req.Txt)
}
func BenchmarkDefaultValue(b *testing.B) { func BenchmarkDefaultValue(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
var a struct { var a struct {

View File

@@ -4,7 +4,6 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"io" "io"
"io/ioutil"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
) )
@@ -14,7 +13,7 @@ const yamlTagKey = "json"
var ( var (
// ErrUnsupportedType is an error that indicates the config format is not supported. // ErrUnsupportedType is an error that indicates the config format is not supported.
ErrUnsupportedType = errors.New("only map-like configs are suported") ErrUnsupportedType = errors.New("only map-like configs are supported")
yamlUnmarshaler = NewUnmarshaler(yamlTagKey) yamlUnmarshaler = NewUnmarshaler(yamlTagKey)
) )
@@ -29,39 +28,6 @@ func UnmarshalYamlReader(reader io.Reader, v interface{}) error {
return unmarshalYamlReader(reader, v, yamlUnmarshaler) return unmarshalYamlReader(reader, v, yamlUnmarshaler)
} }
func unmarshalYamlBytes(content []byte, v interface{}, unmarshaler *Unmarshaler) error {
var o interface{}
if err := yamlUnmarshal(content, &o); err != nil {
return err
}
if m, ok := o.(map[string]interface{}); ok {
return unmarshaler.Unmarshal(m, v)
}
return ErrUnsupportedType
}
func unmarshalYamlReader(reader io.Reader, v interface{}, unmarshaler *Unmarshaler) error {
content, err := ioutil.ReadAll(reader)
if err != nil {
return err
}
return unmarshalYamlBytes(content, v, unmarshaler)
}
// yamlUnmarshal YAML to map[string]interface{} instead of map[interface{}]interface{}.
func yamlUnmarshal(in []byte, out interface{}) error {
var res interface{}
if err := yaml.Unmarshal(in, &res); err != nil {
return err
}
*out.(*interface{}) = cleanupMapValue(res)
return nil
}
func cleanupInterfaceMap(in map[interface{}]interface{}) map[string]interface{} { func cleanupInterfaceMap(in map[interface{}]interface{}) map[string]interface{} {
res := make(map[string]interface{}) res := make(map[string]interface{})
for k, v := range in { for k, v := range in {
@@ -96,3 +62,40 @@ func cleanupMapValue(v interface{}) interface{} {
return Repr(v) return Repr(v)
} }
} }
func unmarshal(unmarshaler *Unmarshaler, o interface{}, v interface{}) error {
if m, ok := o.(map[string]interface{}); ok {
return unmarshaler.Unmarshal(m, v)
}
return ErrUnsupportedType
}
func unmarshalYamlBytes(content []byte, v interface{}, unmarshaler *Unmarshaler) error {
var o interface{}
if err := yamlUnmarshal(content, &o); err != nil {
return err
}
return unmarshal(unmarshaler, o, v)
}
func unmarshalYamlReader(reader io.Reader, v interface{}, unmarshaler *Unmarshaler) error {
var res interface{}
if err := yaml.NewDecoder(reader).Decode(&res); err != nil {
return err
}
return unmarshal(unmarshaler, cleanupMapValue(res), v)
}
// yamlUnmarshal YAML to map[string]interface{} instead of map[interface{}]interface{}.
func yamlUnmarshal(in []byte, out interface{}) error {
var res interface{}
if err := yaml.Unmarshal(in, &res); err != nil {
return err
}
*out.(*interface{}) = cleanupMapValue(res)
return nil
}

View File

@@ -926,14 +926,17 @@ func TestUnmarshalYamlBytesError(t *testing.T) {
} }
func TestUnmarshalYamlReaderError(t *testing.T) { func TestUnmarshalYamlReaderError(t *testing.T) {
payload := `abcd: cdef`
reader := strings.NewReader(payload)
var v struct { var v struct {
Any string Any string
} }
reader := strings.NewReader(`abcd: cdef`)
err := UnmarshalYamlReader(reader, &v) err := UnmarshalYamlReader(reader, &v)
assert.NotNil(t, err) assert.NotNil(t, err)
reader = strings.NewReader("chenquan")
err = UnmarshalYamlReader(reader, &v)
assert.ErrorIs(t, err, ErrUnsupportedType)
} }
func TestUnmarshalYamlBadReader(t *testing.T) { func TestUnmarshalYamlBadReader(t *testing.T) {
@@ -1011,6 +1014,6 @@ func TestUnmarshalYamlMapRune(t *testing.T) {
type badReader struct{} type badReader struct{}
func (b *badReader) Read(p []byte) (n int, err error) { func (b *badReader) Read(_ []byte) (n int, err error) {
return 0, io.ErrLimitReached return 0, io.ErrLimitReached
} }

View File

@@ -3,12 +3,11 @@ package mr
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"sync" "sync"
"sync/atomic"
"github.com/zeromicro/go-zero/core/errorx" "github.com/zeromicro/go-zero/core/errorx"
"github.com/zeromicro/go-zero/core/lang" "github.com/zeromicro/go-zero/core/lang"
"github.com/zeromicro/go-zero/core/threading"
) )
const ( const (
@@ -24,12 +23,12 @@ var (
) )
type ( type (
// ForEachFunc is used to do element processing, but no output.
ForEachFunc func(item interface{})
// GenerateFunc is used to let callers send elements into source. // GenerateFunc is used to let callers send elements into source.
GenerateFunc func(source chan<- interface{}) GenerateFunc func(source chan<- interface{})
// MapFunc is used to do element processing and write the output to writer. // MapFunc is used to do element processing and write the output to writer.
MapFunc func(item interface{}, writer Writer) MapFunc func(item interface{}, writer Writer)
// VoidMapFunc is used to do element processing, but no output.
VoidMapFunc func(item interface{})
// MapperFunc is used to do element processing and write the output to writer, // MapperFunc is used to do element processing and write the output to writer,
// use cancel func to cancel the processing. // use cancel func to cancel the processing.
MapperFunc func(item interface{}, writer Writer, cancel func(error)) MapperFunc func(item interface{}, writer Writer, cancel func(error))
@@ -42,6 +41,16 @@ type (
// Option defines the method to customize the mapreduce. // Option defines the method to customize the mapreduce.
Option func(opts *mapReduceOptions) Option func(opts *mapReduceOptions)
mapperContext struct {
ctx context.Context
mapper MapFunc
source <-chan interface{}
panicChan *onceChan
collector chan<- interface{}
doneChan <-chan lang.PlaceholderType
workers int
}
mapReduceOptions struct { mapReduceOptions struct {
ctx context.Context ctx context.Context
workers int workers int
@@ -69,7 +78,6 @@ func Finish(fns ...func() error) error {
cancel(err) cancel(err)
} }
}, func(pipe <-chan interface{}, cancel func(error)) { }, func(pipe <-chan interface{}, cancel func(error)) {
drain(pipe)
}, WithWorkers(len(fns))) }, WithWorkers(len(fns)))
} }
@@ -79,7 +87,7 @@ func FinishVoid(fns ...func()) {
return return
} }
MapVoid(func(source chan<- interface{}) { ForEach(func(source chan<- interface{}) {
for _, fn := range fns { for _, fn := range fns {
source <- fn source <- fn
} }
@@ -89,41 +97,74 @@ func FinishVoid(fns ...func()) {
}, WithWorkers(len(fns))) }, WithWorkers(len(fns)))
} }
// Map maps all elements generated from given generate func, and returns an output channel. // ForEach maps all elements from given generate but no output.
func Map(generate GenerateFunc, mapper MapFunc, opts ...Option) chan interface{} { func ForEach(generate GenerateFunc, mapper ForEachFunc, opts ...Option) {
options := buildOptions(opts...) options := buildOptions(opts...)
source := buildSource(generate) panicChan := &onceChan{channel: make(chan interface{})}
source := buildSource(generate, panicChan)
collector := make(chan interface{}, options.workers) collector := make(chan interface{}, options.workers)
done := make(chan lang.PlaceholderType) done := make(chan lang.PlaceholderType)
go executeMappers(options.ctx, mapper, source, collector, done, options.workers) go executeMappers(mapperContext{
ctx: options.ctx,
mapper: func(item interface{}, writer Writer) {
mapper(item)
},
source: source,
panicChan: panicChan,
collector: collector,
doneChan: done,
workers: options.workers,
})
return collector for {
select {
case v := <-panicChan.channel:
panic(v)
case _, ok := <-collector:
if !ok {
return
}
}
}
} }
// MapReduce maps all elements generated from given generate func, // MapReduce maps all elements generated from given generate func,
// and reduces the output elements with given reducer. // and reduces the output elements with given reducer.
func MapReduce(generate GenerateFunc, mapper MapperFunc, reducer ReducerFunc, func MapReduce(generate GenerateFunc, mapper MapperFunc, reducer ReducerFunc,
opts ...Option) (interface{}, error) { opts ...Option) (interface{}, error) {
source := buildSource(generate) panicChan := &onceChan{channel: make(chan interface{})}
return MapReduceWithSource(source, mapper, reducer, opts...) source := buildSource(generate, panicChan)
return mapReduceWithPanicChan(source, panicChan, mapper, reducer, opts...)
} }
// MapReduceWithSource maps all elements from source, and reduce the output elements with given reducer. // MapReduceChan maps all elements from source, and reduce the output elements with given reducer.
func MapReduceWithSource(source <-chan interface{}, mapper MapperFunc, reducer ReducerFunc, func MapReduceChan(source <-chan interface{}, mapper MapperFunc, reducer ReducerFunc,
opts ...Option) (interface{}, error) { opts ...Option) (interface{}, error) {
panicChan := &onceChan{channel: make(chan interface{})}
return mapReduceWithPanicChan(source, panicChan, mapper, reducer, opts...)
}
// MapReduceChan maps all elements from source, and reduce the output elements with given reducer.
func mapReduceWithPanicChan(source <-chan interface{}, panicChan *onceChan, mapper MapperFunc,
reducer ReducerFunc, opts ...Option) (interface{}, error) {
options := buildOptions(opts...) options := buildOptions(opts...)
// output is used to write the final result
output := make(chan interface{}) output := make(chan interface{})
defer func() { defer func() {
// reducer can only write once, if more, panic
for range output { for range output {
panic("more than one element written in reducer") panic("more than one element written in reducer")
} }
}() }()
// collector is used to collect data from mapper, and consume in reducer
collector := make(chan interface{}, options.workers) collector := make(chan interface{}, options.workers)
// if done is closed, all mappers and reducer should stop processing
done := make(chan lang.PlaceholderType) done := make(chan lang.PlaceholderType)
writer := newGuardedWriter(options.ctx, output, done) writer := newGuardedWriter(options.ctx, output, done)
var closeOnce sync.Once var closeOnce sync.Once
// use atomic.Value to avoid data race
var retErr errorx.AtomicError var retErr errorx.AtomicError
finish := func() { finish := func() {
closeOnce.Do(func() { closeOnce.Do(func() {
@@ -145,28 +186,41 @@ func MapReduceWithSource(source <-chan interface{}, mapper MapperFunc, reducer R
go func() { go func() {
defer func() { defer func() {
drain(collector) drain(collector)
if r := recover(); r != nil { if r := recover(); r != nil {
cancel(fmt.Errorf("%v", r)) panicChan.write(r)
} else {
finish()
} }
finish()
}() }()
reducer(collector, writer, cancel) reducer(collector, writer, cancel)
}() }()
go executeMappers(options.ctx, func(item interface{}, w Writer) { go executeMappers(mapperContext{
mapper(item, w, cancel) ctx: options.ctx,
}, source, collector, done, options.workers) mapper: func(item interface{}, w Writer) {
mapper(item, w, cancel)
},
source: source,
panicChan: panicChan,
collector: collector,
doneChan: done,
workers: options.workers,
})
value, ok := <-output select {
if err := retErr.Load(); err != nil { case <-options.ctx.Done():
return nil, err cancel(context.DeadlineExceeded)
} else if ok { return nil, context.DeadlineExceeded
return value, nil case v := <-panicChan.channel:
} else { panic(v)
return nil, ErrReduceNoOutput case v, ok := <-output:
if err := retErr.Load(); err != nil {
return nil, err
} else if ok {
return v, nil
} else {
return nil, ErrReduceNoOutput
}
} }
} }
@@ -175,18 +229,12 @@ func MapReduceWithSource(source <-chan interface{}, mapper MapperFunc, reducer R
func MapReduceVoid(generate GenerateFunc, mapper MapperFunc, reducer VoidReducerFunc, opts ...Option) error { func MapReduceVoid(generate GenerateFunc, mapper MapperFunc, reducer VoidReducerFunc, opts ...Option) error {
_, err := MapReduce(generate, mapper, func(input <-chan interface{}, writer Writer, cancel func(error)) { _, err := MapReduce(generate, mapper, func(input <-chan interface{}, writer Writer, cancel func(error)) {
reducer(input, cancel) reducer(input, cancel)
// We need to write a placeholder to let MapReduce to continue on reducer done,
// otherwise, all goroutines are waiting. The placeholder will be discarded by MapReduce.
writer.Write(lang.Placeholder)
}, opts...) }, opts...)
return err if errors.Is(err, ErrReduceNoOutput) {
} return nil
}
// MapVoid maps all elements from given generate but no output. return err
func MapVoid(generate GenerateFunc, mapper VoidMapFunc, opts ...Option) {
drain(Map(generate, func(item interface{}, writer Writer) {
mapper(item)
}, opts...))
} }
// WithContext customizes a mapreduce processing accepts a given ctx. // WithContext customizes a mapreduce processing accepts a given ctx.
@@ -216,12 +264,18 @@ func buildOptions(opts ...Option) *mapReduceOptions {
return options return options
} }
func buildSource(generate GenerateFunc) chan interface{} { func buildSource(generate GenerateFunc, panicChan *onceChan) chan interface{} {
source := make(chan interface{}) source := make(chan interface{})
threading.GoSafe(func() { go func() {
defer close(source) defer func() {
if r := recover(); r != nil {
panicChan.write(r)
}
close(source)
}()
generate(source) generate(source)
}) }()
return source return source
} }
@@ -233,39 +287,43 @@ func drain(channel <-chan interface{}) {
} }
} }
func executeMappers(ctx context.Context, mapper MapFunc, input <-chan interface{}, func executeMappers(mCtx mapperContext) {
collector chan<- interface{}, done <-chan lang.PlaceholderType, workers int) {
var wg sync.WaitGroup var wg sync.WaitGroup
defer func() { defer func() {
wg.Wait() wg.Wait()
close(collector) close(mCtx.collector)
drain(mCtx.source)
}() }()
pool := make(chan lang.PlaceholderType, workers) var failed int32
writer := newGuardedWriter(ctx, collector, done) pool := make(chan lang.PlaceholderType, mCtx.workers)
for { writer := newGuardedWriter(mCtx.ctx, mCtx.collector, mCtx.doneChan)
for atomic.LoadInt32(&failed) == 0 {
select { select {
case <-ctx.Done(): case <-mCtx.ctx.Done():
return return
case <-done: case <-mCtx.doneChan:
return return
case pool <- lang.Placeholder: case pool <- lang.Placeholder:
item, ok := <-input item, ok := <-mCtx.source
if !ok { if !ok {
<-pool <-pool
return return
} }
wg.Add(1) wg.Add(1)
// better to safely run caller defined method go func() {
threading.GoSafe(func() {
defer func() { defer func() {
if r := recover(); r != nil {
atomic.AddInt32(&failed, 1)
mCtx.panicChan.write(r)
}
wg.Done() wg.Done()
<-pool <-pool
}() }()
mapper(item, writer) mCtx.mapper(item, writer)
}) }()
} }
} }
} }
@@ -311,3 +369,16 @@ func (gw guardedWriter) Write(v interface{}) {
gw.channel <- v gw.channel <- v
} }
} }
type onceChan struct {
channel chan interface{}
wrote int32
}
func (oc *onceChan) write(val interface{}) {
if atomic.AddInt32(&oc.wrote, 1) > 1 {
return
}
oc.channel <- val
}

View File

@@ -0,0 +1,78 @@
//go:build go1.18
// +build go1.18
package mr
import (
"fmt"
"math/rand"
"runtime"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"go.uber.org/goleak"
)
func FuzzMapReduce(f *testing.F) {
rand.Seed(time.Now().UnixNano())
f.Add(uint(10), uint(runtime.NumCPU()))
f.Fuzz(func(t *testing.T, num uint, workers uint) {
n := int64(num)%5000 + 5000
genPanic := rand.Intn(100) == 0
mapperPanic := rand.Intn(100) == 0
reducerPanic := rand.Intn(100) == 0
genIdx := rand.Int63n(n)
mapperIdx := rand.Int63n(n)
reducerIdx := rand.Int63n(n)
squareSum := (n - 1) * n * (2*n - 1) / 6
fn := func() (interface{}, error) {
defer goleak.VerifyNone(t, goleak.IgnoreCurrent())
return MapReduce(func(source chan<- interface{}) {
for i := int64(0); i < n; i++ {
source <- i
if genPanic && i == genIdx {
panic("foo")
}
}
}, func(item interface{}, writer Writer, cancel func(error)) {
v := item.(int64)
if mapperPanic && v == mapperIdx {
panic("bar")
}
writer.Write(v * v)
}, func(pipe <-chan interface{}, writer Writer, cancel func(error)) {
var idx int64
var total int64
for v := range pipe {
if reducerPanic && idx == reducerIdx {
panic("baz")
}
total += v.(int64)
idx++
}
writer.Write(total)
}, WithWorkers(int(workers)%50+runtime.NumCPU()/2))
}
if genPanic || mapperPanic || reducerPanic {
var buf strings.Builder
buf.WriteString(fmt.Sprintf("n: %d", n))
buf.WriteString(fmt.Sprintf(", genPanic: %t", genPanic))
buf.WriteString(fmt.Sprintf(", mapperPanic: %t", mapperPanic))
buf.WriteString(fmt.Sprintf(", reducerPanic: %t", reducerPanic))
buf.WriteString(fmt.Sprintf(", genIdx: %d", genIdx))
buf.WriteString(fmt.Sprintf(", mapperIdx: %d", mapperIdx))
buf.WriteString(fmt.Sprintf(", reducerIdx: %d", reducerIdx))
assert.Panicsf(t, func() { fn() }, buf.String())
} else {
val, err := fn()
assert.Nil(t, err)
assert.Equal(t, squareSum, val.(int64))
}
})
}

View File

@@ -0,0 +1,107 @@
//go:build fuzz
// +build fuzz
package mr
import (
"fmt"
"math/rand"
"runtime"
"strconv"
"strings"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/threading"
"gopkg.in/cheggaaa/pb.v1"
)
// If Fuzz stuck, we don't know why, because it only returns hung or unexpected,
// so we need to simulate the fuzz test in test mode.
func TestMapReduceRandom(t *testing.T) {
rand.Seed(time.Now().UnixNano())
const (
times = 10000
nRange = 500
mega = 1024 * 1024
)
bar := pb.New(times).Start()
runner := threading.NewTaskRunner(runtime.NumCPU())
var wg sync.WaitGroup
wg.Add(times)
for i := 0; i < times; i++ {
runner.Schedule(func() {
start := time.Now()
defer func() {
if time.Since(start) > time.Minute {
t.Fatal("timeout")
}
wg.Done()
}()
t.Run(strconv.Itoa(i), func(t *testing.T) {
n := rand.Int63n(nRange)%nRange + nRange
workers := rand.Int()%50 + runtime.NumCPU()/2
genPanic := rand.Intn(100) == 0
mapperPanic := rand.Intn(100) == 0
reducerPanic := rand.Intn(100) == 0
genIdx := rand.Int63n(n)
mapperIdx := rand.Int63n(n)
reducerIdx := rand.Int63n(n)
squareSum := (n - 1) * n * (2*n - 1) / 6
fn := func() (interface{}, error) {
return MapReduce(func(source chan<- interface{}) {
for i := int64(0); i < n; i++ {
source <- i
if genPanic && i == genIdx {
panic("foo")
}
}
}, func(item interface{}, writer Writer, cancel func(error)) {
v := item.(int64)
if mapperPanic && v == mapperIdx {
panic("bar")
}
writer.Write(v * v)
}, func(pipe <-chan interface{}, writer Writer, cancel func(error)) {
var idx int64
var total int64
for v := range pipe {
if reducerPanic && idx == reducerIdx {
panic("baz")
}
total += v.(int64)
idx++
}
writer.Write(total)
}, WithWorkers(int(workers)%50+runtime.NumCPU()/2))
}
if genPanic || mapperPanic || reducerPanic {
var buf strings.Builder
buf.WriteString(fmt.Sprintf("n: %d", n))
buf.WriteString(fmt.Sprintf(", genPanic: %t", genPanic))
buf.WriteString(fmt.Sprintf(", mapperPanic: %t", mapperPanic))
buf.WriteString(fmt.Sprintf(", reducerPanic: %t", reducerPanic))
buf.WriteString(fmt.Sprintf(", genIdx: %d", genIdx))
buf.WriteString(fmt.Sprintf(", mapperIdx: %d", mapperIdx))
buf.WriteString(fmt.Sprintf(", reducerIdx: %d", reducerIdx))
assert.Panicsf(t, func() { fn() }, buf.String())
} else {
val, err := fn()
assert.Nil(t, err)
assert.Equal(t, squareSum, val.(int64))
}
bar.Increment()
})
})
}
wg.Wait()
bar.Finish()
}

View File

@@ -11,8 +11,7 @@ import (
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/stringx" "go.uber.org/goleak"
"github.com/zeromicro/go-zero/core/syncx"
) )
var errDummy = errors.New("dummy") var errDummy = errors.New("dummy")
@@ -22,6 +21,8 @@ func init() {
} }
func TestFinish(t *testing.T) { func TestFinish(t *testing.T) {
defer goleak.VerifyNone(t)
var total uint32 var total uint32
err := Finish(func() error { err := Finish(func() error {
atomic.AddUint32(&total, 2) atomic.AddUint32(&total, 2)
@@ -39,14 +40,20 @@ func TestFinish(t *testing.T) {
} }
func TestFinishNone(t *testing.T) { func TestFinishNone(t *testing.T) {
defer goleak.VerifyNone(t)
assert.Nil(t, Finish()) assert.Nil(t, Finish())
} }
func TestFinishVoidNone(t *testing.T) { func TestFinishVoidNone(t *testing.T) {
defer goleak.VerifyNone(t)
FinishVoid() FinishVoid()
} }
func TestFinishErr(t *testing.T) { func TestFinishErr(t *testing.T) {
defer goleak.VerifyNone(t)
var total uint32 var total uint32
err := Finish(func() error { err := Finish(func() error {
atomic.AddUint32(&total, 2) atomic.AddUint32(&total, 2)
@@ -63,6 +70,8 @@ func TestFinishErr(t *testing.T) {
} }
func TestFinishVoid(t *testing.T) { func TestFinishVoid(t *testing.T) {
defer goleak.VerifyNone(t)
var total uint32 var total uint32
FinishVoid(func() { FinishVoid(func() {
atomic.AddUint32(&total, 2) atomic.AddUint32(&total, 2)
@@ -75,70 +84,107 @@ func TestFinishVoid(t *testing.T) {
assert.Equal(t, uint32(10), atomic.LoadUint32(&total)) assert.Equal(t, uint32(10), atomic.LoadUint32(&total))
} }
func TestMap(t *testing.T) { func TestForEach(t *testing.T) {
tests := []struct { const tasks = 1000
mapper MapFunc
expect int
}{
{
mapper: func(item interface{}, writer Writer) {
v := item.(int)
writer.Write(v * v)
},
expect: 30,
},
{
mapper: func(item interface{}, writer Writer) {
v := item.(int)
if v%2 == 0 {
return
}
writer.Write(v * v)
},
expect: 10,
},
{
mapper: func(item interface{}, writer Writer) {
v := item.(int)
if v%2 == 0 {
panic(v)
}
writer.Write(v * v)
},
expect: 10,
},
}
for _, test := range tests { t.Run("all", func(t *testing.T) {
t.Run(stringx.Rand(), func(t *testing.T) { defer goleak.VerifyNone(t)
channel := Map(func(source chan<- interface{}) {
for i := 1; i < 5; i++ { var count uint32
ForEach(func(source chan<- interface{}) {
for i := 0; i < tasks; i++ {
source <- i
}
}, func(item interface{}) {
atomic.AddUint32(&count, 1)
}, WithWorkers(-1))
assert.Equal(t, tasks, int(count))
})
t.Run("odd", func(t *testing.T) {
defer goleak.VerifyNone(t)
var count uint32
ForEach(func(source chan<- interface{}) {
for i := 0; i < tasks; i++ {
source <- i
}
}, func(item interface{}) {
if item.(int)%2 == 0 {
atomic.AddUint32(&count, 1)
}
})
assert.Equal(t, tasks/2, int(count))
})
t.Run("all", func(t *testing.T) {
defer goleak.VerifyNone(t)
assert.PanicsWithValue(t, "foo", func() {
ForEach(func(source chan<- interface{}) {
for i := 0; i < tasks; i++ {
source <- i source <- i
} }
}, test.mapper, WithWorkers(-1)) }, func(item interface{}) {
panic("foo")
var result int })
for v := range channel {
result += v.(int)
}
assert.Equal(t, test.expect, result)
}) })
} })
}
func TestGeneratePanic(t *testing.T) {
defer goleak.VerifyNone(t)
t.Run("all", func(t *testing.T) {
assert.PanicsWithValue(t, "foo", func() {
ForEach(func(source chan<- interface{}) {
panic("foo")
}, func(item interface{}) {
})
})
})
}
func TestMapperPanic(t *testing.T) {
defer goleak.VerifyNone(t)
const tasks = 1000
var run int32
t.Run("all", func(t *testing.T) {
assert.PanicsWithValue(t, "foo", func() {
_, _ = MapReduce(func(source chan<- interface{}) {
for i := 0; i < tasks; i++ {
source <- i
}
}, func(item interface{}, writer Writer, cancel func(error)) {
atomic.AddInt32(&run, 1)
panic("foo")
}, func(pipe <-chan interface{}, writer Writer, cancel func(error)) {
})
})
assert.True(t, atomic.LoadInt32(&run) < tasks/2)
})
} }
func TestMapReduce(t *testing.T) { func TestMapReduce(t *testing.T) {
defer goleak.VerifyNone(t)
tests := []struct { tests := []struct {
name string
mapper MapperFunc mapper MapperFunc
reducer ReducerFunc reducer ReducerFunc
expectErr error expectErr error
expectValue interface{} expectValue interface{}
}{ }{
{ {
name: "simple",
expectErr: nil, expectErr: nil,
expectValue: 30, expectValue: 30,
}, },
{ {
name: "cancel with error",
mapper: func(item interface{}, writer Writer, cancel func(error)) { mapper: func(item interface{}, writer Writer, cancel func(error)) {
v := item.(int) v := item.(int)
if v%3 == 0 { if v%3 == 0 {
@@ -149,6 +195,7 @@ func TestMapReduce(t *testing.T) {
expectErr: errDummy, expectErr: errDummy,
}, },
{ {
name: "cancel with nil",
mapper: func(item interface{}, writer Writer, cancel func(error)) { mapper: func(item interface{}, writer Writer, cancel func(error)) {
v := item.(int) v := item.(int)
if v%3 == 0 { if v%3 == 0 {
@@ -160,6 +207,7 @@ func TestMapReduce(t *testing.T) {
expectValue: nil, expectValue: nil,
}, },
{ {
name: "cancel with more",
reducer: func(pipe <-chan interface{}, writer Writer, cancel func(error)) { reducer: func(pipe <-chan interface{}, writer Writer, cancel func(error)) {
var result int var result int
for item := range pipe { for item := range pipe {
@@ -174,36 +222,74 @@ func TestMapReduce(t *testing.T) {
}, },
} }
for _, test := range tests { t.Run("MapReduce", func(t *testing.T) {
t.Run(stringx.Rand(), func(t *testing.T) { for _, test := range tests {
if test.mapper == nil { t.Run(test.name, func(t *testing.T) {
test.mapper = func(item interface{}, writer Writer, cancel func(error)) { if test.mapper == nil {
v := item.(int) test.mapper = func(item interface{}, writer Writer, cancel func(error)) {
writer.Write(v * v) v := item.(int)
} writer.Write(v * v)
}
if test.reducer == nil {
test.reducer = func(pipe <-chan interface{}, writer Writer, cancel func(error)) {
var result int
for item := range pipe {
result += item.(int)
} }
writer.Write(result)
} }
} if test.reducer == nil {
value, err := MapReduce(func(source chan<- interface{}) { test.reducer = func(pipe <-chan interface{}, writer Writer, cancel func(error)) {
for i := 1; i < 5; i++ { var result int
source <- i for item := range pipe {
result += item.(int)
}
writer.Write(result)
}
} }
}, test.mapper, test.reducer, WithWorkers(runtime.NumCPU())) value, err := MapReduce(func(source chan<- interface{}) {
for i := 1; i < 5; i++ {
source <- i
}
}, test.mapper, test.reducer, WithWorkers(runtime.NumCPU()))
assert.Equal(t, test.expectErr, err) assert.Equal(t, test.expectErr, err)
assert.Equal(t, test.expectValue, value) assert.Equal(t, test.expectValue, value)
}) })
} }
})
t.Run("MapReduce", func(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
if test.mapper == nil {
test.mapper = func(item interface{}, writer Writer, cancel func(error)) {
v := item.(int)
writer.Write(v * v)
}
}
if test.reducer == nil {
test.reducer = func(pipe <-chan interface{}, writer Writer, cancel func(error)) {
var result int
for item := range pipe {
result += item.(int)
}
writer.Write(result)
}
}
source := make(chan interface{})
go func() {
for i := 1; i < 5; i++ {
source <- i
}
close(source)
}()
value, err := MapReduceChan(source, test.mapper, test.reducer, WithWorkers(-1))
assert.Equal(t, test.expectErr, err)
assert.Equal(t, test.expectValue, value)
})
}
})
} }
func TestMapReduceWithReduerWriteMoreThanOnce(t *testing.T) { func TestMapReduceWithReduerWriteMoreThanOnce(t *testing.T) {
defer goleak.VerifyNone(t)
assert.Panics(t, func() { assert.Panics(t, func() {
MapReduce(func(source chan<- interface{}) { MapReduce(func(source chan<- interface{}) {
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
@@ -220,18 +306,23 @@ func TestMapReduceWithReduerWriteMoreThanOnce(t *testing.T) {
} }
func TestMapReduceVoid(t *testing.T) { func TestMapReduceVoid(t *testing.T) {
defer goleak.VerifyNone(t)
var value uint32 var value uint32
tests := []struct { tests := []struct {
name string
mapper MapperFunc mapper MapperFunc
reducer VoidReducerFunc reducer VoidReducerFunc
expectValue uint32 expectValue uint32
expectErr error expectErr error
}{ }{
{ {
name: "simple",
expectValue: 30, expectValue: 30,
expectErr: nil, expectErr: nil,
}, },
{ {
name: "cancel with error",
mapper: func(item interface{}, writer Writer, cancel func(error)) { mapper: func(item interface{}, writer Writer, cancel func(error)) {
v := item.(int) v := item.(int)
if v%3 == 0 { if v%3 == 0 {
@@ -242,6 +333,7 @@ func TestMapReduceVoid(t *testing.T) {
expectErr: errDummy, expectErr: errDummy,
}, },
{ {
name: "cancel with nil",
mapper: func(item interface{}, writer Writer, cancel func(error)) { mapper: func(item interface{}, writer Writer, cancel func(error)) {
v := item.(int) v := item.(int)
if v%3 == 0 { if v%3 == 0 {
@@ -252,6 +344,7 @@ func TestMapReduceVoid(t *testing.T) {
expectErr: ErrCancelWithNil, expectErr: ErrCancelWithNil,
}, },
{ {
name: "cancel with more",
reducer: func(pipe <-chan interface{}, cancel func(error)) { reducer: func(pipe <-chan interface{}, cancel func(error)) {
for item := range pipe { for item := range pipe {
result := atomic.AddUint32(&value, uint32(item.(int))) result := atomic.AddUint32(&value, uint32(item.(int)))
@@ -265,7 +358,7 @@ func TestMapReduceVoid(t *testing.T) {
} }
for _, test := range tests { for _, test := range tests {
t.Run(stringx.Rand(), func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
atomic.StoreUint32(&value, 0) atomic.StoreUint32(&value, 0)
if test.mapper == nil { if test.mapper == nil {
@@ -296,6 +389,8 @@ func TestMapReduceVoid(t *testing.T) {
} }
func TestMapReduceVoidWithDelay(t *testing.T) { func TestMapReduceVoidWithDelay(t *testing.T) {
defer goleak.VerifyNone(t)
var result []int var result []int
err := MapReduceVoid(func(source chan<- interface{}) { err := MapReduceVoid(func(source chan<- interface{}) {
source <- 0 source <- 0
@@ -318,38 +413,64 @@ func TestMapReduceVoidWithDelay(t *testing.T) {
assert.Equal(t, 0, result[1]) assert.Equal(t, 0, result[1])
} }
func TestMapVoid(t *testing.T) { func TestMapReducePanic(t *testing.T) {
const tasks = 1000 defer goleak.VerifyNone(t)
var count uint32
MapVoid(func(source chan<- interface{}) {
for i := 0; i < tasks; i++ {
source <- i
}
}, func(item interface{}) {
atomic.AddUint32(&count, 1)
})
assert.Equal(t, tasks, int(count)) assert.Panics(t, func() {
_, _ = MapReduce(func(source chan<- interface{}) {
source <- 0
source <- 1
}, func(item interface{}, writer Writer, cancel func(error)) {
i := item.(int)
writer.Write(i)
}, func(pipe <-chan interface{}, writer Writer, cancel func(error)) {
for range pipe {
panic("panic")
}
})
})
} }
func TestMapReducePanic(t *testing.T) { func TestMapReducePanicOnce(t *testing.T) {
v, err := MapReduce(func(source chan<- interface{}) { defer goleak.VerifyNone(t)
source <- 0
source <- 1 assert.Panics(t, func() {
}, func(item interface{}, writer Writer, cancel func(error)) { _, _ = MapReduce(func(source chan<- interface{}) {
i := item.(int) for i := 0; i < 100; i++ {
writer.Write(i) source <- i
}, func(pipe <-chan interface{}, writer Writer, cancel func(error)) { }
for range pipe { }, func(item interface{}, writer Writer, cancel func(error)) {
panic("panic") i := item.(int)
} if i == 0 {
panic("foo")
}
writer.Write(i)
}, func(pipe <-chan interface{}, writer Writer, cancel func(error)) {
for range pipe {
panic("bar")
}
})
})
}
func TestMapReducePanicBothMapperAndReducer(t *testing.T) {
defer goleak.VerifyNone(t)
assert.Panics(t, func() {
_, _ = MapReduce(func(source chan<- interface{}) {
source <- 0
source <- 1
}, func(item interface{}, writer Writer, cancel func(error)) {
panic("foo")
}, func(pipe <-chan interface{}, writer Writer, cancel func(error)) {
panic("bar")
})
}) })
assert.Nil(t, v)
assert.NotNil(t, err)
assert.Equal(t, "panic", err.Error())
} }
func TestMapReduceVoidCancel(t *testing.T) { func TestMapReduceVoidCancel(t *testing.T) {
defer goleak.VerifyNone(t)
var result []int var result []int
err := MapReduceVoid(func(source chan<- interface{}) { err := MapReduceVoid(func(source chan<- interface{}) {
source <- 0 source <- 0
@@ -371,13 +492,15 @@ func TestMapReduceVoidCancel(t *testing.T) {
} }
func TestMapReduceVoidCancelWithRemains(t *testing.T) { func TestMapReduceVoidCancelWithRemains(t *testing.T) {
var done syncx.AtomicBool defer goleak.VerifyNone(t)
var done int32
var result []int var result []int
err := MapReduceVoid(func(source chan<- interface{}) { err := MapReduceVoid(func(source chan<- interface{}) {
for i := 0; i < defaultWorkers*2; i++ { for i := 0; i < defaultWorkers*2; i++ {
source <- i source <- i
} }
done.Set(true) atomic.AddInt32(&done, 1)
}, func(item interface{}, writer Writer, cancel func(error)) { }, func(item interface{}, writer Writer, cancel func(error)) {
i := item.(int) i := item.(int)
if i == defaultWorkers/2 { if i == defaultWorkers/2 {
@@ -392,10 +515,12 @@ func TestMapReduceVoidCancelWithRemains(t *testing.T) {
}) })
assert.NotNil(t, err) assert.NotNil(t, err)
assert.Equal(t, "anything", err.Error()) assert.Equal(t, "anything", err.Error())
assert.True(t, done.True()) assert.Equal(t, int32(1), done)
} }
func TestMapReduceWithoutReducerWrite(t *testing.T) { func TestMapReduceWithoutReducerWrite(t *testing.T) {
defer goleak.VerifyNone(t)
uids := []int{1, 2, 3} uids := []int{1, 2, 3}
res, err := MapReduce(func(source chan<- interface{}) { res, err := MapReduce(func(source chan<- interface{}) {
for _, uid := range uids { for _, uid := range uids {
@@ -412,33 +537,54 @@ func TestMapReduceWithoutReducerWrite(t *testing.T) {
} }
func TestMapReduceVoidPanicInReducer(t *testing.T) { func TestMapReduceVoidPanicInReducer(t *testing.T) {
defer goleak.VerifyNone(t)
const message = "foo" const message = "foo"
var done syncx.AtomicBool assert.Panics(t, func() {
err := MapReduceVoid(func(source chan<- interface{}) { var done int32
_ = MapReduceVoid(func(source chan<- interface{}) {
for i := 0; i < defaultWorkers*2; i++ {
source <- i
}
atomic.AddInt32(&done, 1)
}, func(item interface{}, writer Writer, cancel func(error)) {
i := item.(int)
writer.Write(i)
}, func(pipe <-chan interface{}, cancel func(error)) {
panic(message)
}, WithWorkers(1))
})
}
func TestForEachWithContext(t *testing.T) {
defer goleak.VerifyNone(t)
var done int32
ctx, cancel := context.WithCancel(context.Background())
ForEach(func(source chan<- interface{}) {
for i := 0; i < defaultWorkers*2; i++ { for i := 0; i < defaultWorkers*2; i++ {
source <- i source <- i
} }
done.Set(true) atomic.AddInt32(&done, 1)
}, func(item interface{}, writer Writer, cancel func(error)) { }, func(item interface{}) {
i := item.(int) i := item.(int)
writer.Write(i) if i == defaultWorkers/2 {
}, func(pipe <-chan interface{}, cancel func(error)) { cancel()
panic(message) }
}, WithWorkers(1)) }, WithContext(ctx))
assert.NotNil(t, err)
assert.Equal(t, message, err.Error())
assert.True(t, done.True())
} }
func TestMapReduceWithContext(t *testing.T) { func TestMapReduceWithContext(t *testing.T) {
var done syncx.AtomicBool defer goleak.VerifyNone(t)
var done int32
var result []int var result []int
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
err := MapReduceVoid(func(source chan<- interface{}) { err := MapReduceVoid(func(source chan<- interface{}) {
for i := 0; i < defaultWorkers*2; i++ { for i := 0; i < defaultWorkers*2; i++ {
source <- i source <- i
} }
done.Set(true) atomic.AddInt32(&done, 1)
}, func(item interface{}, writer Writer, c func(error)) { }, func(item interface{}, writer Writer, c func(error)) {
i := item.(int) i := item.(int)
if i == defaultWorkers/2 { if i == defaultWorkers/2 {
@@ -452,7 +598,7 @@ func TestMapReduceWithContext(t *testing.T) {
} }
}, WithContext(ctx)) }, WithContext(ctx))
assert.NotNil(t, err) assert.NotNil(t, err)
assert.Equal(t, ErrReduceNoOutput, err) assert.Equal(t, context.DeadlineExceeded, err)
} }
func BenchmarkMapReduce(b *testing.B) { func BenchmarkMapReduce(b *testing.B) {

View File

@@ -54,7 +54,7 @@ import (
"fmt" "fmt"
"log" "log"
"github.com/tal-tech/go-zero/core/mr" "github.com/zeromicro/go-zero/core/mr"
) )
func main() { func main() {

View File

@@ -55,7 +55,7 @@ import (
"fmt" "fmt"
"log" "log"
"github.com/tal-tech/go-zero/core/mr" "github.com/zeromicro/go-zero/core/mr"
) )
func main() { func main() {
@@ -87,4 +87,4 @@ More examples: [https://github.com/zeromicro/zero-examples/tree/main/mapreduce](
## Give a Star! ⭐ ## Give a Star! ⭐
If you like or are using this project to learn or start your solution, please give it a star. Thanks! If you like or are using this project to learn or start your solution, please give it a star. Thanks!

View File

@@ -1,3 +1,6 @@
//go:build linux || darwin
// +build linux darwin
package proc package proc
import ( import (

31
core/prof/runtime.go Normal file
View File

@@ -0,0 +1,31 @@
package prof
import (
"fmt"
"runtime"
"time"
)
const (
defaultInterval = time.Second * 5
mega = 1024 * 1024
)
// DisplayStats prints the goroutine, memory, GC stats with given interval, default to 5 seconds.
func DisplayStats(interval ...time.Duration) {
duration := defaultInterval
for _, val := range interval {
duration = val
}
go func() {
ticker := time.NewTicker(duration)
defer ticker.Stop()
for range ticker.C {
var m runtime.MemStats
runtime.ReadMemStats(&m)
fmt.Printf("Goroutines: %d, Alloc: %vm, TotalAlloc: %vm, Sys: %vm, NumGC: %v\n",
runtime.NumGoroutine(), m.Alloc/mega, m.TotalAlloc/mega, m.Sys/mega, m.NumGC)
}
}()
}

View File

@@ -1,78 +1,129 @@
package internal package internal
import ( import (
"bufio"
"fmt" "fmt"
"os" "os"
"path" "path"
"strconv" "strconv"
"strings" "strings"
"sync"
"time"
"github.com/zeromicro/go-zero/core/iox" "github.com/zeromicro/go-zero/core/iox"
"github.com/zeromicro/go-zero/core/lang" "github.com/zeromicro/go-zero/core/lang"
"golang.org/x/sys/unix"
) )
const cgroupDir = "/sys/fs/cgroup" const (
cgroupDir = "/sys/fs/cgroup"
cpuStatFile = cgroupDir + "/cpu.stat"
cpusetFile = cgroupDir + "/cpuset.cpus.effective"
)
type cgroup struct { var (
isUnifiedOnce sync.Once
isUnified bool
inUserNS bool
nsOnce sync.Once
)
type cgroup interface {
cpuQuotaUs() (int64, error)
cpuPeriodUs() (uint64, error)
cpus() ([]uint64, error)
usageAllCpus() (uint64, error)
}
func currentCgroup() (cgroup, error) {
if isCgroup2UnifiedMode() {
return currentCgroupV2()
}
return currentCgroupV1()
}
type cgroupV1 struct {
cgroups map[string]string cgroups map[string]string
} }
func (c *cgroup) acctUsageAllCpus() (uint64, error) { func (c *cgroupV1) cpuQuotaUs() (int64, error) {
data, err := iox.ReadText(path.Join(c.cgroups["cpuacct"], "cpuacct.usage"))
if err != nil {
return 0, err
}
return parseUint(string(data))
}
func (c *cgroup) acctUsagePerCpu() ([]uint64, error) {
data, err := iox.ReadText(path.Join(c.cgroups["cpuacct"], "cpuacct.usage_percpu"))
if err != nil {
return nil, err
}
var usage []uint64
for _, v := range strings.Fields(string(data)) {
u, err := parseUint(v)
if err != nil {
return nil, err
}
usage = append(usage, u)
}
return usage, nil
}
func (c *cgroup) cpuQuotaUs() (int64, error) {
data, err := iox.ReadText(path.Join(c.cgroups["cpu"], "cpu.cfs_quota_us")) data, err := iox.ReadText(path.Join(c.cgroups["cpu"], "cpu.cfs_quota_us"))
if err != nil { if err != nil {
return 0, err return 0, err
} }
return strconv.ParseInt(string(data), 10, 64) return strconv.ParseInt(data, 10, 64)
} }
func (c *cgroup) cpuPeriodUs() (uint64, error) { func (c *cgroupV1) cpuPeriodUs() (uint64, error) {
data, err := iox.ReadText(path.Join(c.cgroups["cpu"], "cpu.cfs_period_us")) data, err := iox.ReadText(path.Join(c.cgroups["cpu"], "cpu.cfs_period_us"))
if err != nil { if err != nil {
return 0, err return 0, err
} }
return parseUint(string(data)) return parseUint(data)
} }
func (c *cgroup) cpus() ([]uint64, error) { func (c *cgroupV1) cpus() ([]uint64, error) {
data, err := iox.ReadText(path.Join(c.cgroups["cpuset"], "cpuset.cpus")) data, err := iox.ReadText(path.Join(c.cgroups["cpuset"], "cpuset.cpus"))
if err != nil { if err != nil {
return nil, err return nil, err
} }
return parseUints(string(data)) return parseUints(data)
} }
func currentCgroup() (*cgroup, error) { func (c *cgroupV1) usageAllCpus() (uint64, error) {
data, err := iox.ReadText(path.Join(c.cgroups["cpuacct"], "cpuacct.usage"))
if err != nil {
return 0, err
}
return parseUint(data)
}
type cgroupV2 struct {
cgroups map[string]string
}
func (c *cgroupV2) cpuQuotaUs() (int64, error) {
data, err := iox.ReadText(path.Join(cgroupDir, "cpu.cfs_quota_us"))
if err != nil {
return 0, err
}
return strconv.ParseInt(data, 10, 64)
}
func (c *cgroupV2) cpuPeriodUs() (uint64, error) {
data, err := iox.ReadText(path.Join(cgroupDir, "cpu.cfs_period_us"))
if err != nil {
return 0, err
}
return parseUint(data)
}
func (c *cgroupV2) cpus() ([]uint64, error) {
data, err := iox.ReadText(cpusetFile)
if err != nil {
return nil, err
}
return parseUints(data)
}
func (c *cgroupV2) usageAllCpus() (uint64, error) {
usec, err := parseUint(c.cgroups["usage_usec"])
if err != nil {
return 0, err
}
return usec * uint64(time.Microsecond), nil
}
func currentCgroupV1() (cgroup, error) {
cgroupFile := fmt.Sprintf("/proc/%d/cgroup", os.Getpid()) cgroupFile := fmt.Sprintf("/proc/%d/cgroup", os.Getpid())
lines, err := iox.ReadTextLines(cgroupFile, iox.WithoutBlank()) lines, err := iox.ReadTextLines(cgroupFile, iox.WithoutBlank())
if err != nil { if err != nil {
@@ -100,11 +151,51 @@ func currentCgroup() (*cgroup, error) {
} }
} }
return &cgroup{ return &cgroupV1{
cgroups: cgroups, cgroups: cgroups,
}, nil }, nil
} }
func currentCgroupV2() (cgroup, error) {
lines, err := iox.ReadTextLines(cpuStatFile, iox.WithoutBlank())
if err != nil {
return nil, err
}
cgroups := make(map[string]string)
for _, line := range lines {
cols := strings.Fields(line)
if len(cols) != 2 {
return nil, fmt.Errorf("invalid cgroupV2 line: %s", line)
}
cgroups[cols[0]] = cols[1]
}
return &cgroupV2{
cgroups: cgroups,
}, nil
}
// isCgroup2UnifiedMode returns whether we are running in cgroup v2 unified mode.
func isCgroup2UnifiedMode() bool {
isUnifiedOnce.Do(func() {
var st unix.Statfs_t
err := unix.Statfs(cgroupDir, &st)
if err != nil {
if os.IsNotExist(err) && runningInUserNS() {
// ignore the "not found" error if running in userns
isUnified = false
return
}
panic(fmt.Sprintf("cannot statfs cgroup root: %s", err))
}
isUnified = st.Type == unix.CGROUP2_SUPER_MAGIC
})
return isUnified
}
func parseUint(s string) (uint64, error) { func parseUint(s string) (uint64, error) {
v, err := strconv.ParseInt(s, 10, 64) v, err := strconv.ParseInt(s, 10, 64)
if err != nil { if err != nil {
@@ -166,3 +257,36 @@ func parseUints(val string) ([]uint64, error) {
return sets, nil return sets, nil
} }
// runningInUserNS detects whether we are currently running in a user namespace.
func runningInUserNS() bool {
nsOnce.Do(func() {
file, err := os.Open("/proc/self/uid_map")
if err != nil {
// This kernel-provided file only exists if user namespaces are supported
return
}
defer file.Close()
buf := bufio.NewReader(file)
l, _, err := buf.ReadLine()
if err != nil {
return
}
line := string(l)
var a, b, c int64
fmt.Sscanf(line, "%d %d %d", &a, &b, &c)
/*
* We assume we are in the initial user namespace if we have a full
* range - 4294967295 uids starting at uid 0.
*/
if a == 0 && b == 0 && c == 4294967295 {
return
}
inUserNS = true
})
return inUserNS
}

View File

@@ -24,7 +24,7 @@ var (
// if /proc not present, ignore the cpu calculation, like wsl linux // if /proc not present, ignore the cpu calculation, like wsl linux
func init() { func init() {
cpus, err := perCpuUsage() cpus, err := cpuSets()
if err != nil { if err != nil {
logx.Error(err) logx.Error(err)
return return
@@ -117,15 +117,6 @@ func cpuSets() ([]uint64, error) {
return cg.cpus() return cg.cpus()
} }
func perCpuUsage() ([]uint64, error) {
cg, err := currentCgroup()
if err != nil {
return nil, err
}
return cg.acctUsagePerCpu()
}
func systemCpuUsage() (uint64, error) { func systemCpuUsage() (uint64, error) {
lines, err := iox.ReadTextLines("/proc/stat", iox.WithoutBlank()) lines, err := iox.ReadTextLines("/proc/stat", iox.WithoutBlank())
if err != nil { if err != nil {
@@ -157,10 +148,10 @@ func systemCpuUsage() (uint64, error) {
} }
func totalCpuUsage() (usage uint64, err error) { func totalCpuUsage() (usage uint64, err error) {
var cg *cgroup var cg cgroup
if cg, err = currentCgroup(); err != nil { if cg, err = currentCgroup(); err != nil {
return return
} }
return cg.acctUsageAllCpus() return cg.usageAllCpus()
} }

View File

@@ -10,7 +10,10 @@ import (
"github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/logx"
) )
const httpTimeout = time.Second * 5 const (
httpTimeout = time.Second * 5
jsonContentType = "application/json; charset=utf-8"
)
// ErrWriteFailed is an error that indicates failed to submit a StatReport. // ErrWriteFailed is an error that indicates failed to submit a StatReport.
var ErrWriteFailed = errors.New("submit failed") var ErrWriteFailed = errors.New("submit failed")
@@ -36,7 +39,7 @@ func (rw *RemoteWriter) Write(report *StatReport) error {
client := &http.Client{ client := &http.Client{
Timeout: httpTimeout, Timeout: httpTimeout,
} }
resp, err := client.Post(rw.endpoint, "application/json", bytes.NewBuffer(bs)) resp, err := client.Post(rw.endpoint, jsonContentType, bytes.NewReader(bs))
if err != nil { if err != nil {
return err return err
} }

View File

@@ -30,18 +30,32 @@ func RawFieldNames(in interface{}, postgresSql ...bool) []string {
for i := 0; i < v.NumField(); i++ { for i := 0; i < v.NumField(); i++ {
// gets us a StructField // gets us a StructField
fi := typ.Field(i) fi := typ.Field(i)
if tagv := fi.Tag.Get(dbTag); tagv != "" { tagv := fi.Tag.Get(dbTag)
if pg { switch tagv {
out = append(out, tagv) case "-":
} else { continue
out = append(out, fmt.Sprintf("`%s`", tagv)) case "":
}
} else {
if pg { if pg {
out = append(out, fi.Name) out = append(out, fi.Name)
} else { } else {
out = append(out, fmt.Sprintf("`%s`", fi.Name)) out = append(out, fmt.Sprintf("`%s`", fi.Name))
} }
default:
// get tag name with the tag opton, e.g.:
// `db:"id"`
// `db:"id,type=char,length=16"`
// `db:",type=char,length=16"`
if strings.Contains(tagv, ",") {
tagv = strings.TrimSpace(strings.Split(tagv, ",")[0])
}
if len(tagv) == 0 {
tagv = fi.Name
}
if pg {
out = append(out, tagv)
} else {
out = append(out, fmt.Sprintf("`%s`", tagv))
}
} }
} }

View File

@@ -22,3 +22,20 @@ func TestFieldNames(t *testing.T) {
assert.Equal(t, expected, out) assert.Equal(t, expected, out)
}) })
} }
type mockedUserWithOptions struct {
ID string `db:"id" json:"id,omitempty"`
UserName string `db:"user_name,type=varchar,length=255" json:"userName,omitempty"`
Sex int `db:"sex" json:"sex,omitempty"`
UUID string `db:",type=varchar,length=16" uuid:"uuid,omitempty"`
Age int `db:"age" json:"age"`
}
func TestFieldNamesWithTagOptions(t *testing.T) {
t.Run("new", func(t *testing.T) {
var u mockedUserWithOptions
out := RawFieldNames(&u)
expected := []string{"`id`", "`user_name`", "`sex`", "`UUID`", "`age`"}
assert.Equal(t, expected, out)
})
}

View File

@@ -1,6 +1,8 @@
package cache package cache
import ( import (
"context"
"errors"
"fmt" "fmt"
"log" "log"
"time" "time"
@@ -13,13 +15,37 @@ import (
type ( type (
// Cache interface is used to define the cache implementation. // Cache interface is used to define the cache implementation.
Cache interface { Cache interface {
// Del deletes cached values with keys.
Del(keys ...string) error Del(keys ...string) error
Get(key string, v interface{}) error // DelCtx deletes cached values with keys.
DelCtx(ctx context.Context, keys ...string) error
// Get gets the cache with key and fills into v.
Get(key string, val interface{}) error
// GetCtx gets the cache with key and fills into v.
GetCtx(ctx context.Context, key string, val interface{}) error
// IsNotFound checks if the given error is the defined errNotFound.
IsNotFound(err error) bool IsNotFound(err error) bool
Set(key string, v interface{}) error // Set sets the cache with key and v, using c.expiry.
SetWithExpire(key string, v interface{}, expire time.Duration) error Set(key string, val interface{}) error
Take(v interface{}, key string, query func(v interface{}) error) error // SetCtx sets the cache with key and v, using c.expiry.
TakeWithExpire(v interface{}, key string, query func(v interface{}, expire time.Duration) error) error SetCtx(ctx context.Context, key string, val interface{}) error
// SetWithExpire sets the cache with key and v, using given expire.
SetWithExpire(key string, val interface{}, expire time.Duration) error
// SetWithExpireCtx sets the cache with key and v, using given expire.
SetWithExpireCtx(ctx context.Context, key string, val interface{}, expire time.Duration) error
// Take takes the result from cache first, if not found,
// query from DB and set cache using c.expiry, then return the result.
Take(val interface{}, key string, query func(val interface{}) error) error
// TakeCtx takes the result from cache first, if not found,
// query from DB and set cache using c.expiry, then return the result.
TakeCtx(ctx context.Context, val interface{}, key string, query func(val interface{}) error) error
// TakeWithExpire takes the result from cache first, if not found,
// query from DB and set cache using given expire, then return the result.
TakeWithExpire(val interface{}, key string, query func(val interface{}, expire time.Duration) error) error
// TakeWithExpireCtx takes the result from cache first, if not found,
// query from DB and set cache using given expire, then return the result.
TakeWithExpireCtx(ctx context.Context, val interface{}, key string,
query func(val interface{}, expire time.Duration) error) error
} }
cacheCluster struct { cacheCluster struct {
@@ -51,7 +77,13 @@ func New(c ClusterConf, barrier syncx.SingleFlight, st *Stat, errNotFound error,
} }
} }
// Del deletes cached values with keys.
func (cc cacheCluster) Del(keys ...string) error { func (cc cacheCluster) Del(keys ...string) error {
return cc.DelCtx(context.Background(), keys...)
}
// DelCtx deletes cached values with keys.
func (cc cacheCluster) DelCtx(ctx context.Context, keys ...string) error {
switch len(keys) { switch len(keys) {
case 0: case 0:
return nil return nil
@@ -62,7 +94,7 @@ func (cc cacheCluster) Del(keys ...string) error {
return cc.errNotFound return cc.errNotFound
} }
return c.(Cache).Del(key) return c.(Cache).DelCtx(ctx, key)
default: default:
var be errorx.BatchError var be errorx.BatchError
nodes := make(map[interface{}][]string) nodes := make(map[interface{}][]string)
@@ -76,7 +108,7 @@ func (cc cacheCluster) Del(keys ...string) error {
nodes[c] = append(nodes[c], key) nodes[c] = append(nodes[c], key)
} }
for c, ks := range nodes { for c, ks := range nodes {
if err := c.(Cache).Del(ks...); err != nil { if err := c.(Cache).DelCtx(ctx, ks...); err != nil {
be.Add(err) be.Add(err)
} }
} }
@@ -85,52 +117,86 @@ func (cc cacheCluster) Del(keys ...string) error {
} }
} }
func (cc cacheCluster) Get(key string, v interface{}) error { // Get gets the cache with key and fills into v.
func (cc cacheCluster) Get(key string, val interface{}) error {
return cc.GetCtx(context.Background(), key, val)
}
// GetCtx gets the cache with key and fills into v.
func (cc cacheCluster) GetCtx(ctx context.Context, key string, val interface{}) error {
c, ok := cc.dispatcher.Get(key) c, ok := cc.dispatcher.Get(key)
if !ok { if !ok {
return cc.errNotFound return cc.errNotFound
} }
return c.(Cache).Get(key, v) return c.(Cache).GetCtx(ctx, key, val)
} }
// IsNotFound checks if the given error is the defined errNotFound.
func (cc cacheCluster) IsNotFound(err error) bool { func (cc cacheCluster) IsNotFound(err error) bool {
return err == cc.errNotFound return errors.Is(err, cc.errNotFound)
} }
func (cc cacheCluster) Set(key string, v interface{}) error { // Set sets the cache with key and v, using c.expiry.
func (cc cacheCluster) Set(key string, val interface{}) error {
return cc.SetCtx(context.Background(), key, val)
}
// SetCtx sets the cache with key and v, using c.expiry.
func (cc cacheCluster) SetCtx(ctx context.Context, key string, val interface{}) error {
c, ok := cc.dispatcher.Get(key) c, ok := cc.dispatcher.Get(key)
if !ok { if !ok {
return cc.errNotFound return cc.errNotFound
} }
return c.(Cache).Set(key, v) return c.(Cache).SetCtx(ctx, key, val)
} }
func (cc cacheCluster) SetWithExpire(key string, v interface{}, expire time.Duration) error { // SetWithExpire sets the cache with key and v, using given expire.
func (cc cacheCluster) SetWithExpire(key string, val interface{}, expire time.Duration) error {
return cc.SetWithExpireCtx(context.Background(), key, val, expire)
}
// SetWithExpireCtx sets the cache with key and v, using given expire.
func (cc cacheCluster) SetWithExpireCtx(ctx context.Context, key string, val interface{}, expire time.Duration) error {
c, ok := cc.dispatcher.Get(key) c, ok := cc.dispatcher.Get(key)
if !ok { if !ok {
return cc.errNotFound return cc.errNotFound
} }
return c.(Cache).SetWithExpire(key, v, expire) return c.(Cache).SetWithExpireCtx(ctx, key, val, expire)
} }
func (cc cacheCluster) Take(v interface{}, key string, query func(v interface{}) error) error { // Take takes the result from cache first, if not found,
// query from DB and set cache using c.expiry, then return the result.
func (cc cacheCluster) Take(val interface{}, key string, query func(val interface{}) error) error {
return cc.TakeCtx(context.Background(), val, key, query)
}
// TakeCtx takes the result from cache first, if not found,
// query from DB and set cache using c.expiry, then return the result.
func (cc cacheCluster) TakeCtx(ctx context.Context, val interface{}, key string, query func(val interface{}) error) error {
c, ok := cc.dispatcher.Get(key) c, ok := cc.dispatcher.Get(key)
if !ok { if !ok {
return cc.errNotFound return cc.errNotFound
} }
return c.(Cache).Take(v, key, query) return c.(Cache).TakeCtx(ctx, val, key, query)
} }
func (cc cacheCluster) TakeWithExpire(v interface{}, key string, // TakeWithExpire takes the result from cache first, if not found,
query func(v interface{}, expire time.Duration) error) error { // query from DB and set cache using given expire, then return the result.
func (cc cacheCluster) TakeWithExpire(val interface{}, key string, query func(val interface{}, expire time.Duration) error) error {
return cc.TakeWithExpireCtx(context.Background(), val, key, query)
}
// TakeWithExpireCtx takes the result from cache first, if not found,
// query from DB and set cache using given expire, then return the result.
func (cc cacheCluster) TakeWithExpireCtx(ctx context.Context, val interface{}, key string, query func(val interface{}, expire time.Duration) error) error {
c, ok := cc.dispatcher.Get(key) c, ok := cc.dispatcher.Get(key)
if !ok { if !ok {
return cc.errNotFound return cc.errNotFound
} }
return c.(Cache).TakeWithExpire(v, key, query) return c.(Cache).TakeWithExpireCtx(ctx, val, key, query)
} }

View File

@@ -1,7 +1,9 @@
package cache package cache
import ( import (
"context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"math" "math"
"strconv" "strconv"
@@ -16,12 +18,18 @@ import (
"github.com/zeromicro/go-zero/core/syncx" "github.com/zeromicro/go-zero/core/syncx"
) )
var _ Cache = (*mockedNode)(nil)
type mockedNode struct { type mockedNode struct {
vals map[string][]byte vals map[string][]byte
errNotFound error errNotFound error
} }
func (mc *mockedNode) Del(keys ...string) error { func (mc *mockedNode) Del(keys ...string) error {
return mc.DelCtx(context.Background(), keys...)
}
func (mc *mockedNode) DelCtx(_ context.Context, keys ...string) error {
var be errorx.BatchError var be errorx.BatchError
for _, key := range keys { for _, key := range keys {
@@ -35,21 +43,29 @@ func (mc *mockedNode) Del(keys ...string) error {
return be.Err() return be.Err()
} }
func (mc *mockedNode) Get(key string, v interface{}) error { func (mc *mockedNode) Get(key string, val interface{}) error {
return mc.GetCtx(context.Background(), key, val)
}
func (mc *mockedNode) GetCtx(ctx context.Context, key string, val interface{}) error {
bs, ok := mc.vals[key] bs, ok := mc.vals[key]
if ok { if ok {
return json.Unmarshal(bs, v) return json.Unmarshal(bs, val)
} }
return mc.errNotFound return mc.errNotFound
} }
func (mc *mockedNode) IsNotFound(err error) bool { func (mc *mockedNode) IsNotFound(err error) bool {
return err == mc.errNotFound return errors.Is(err, mc.errNotFound)
} }
func (mc *mockedNode) Set(key string, v interface{}) error { func (mc *mockedNode) Set(key string, val interface{}) error {
data, err := json.Marshal(v) return mc.SetCtx(context.Background(), key, val)
}
func (mc *mockedNode) SetCtx(ctx context.Context, key string, val interface{}) error {
data, err := json.Marshal(val)
if err != nil { if err != nil {
return err return err
} }
@@ -58,25 +74,37 @@ func (mc *mockedNode) Set(key string, v interface{}) error {
return nil return nil
} }
func (mc *mockedNode) SetWithExpire(key string, v interface{}, expire time.Duration) error { func (mc *mockedNode) SetWithExpire(key string, val interface{}, expire time.Duration) error {
return mc.Set(key, v) return mc.SetWithExpireCtx(context.Background(), key, val, expire)
} }
func (mc *mockedNode) Take(v interface{}, key string, query func(v interface{}) error) error { func (mc *mockedNode) SetWithExpireCtx(ctx context.Context, key string, val interface{}, expire time.Duration) error {
return mc.Set(key, val)
}
func (mc *mockedNode) Take(val interface{}, key string, query func(val interface{}) error) error {
return mc.TakeCtx(context.Background(), val, key, query)
}
func (mc *mockedNode) TakeCtx(ctx context.Context, val interface{}, key string, query func(val interface{}) error) error {
if _, ok := mc.vals[key]; ok { if _, ok := mc.vals[key]; ok {
return mc.Get(key, v) return mc.GetCtx(ctx, key, val)
} }
if err := query(v); err != nil { if err := query(val); err != nil {
return err return err
} }
return mc.Set(key, v) return mc.SetCtx(ctx, key, val)
} }
func (mc *mockedNode) TakeWithExpire(v interface{}, key string, query func(v interface{}, expire time.Duration) error) error { func (mc *mockedNode) TakeWithExpire(val interface{}, key string, query func(val interface{}, expire time.Duration) error) error {
return mc.Take(v, key, func(v interface{}) error { return mc.TakeWithExpireCtx(context.Background(), val, key, query)
return query(v, 0) }
func (mc *mockedNode) TakeWithExpireCtx(ctx context.Context, val interface{}, key string, query func(val interface{}, expire time.Duration) error) error {
return mc.Take(val, key, func(val interface{}) error {
return query(val, 0)
}) })
} }
@@ -113,18 +141,18 @@ func TestCache_SetDel(t *testing.T) {
} }
} }
for i := 0; i < total; i++ { for i := 0; i < total; i++ {
var v int var val int
assert.Nil(t, c.Get(fmt.Sprintf("key/%d", i), &v)) assert.Nil(t, c.Get(fmt.Sprintf("key/%d", i), &val))
assert.Equal(t, i, v) assert.Equal(t, i, val)
} }
assert.Nil(t, c.Del()) assert.Nil(t, c.Del())
for i := 0; i < total; i++ { for i := 0; i < total; i++ {
assert.Nil(t, c.Del(fmt.Sprintf("key/%d", i))) assert.Nil(t, c.Del(fmt.Sprintf("key/%d", i)))
} }
for i := 0; i < total; i++ { for i := 0; i < total; i++ {
var v int var val int
assert.True(t, c.IsNotFound(c.Get(fmt.Sprintf("key/%d", i), &v))) assert.True(t, c.IsNotFound(c.Get(fmt.Sprintf("key/%d", i), &val)))
assert.Equal(t, 0, v) assert.Equal(t, 0, val)
} }
} }
@@ -151,18 +179,18 @@ func TestCache_OneNode(t *testing.T) {
} }
} }
for i := 0; i < total; i++ { for i := 0; i < total; i++ {
var v int var val int
assert.Nil(t, c.Get(fmt.Sprintf("key/%d", i), &v)) assert.Nil(t, c.Get(fmt.Sprintf("key/%d", i), &val))
assert.Equal(t, i, v) assert.Equal(t, i, val)
} }
assert.Nil(t, c.Del()) assert.Nil(t, c.Del())
for i := 0; i < total; i++ { for i := 0; i < total; i++ {
assert.Nil(t, c.Del(fmt.Sprintf("key/%d", i))) assert.Nil(t, c.Del(fmt.Sprintf("key/%d", i)))
} }
for i := 0; i < total; i++ { for i := 0; i < total; i++ {
var v int var val int
assert.True(t, c.IsNotFound(c.Get(fmt.Sprintf("key/%d", i), &v))) assert.True(t, c.IsNotFound(c.Get(fmt.Sprintf("key/%d", i), &val)))
assert.Equal(t, 0, v) assert.Equal(t, 0, val)
} }
} }
@@ -202,9 +230,9 @@ func TestCache_Balance(t *testing.T) {
assert.True(t, entropy > .95, fmt.Sprintf("entropy should be greater than 0.95, but got %.2f", entropy)) assert.True(t, entropy > .95, fmt.Sprintf("entropy should be greater than 0.95, but got %.2f", entropy))
for i := 0; i < total; i++ { for i := 0; i < total; i++ {
var v int var val int
assert.Nil(t, c.Get(strconv.Itoa(i), &v)) assert.Nil(t, c.Get(strconv.Itoa(i), &val))
assert.Equal(t, i, v) assert.Equal(t, i, val)
} }
for i := 0; i < total/10; i++ { for i := 0; i < total/10; i++ {
@@ -216,14 +244,14 @@ func TestCache_Balance(t *testing.T) {
for i := 0; i < total/10; i++ { for i := 0; i < total/10; i++ {
var val int var val int
if i%2 == 0 { if i%2 == 0 {
assert.Nil(t, c.Take(&val, strconv.Itoa(i*10), func(v interface{}) error { assert.Nil(t, c.Take(&val, strconv.Itoa(i*10), func(val interface{}) error {
*v.(*int) = i *val.(*int) = i
count++ count++
return nil return nil
})) }))
} else { } else {
assert.Nil(t, c.TakeWithExpire(&val, strconv.Itoa(i*10), func(v interface{}, expire time.Duration) error { assert.Nil(t, c.TakeWithExpire(&val, strconv.Itoa(i*10), func(val interface{}, expire time.Duration) error {
*v.(*int) = i *val.(*int) = i
count++ count++
return nil return nil
})) }))
@@ -244,10 +272,10 @@ func TestCacheNoNode(t *testing.T) {
assert.NotNil(t, c.Get("foo", nil)) assert.NotNil(t, c.Get("foo", nil))
assert.NotNil(t, c.Set("foo", nil)) assert.NotNil(t, c.Set("foo", nil))
assert.NotNil(t, c.SetWithExpire("foo", nil, time.Second)) assert.NotNil(t, c.SetWithExpire("foo", nil, time.Second))
assert.NotNil(t, c.Take(nil, "foo", func(v interface{}) error { assert.NotNil(t, c.Take(nil, "foo", func(val interface{}) error {
return nil return nil
})) }))
assert.NotNil(t, c.TakeWithExpire(nil, "foo", func(v interface{}, duration time.Duration) error { assert.NotNil(t, c.TakeWithExpire(nil, "foo", func(val interface{}, duration time.Duration) error {
return nil return nil
})) }))
} }
@@ -255,8 +283,8 @@ func TestCacheNoNode(t *testing.T) {
func calcEntropy(m map[int]int, total int) float64 { func calcEntropy(m map[int]int, total int) float64 {
var entropy float64 var entropy float64
for _, v := range m { for _, val := range m {
proba := float64(v) / float64(total) proba := float64(val) / float64(total)
entropy -= proba * math.Log2(proba) entropy -= proba * math.Log2(proba)
} }

View File

@@ -1,6 +1,7 @@
package cache package cache
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"math/rand" "math/rand"
@@ -61,30 +62,39 @@ func NewNode(rds *redis.Redis, barrier syncx.SingleFlight, st *Stat,
// Del deletes cached values with keys. // Del deletes cached values with keys.
func (c cacheNode) Del(keys ...string) error { func (c cacheNode) Del(keys ...string) error {
return c.DelCtx(context.Background(), keys...)
}
// DelCtx deletes cached values with keys.
func (c cacheNode) DelCtx(ctx context.Context, keys ...string) error {
if len(keys) == 0 { if len(keys) == 0 {
return nil return nil
} }
logger := logx.WithContext(ctx)
if len(keys) > 1 && c.rds.Type == redis.ClusterType { if len(keys) > 1 && c.rds.Type == redis.ClusterType {
for _, key := range keys { for _, key := range keys {
if _, err := c.rds.Del(key); err != nil { if _, err := c.rds.DelCtx(ctx, key); err != nil {
logx.Errorf("failed to clear cache with key: %q, error: %v", key, err) logger.Errorf("failed to clear cache with key: %q, error: %v", key, err)
c.asyncRetryDelCache(key) c.asyncRetryDelCache(key)
} }
} }
} else { } else if _, err := c.rds.DelCtx(ctx, keys...); err != nil {
if _, err := c.rds.Del(keys...); err != nil { logger.Errorf("failed to clear cache with keys: %q, error: %v", formatKeys(keys), err)
logx.Errorf("failed to clear cache with keys: %q, error: %v", formatKeys(keys), err) c.asyncRetryDelCache(keys...)
c.asyncRetryDelCache(keys...)
}
} }
return nil return nil
} }
// Get gets the cache with key and fills into v. // Get gets the cache with key and fills into v.
func (c cacheNode) Get(key string, v interface{}) error { func (c cacheNode) Get(key string, val interface{}) error {
err := c.doGetCache(key, v) return c.GetCtx(context.Background(), key, val)
}
// GetCtx gets the cache with key and fills into v.
func (c cacheNode) GetCtx(ctx context.Context, key string, val interface{}) error {
err := c.doGetCache(ctx, key, val)
if err == errPlaceholder { if err == errPlaceholder {
return c.errNotFound return c.errNotFound
} }
@@ -94,22 +104,32 @@ func (c cacheNode) Get(key string, v interface{}) error {
// IsNotFound checks if the given error is the defined errNotFound. // IsNotFound checks if the given error is the defined errNotFound.
func (c cacheNode) IsNotFound(err error) bool { func (c cacheNode) IsNotFound(err error) bool {
return err == c.errNotFound return errors.Is(err, c.errNotFound)
} }
// Set sets the cache with key and v, using c.expiry. // Set sets the cache with key and v, using c.expiry.
func (c cacheNode) Set(key string, v interface{}) error { func (c cacheNode) Set(key string, val interface{}) error {
return c.SetWithExpire(key, v, c.aroundDuration(c.expiry)) return c.SetCtx(context.Background(), key, val)
}
// SetCtx sets the cache with key and v, using c.expiry.
func (c cacheNode) SetCtx(ctx context.Context, key string, val interface{}) error {
return c.SetWithExpireCtx(ctx, key, val, c.aroundDuration(c.expiry))
} }
// SetWithExpire sets the cache with key and v, using given expire. // SetWithExpire sets the cache with key and v, using given expire.
func (c cacheNode) SetWithExpire(key string, v interface{}, expire time.Duration) error { func (c cacheNode) SetWithExpire(key string, val interface{}, expire time.Duration) error {
data, err := jsonx.Marshal(v) return c.SetWithExpireCtx(context.Background(), key, val, expire)
}
// SetWithExpireCtx sets the cache with key and v, using given expire.
func (c cacheNode) SetWithExpireCtx(ctx context.Context, key string, val interface{}, expire time.Duration) error {
data, err := jsonx.Marshal(val)
if err != nil { if err != nil {
return err return err
} }
return c.rds.Setex(key, string(data), int(expire.Seconds())) return c.rds.SetexCtx(ctx, key, string(data), int(expire.Seconds()))
} }
// String returns a string that represents the cacheNode. // String returns a string that represents the cacheNode.
@@ -119,21 +139,32 @@ func (c cacheNode) String() string {
// Take takes the result from cache first, if not found, // Take takes the result from cache first, if not found,
// query from DB and set cache using c.expiry, then return the result. // query from DB and set cache using c.expiry, then return the result.
func (c cacheNode) Take(v interface{}, key string, query func(v interface{}) error) error { func (c cacheNode) Take(val interface{}, key string, query func(val interface{}) error) error {
return c.doTake(v, key, query, func(v interface{}) error { return c.TakeCtx(context.Background(), val, key, query)
return c.Set(key, v) }
// TakeCtx takes the result from cache first, if not found,
// query from DB and set cache using c.expiry, then return the result.
func (c cacheNode) TakeCtx(ctx context.Context, val interface{}, key string, query func(val interface{}) error) error {
return c.doTake(ctx, val, key, query, func(v interface{}) error {
return c.SetCtx(ctx, key, v)
}) })
} }
// TakeWithExpire takes the result from cache first, if not found, // TakeWithExpire takes the result from cache first, if not found,
// query from DB and set cache using given expire, then return the result. // query from DB and set cache using given expire, then return the result.
func (c cacheNode) TakeWithExpire(v interface{}, key string, query func(v interface{}, func (c cacheNode) TakeWithExpire(val interface{}, key string, query func(val interface{}, expire time.Duration) error) error {
expire time.Duration) error) error { return c.TakeWithExpireCtx(context.Background(), val, key, query)
}
// TakeWithExpireCtx takes the result from cache first, if not found,
// query from DB and set cache using given expire, then return the result.
func (c cacheNode) TakeWithExpireCtx(ctx context.Context, val interface{}, key string, query func(val interface{}, expire time.Duration) error) error {
expire := c.aroundDuration(c.expiry) expire := c.aroundDuration(c.expiry)
return c.doTake(v, key, func(v interface{}) error { return c.doTake(ctx, val, key, func(v interface{}) error {
return query(v, expire) return query(v, expire)
}, func(v interface{}) error { }, func(v interface{}) error {
return c.SetWithExpire(key, v, expire) return c.SetWithExpireCtx(ctx, key, v, expire)
}) })
} }
@@ -148,9 +179,9 @@ func (c cacheNode) asyncRetryDelCache(keys ...string) {
}, keys...) }, keys...)
} }
func (c cacheNode) doGetCache(key string, v interface{}) error { func (c cacheNode) doGetCache(ctx context.Context, key string, v interface{}) error {
c.stat.IncrementTotal() c.stat.IncrementTotal()
data, err := c.rds.Get(key) data, err := c.rds.GetCtx(ctx, key)
if err != nil { if err != nil {
c.stat.IncrementMiss() c.stat.IncrementMiss()
return err return err
@@ -166,13 +197,14 @@ func (c cacheNode) doGetCache(key string, v interface{}) error {
return errPlaceholder return errPlaceholder
} }
return c.processCache(key, data, v) return c.processCache(ctx, key, data, v)
} }
func (c cacheNode) doTake(v interface{}, key string, query func(v interface{}) error, func (c cacheNode) doTake(ctx context.Context, v interface{}, key string,
cacheVal func(v interface{}) error) error { query func(v interface{}) error, cacheVal func(v interface{}) error) error {
logger := logx.WithContext(ctx)
val, fresh, err := c.barrier.DoEx(key, func() (interface{}, error) { val, fresh, err := c.barrier.DoEx(key, func() (interface{}, error) {
if err := c.doGetCache(key, v); err != nil { if err := c.doGetCache(ctx, key, v); err != nil {
if err == errPlaceholder { if err == errPlaceholder {
return nil, c.errNotFound return nil, c.errNotFound
} else if err != c.errNotFound { } else if err != c.errNotFound {
@@ -183,8 +215,8 @@ func (c cacheNode) doTake(v interface{}, key string, query func(v interface{}) e
} }
if err = query(v); err == c.errNotFound { if err = query(v); err == c.errNotFound {
if err = c.setCacheWithNotFound(key); err != nil { if err = c.setCacheWithNotFound(ctx, key); err != nil {
logx.Error(err) logger.Error(err)
} }
return nil, c.errNotFound return nil, c.errNotFound
@@ -194,7 +226,7 @@ func (c cacheNode) doTake(v interface{}, key string, query func(v interface{}) e
} }
if err = cacheVal(v); err != nil { if err = cacheVal(v); err != nil {
logx.Error(err) logger.Error(err)
} }
} }
@@ -214,7 +246,7 @@ func (c cacheNode) doTake(v interface{}, key string, query func(v interface{}) e
return jsonx.Unmarshal(val.([]byte), v) return jsonx.Unmarshal(val.([]byte), v)
} }
func (c cacheNode) processCache(key, data string, v interface{}) error { func (c cacheNode) processCache(ctx context.Context, key, data string, v interface{}) error {
err := jsonx.Unmarshal([]byte(data), v) err := jsonx.Unmarshal([]byte(data), v)
if err == nil { if err == nil {
return nil return nil
@@ -222,10 +254,11 @@ func (c cacheNode) processCache(key, data string, v interface{}) error {
report := fmt.Sprintf("unmarshal cache, node: %s, key: %s, value: %s, error: %v", report := fmt.Sprintf("unmarshal cache, node: %s, key: %s, value: %s, error: %v",
c.rds.Addr, key, data, err) c.rds.Addr, key, data, err)
logx.Error(report) logger := logx.WithContext(ctx)
logger.Error(report)
stat.Report(report) stat.Report(report)
if _, e := c.rds.Del(key); e != nil { if _, e := c.rds.DelCtx(ctx, key); e != nil {
logx.Errorf("delete invalid cache, node: %s, key: %s, value: %s, error: %v", logger.Errorf("delete invalid cache, node: %s, key: %s, value: %s, error: %v",
c.rds.Addr, key, data, e) c.rds.Addr, key, data, e)
} }
@@ -233,6 +266,6 @@ func (c cacheNode) processCache(key, data string, v interface{}) error {
return c.errNotFound return c.errNotFound
} }
func (c cacheNode) setCacheWithNotFound(key string) error { func (c cacheNode) setCacheWithNotFound(ctx context.Context, key string) error {
return c.rds.Setex(key, notFoundPlaceholder, int(c.aroundDuration(c.notFoundExpiry).Seconds())) return c.rds.SetexCtx(ctx, key, notFoundPlaceholder, int(c.aroundDuration(c.notFoundExpiry).Seconds()))
} }

File diff suppressed because it is too large Load Diff

View File

@@ -490,6 +490,29 @@ func TestRedis_SetExNx(t *testing.T) {
}) })
} }
func TestRedis_Getset(t *testing.T) {
store := clusterStore{dispatcher: hash.NewConsistentHash()}
_, err := store.GetSet("hello", "world")
assert.NotNil(t, err)
runOnCluster(t, func(client Store) {
val, err := client.GetSet("hello", "world")
assert.Nil(t, err)
assert.Equal(t, "", val)
val, err = client.Get("hello")
assert.Nil(t, err)
assert.Equal(t, "world", val)
val, err = client.GetSet("hello", "newworld")
assert.Nil(t, err)
assert.Equal(t, "world", val)
val, err = client.Get("hello")
assert.Nil(t, err)
assert.Equal(t, "newworld", val)
_, err = client.Del("hello")
assert.Nil(t, err)
})
}
func TestRedis_SetGetDelHashField(t *testing.T) { func TestRedis_SetGetDelHashField(t *testing.T) {
store := clusterStore{dispatcher: hash.NewConsistentHash()} store := clusterStore{dispatcher: hash.NewConsistentHash()}
err := store.Hset("key", "field", "value") err := store.Hset("key", "field", "value")

View File

@@ -0,0 +1,91 @@
package mon
import (
"context"
"time"
"github.com/zeromicro/go-zero/core/executors"
"github.com/zeromicro/go-zero/core/logx"
"go.mongodb.org/mongo-driver/mongo"
)
const (
flushInterval = time.Second
maxBulkRows = 1000
)
type (
// ResultHandler is a handler that used to handle results.
ResultHandler func(*mongo.InsertManyResult, error)
// A BulkInserter is used to insert bulk of mongo records.
BulkInserter struct {
executor *executors.PeriodicalExecutor
inserter *dbInserter
}
)
// NewBulkInserter returns a BulkInserter.
func NewBulkInserter(coll *mongo.Collection, interval ...time.Duration) *BulkInserter {
inserter := &dbInserter{
collection: coll,
}
duration := flushInterval
if len(interval) > 0 {
duration = interval[0]
}
return &BulkInserter{
executor: executors.NewPeriodicalExecutor(duration, inserter),
inserter: inserter,
}
}
// Flush flushes the inserter, writes all pending records.
func (bi *BulkInserter) Flush() {
bi.executor.Flush()
}
// Insert inserts doc.
func (bi *BulkInserter) Insert(doc interface{}) {
bi.executor.Add(doc)
}
// SetResultHandler sets the result handler.
func (bi *BulkInserter) SetResultHandler(handler ResultHandler) {
bi.executor.Sync(func() {
bi.inserter.resultHandler = handler
})
}
type dbInserter struct {
collection *mongo.Collection
documents []interface{}
resultHandler ResultHandler
}
func (in *dbInserter) AddTask(doc interface{}) bool {
in.documents = append(in.documents, doc)
return len(in.documents) >= maxBulkRows
}
func (in *dbInserter) Execute(objs interface{}) {
docs := objs.([]interface{})
if len(docs) == 0 {
return
}
result, err := in.collection.InsertMany(context.Background(), docs)
if in.resultHandler != nil {
in.resultHandler(result, err)
} else if err != nil {
logx.Error(err)
}
}
func (in *dbInserter) RemoveAll() interface{} {
documents := in.documents
in.documents = nil
return documents
}

View File

@@ -0,0 +1,27 @@
package mon
import (
"testing"
"github.com/stretchr/testify/assert"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/integration/mtest"
)
func TestBulkInserter(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "ok", Value: 1}}...))
bulk := NewBulkInserter(mt.Coll)
bulk.SetResultHandler(func(result *mongo.InsertManyResult, err error) {
assert.Nil(t, err)
assert.Equal(t, 2, len(result.InsertedIDs))
})
bulk.Insert(bson.D{{Key: "foo", Value: "bar"}})
bulk.Insert(bson.D{{Key: "foo", Value: "baz"}})
bulk.Flush()
})
}

View File

@@ -0,0 +1,51 @@
package mon
import (
"context"
"io"
"time"
"github.com/zeromicro/go-zero/core/syncx"
"go.mongodb.org/mongo-driver/mongo"
mopt "go.mongodb.org/mongo-driver/mongo/options"
)
const defaultTimeout = time.Second
var clientManager = syncx.NewResourceManager()
// ClosableClient wraps *mongo.Client and provides a Close method.
type ClosableClient struct {
*mongo.Client
}
// Close disconnects the underlying *mongo.Client.
func (cs *ClosableClient) Close() error {
return cs.Client.Disconnect(context.Background())
}
// Inject injects a *mongo.Client into the client manager.
// Typically, this is used to inject a *mongo.Client for test purpose.
func Inject(key string, client *mongo.Client) {
clientManager.Inject(key, &ClosableClient{client})
}
func getClient(url string) (*mongo.Client, error) {
val, err := clientManager.GetResource(url, func() (io.Closer, error) {
cli, err := mongo.Connect(context.Background(), mopt.Client().ApplyURI(url))
if err != nil {
return nil, err
}
concurrentSess := &ClosableClient{
Client: cli,
}
return concurrentSess, nil
})
if err != nil {
return nil, err
}
return val.(*ClosableClient).Client, nil
}

View File

@@ -0,0 +1,20 @@
package mon
import (
"testing"
"github.com/stretchr/testify/assert"
"go.mongodb.org/mongo-driver/mongo/integration/mtest"
)
func TestClientManger_getClient(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
Inject(mtest.ClusterURI(), mt.Client)
cli, err := getClient(mtest.ClusterURI())
assert.Nil(t, err)
assert.Equal(t, mt.Client, cli)
})
}

View File

@@ -0,0 +1,490 @@
package mon
import (
"context"
"encoding/json"
"time"
"github.com/zeromicro/go-zero/core/breaker"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/timex"
"github.com/zeromicro/go-zero/core/trace"
"go.mongodb.org/mongo-driver/mongo"
mopt "go.mongodb.org/mongo-driver/mongo/options"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.opentelemetry.io/otel"
tracesdk "go.opentelemetry.io/otel/trace"
)
const (
defaultSlowThreshold = time.Millisecond * 500
// spanName is the span name of the mongo calls.
spanName = "mongo"
)
// ErrNotFound is an alias of mongo.ErrNoDocuments
var ErrNotFound = mongo.ErrNoDocuments
type (
// Collection defines a MongoDB collection.
Collection interface {
// Aggregate executes an aggregation pipeline.
Aggregate(ctx context.Context, pipeline interface{}, opts ...*mopt.AggregateOptions) (
*mongo.Cursor, error)
// BulkWrite performs a bulk write operation.
BulkWrite(ctx context.Context, models []mongo.WriteModel, opts ...*mopt.BulkWriteOptions) (
*mongo.BulkWriteResult, error)
// Clone creates a copy of this collection with the same settings.
Clone(opts ...*mopt.CollectionOptions) (*mongo.Collection, error)
// CountDocuments returns the number of documents in the collection that match the filter.
CountDocuments(ctx context.Context, filter interface{}, opts ...*mopt.CountOptions) (int64, error)
// Database returns the database that this collection is a part of.
Database() *mongo.Database
// DeleteMany deletes documents from the collection that match the filter.
DeleteMany(ctx context.Context, filter interface{}, opts ...*mopt.DeleteOptions) (
*mongo.DeleteResult, error)
// DeleteOne deletes at most one document from the collection that matches the filter.
DeleteOne(ctx context.Context, filter interface{}, opts ...*mopt.DeleteOptions) (
*mongo.DeleteResult, error)
// Distinct returns a list of distinct values for the given key across the collection.
Distinct(ctx context.Context, fieldName string, filter interface{},
opts ...*mopt.DistinctOptions) ([]interface{}, error)
// Drop drops this collection from database.
Drop(ctx context.Context) error
// EstimatedDocumentCount returns an estimate of the count of documents in a collection
// using collection metadata.
EstimatedDocumentCount(ctx context.Context, opts ...*mopt.EstimatedDocumentCountOptions) (int64, error)
// Find finds the documents matching the provided filter.
Find(ctx context.Context, filter interface{}, opts ...*mopt.FindOptions) (*mongo.Cursor, error)
// FindOne returns up to one document that matches the provided filter.
FindOne(ctx context.Context, filter interface{}, opts ...*mopt.FindOneOptions) (
*mongo.SingleResult, error)
// FindOneAndDelete returns at most one document that matches the filter. If the filter
// matches multiple documents, only the first document is deleted.
FindOneAndDelete(ctx context.Context, filter interface{}, opts ...*mopt.FindOneAndDeleteOptions) (
*mongo.SingleResult, error)
// FindOneAndReplace returns at most one document that matches the filter. If the filter
// matches multiple documents, FindOneAndReplace returns the first document in the
// collection that matches the filter.
FindOneAndReplace(ctx context.Context, filter interface{}, replacement interface{},
opts ...*mopt.FindOneAndReplaceOptions) (*mongo.SingleResult, error)
// FindOneAndUpdate returns at most one document that matches the filter. If the filter
// matches multiple documents, FindOneAndUpdate returns the first document in the
// collection that matches the filter.
FindOneAndUpdate(ctx context.Context, filter interface{}, update interface{},
opts ...*mopt.FindOneAndUpdateOptions) (*mongo.SingleResult, error)
// Indexes returns the index view for this collection.
Indexes() mongo.IndexView
// InsertMany inserts the provided documents.
InsertMany(ctx context.Context, documents []interface{}, opts ...*mopt.InsertManyOptions) (
*mongo.InsertManyResult, error)
// InsertOne inserts the provided document.
InsertOne(ctx context.Context, document interface{}, opts ...*mopt.InsertOneOptions) (
*mongo.InsertOneResult, error)
// ReplaceOne replaces at most one document that matches the filter.
ReplaceOne(ctx context.Context, filter interface{}, replacement interface{},
opts ...*mopt.ReplaceOptions) (*mongo.UpdateResult, error)
// UpdateByID updates a single document matching the provided filter.
UpdateByID(ctx context.Context, id interface{}, update interface{},
opts ...*mopt.UpdateOptions) (*mongo.UpdateResult, error)
// UpdateMany updates the provided documents.
UpdateMany(ctx context.Context, filter interface{}, update interface{},
opts ...*mopt.UpdateOptions) (*mongo.UpdateResult, error)
// UpdateOne updates a single document matching the provided filter.
UpdateOne(ctx context.Context, filter interface{}, update interface{},
opts ...*mopt.UpdateOptions) (*mongo.UpdateResult, error)
// Watch returns a change stream cursor used to receive notifications of changes to the collection.
Watch(ctx context.Context, pipeline interface{}, opts ...*mopt.ChangeStreamOptions) (
*mongo.ChangeStream, error)
}
decoratedCollection struct {
*mongo.Collection
name string
brk breaker.Breaker
}
keepablePromise struct {
promise breaker.Promise
log func(error)
}
)
func newCollection(collection *mongo.Collection, brk breaker.Breaker) Collection {
return &decoratedCollection{
Collection: collection,
name: collection.Name(),
brk: brk,
}
}
func (c *decoratedCollection) Aggregate(ctx context.Context, pipeline interface{},
opts ...*mopt.AggregateOptions) (cur *mongo.Cursor, err error) {
ctx, span := startSpan(ctx)
defer span.End()
err = c.brk.DoWithAcceptable(func() error {
starTime := timex.Now()
defer func() {
c.logDurationSimple("Aggregate", starTime, err)
}()
cur, err = c.Collection.Aggregate(ctx, pipeline, opts...)
return err
}, acceptable)
return
}
func (c *decoratedCollection) BulkWrite(ctx context.Context, models []mongo.WriteModel,
opts ...*mopt.BulkWriteOptions) (res *mongo.BulkWriteResult, err error) {
ctx, span := startSpan(ctx)
defer span.End()
err = c.brk.DoWithAcceptable(func() error {
startTime := timex.Now()
defer func() {
c.logDurationSimple("BulkWrite", startTime, err)
}()
res, err = c.Collection.BulkWrite(ctx, models, opts...)
return err
}, acceptable)
return
}
func (c *decoratedCollection) CountDocuments(ctx context.Context, filter interface{},
opts ...*mopt.CountOptions) (count int64, err error) {
ctx, span := startSpan(ctx)
defer span.End()
err = c.brk.DoWithAcceptable(func() error {
startTime := timex.Now()
defer func() {
c.logDurationSimple("CountDocuments", startTime, err)
}()
count, err = c.Collection.CountDocuments(ctx, filter, opts...)
return err
}, acceptable)
return
}
func (c *decoratedCollection) DeleteMany(ctx context.Context, filter interface{},
opts ...*mopt.DeleteOptions) (res *mongo.DeleteResult, err error) {
ctx, span := startSpan(ctx)
defer span.End()
err = c.brk.DoWithAcceptable(func() error {
startTime := timex.Now()
defer func() {
c.logDurationSimple("DeleteMany", startTime, err)
}()
res, err = c.Collection.DeleteMany(ctx, filter, opts...)
return err
}, acceptable)
return
}
func (c *decoratedCollection) DeleteOne(ctx context.Context, filter interface{},
opts ...*mopt.DeleteOptions) (res *mongo.DeleteResult, err error) {
ctx, span := startSpan(ctx)
defer span.End()
err = c.brk.DoWithAcceptable(func() error {
startTime := timex.Now()
defer func() {
c.logDuration("DeleteOne", startTime, err, filter)
}()
res, err = c.Collection.DeleteOne(ctx, filter, opts...)
return err
}, acceptable)
return
}
func (c *decoratedCollection) Distinct(ctx context.Context, fieldName string, filter interface{},
opts ...*mopt.DistinctOptions) (val []interface{}, err error) {
ctx, span := startSpan(ctx)
defer span.End()
err = c.brk.DoWithAcceptable(func() error {
startTime := timex.Now()
defer func() {
c.logDurationSimple("Distinct", startTime, err)
}()
val, err = c.Collection.Distinct(ctx, fieldName, filter, opts...)
return err
}, acceptable)
return
}
func (c *decoratedCollection) EstimatedDocumentCount(ctx context.Context,
opts ...*mopt.EstimatedDocumentCountOptions) (val int64, err error) {
ctx, span := startSpan(ctx)
defer span.End()
err = c.brk.DoWithAcceptable(func() error {
startTime := timex.Now()
defer func() {
c.logDurationSimple("EstimatedDocumentCount", startTime, err)
}()
val, err = c.Collection.EstimatedDocumentCount(ctx, opts...)
return err
}, acceptable)
return
}
func (c *decoratedCollection) Find(ctx context.Context, filter interface{},
opts ...*mopt.FindOptions) (cur *mongo.Cursor, err error) {
ctx, span := startSpan(ctx)
defer span.End()
err = c.brk.DoWithAcceptable(func() error {
startTime := timex.Now()
defer func() {
c.logDuration("Find", startTime, err, filter)
}()
cur, err = c.Collection.Find(ctx, filter, opts...)
return err
}, acceptable)
return
}
func (c *decoratedCollection) FindOne(ctx context.Context, filter interface{},
opts ...*mopt.FindOneOptions) (res *mongo.SingleResult, err error) {
ctx, span := startSpan(ctx)
defer span.End()
err = c.brk.DoWithAcceptable(func() error {
startTime := timex.Now()
defer func() {
c.logDuration("FindOne", startTime, err, filter)
}()
res = c.Collection.FindOne(ctx, filter, opts...)
err = res.Err()
return err
}, acceptable)
return
}
func (c *decoratedCollection) FindOneAndDelete(ctx context.Context, filter interface{},
opts ...*mopt.FindOneAndDeleteOptions) (res *mongo.SingleResult, err error) {
ctx, span := startSpan(ctx)
defer span.End()
err = c.brk.DoWithAcceptable(func() error {
startTime := timex.Now()
defer func() {
c.logDuration("FindOneAndDelete", startTime, err, filter)
}()
res = c.Collection.FindOneAndDelete(ctx, filter, opts...)
err = res.Err()
return err
}, acceptable)
return
}
func (c *decoratedCollection) FindOneAndReplace(ctx context.Context, filter interface{},
replacement interface{}, opts ...*mopt.FindOneAndReplaceOptions) (
res *mongo.SingleResult, err error) {
ctx, span := startSpan(ctx)
defer span.End()
err = c.brk.DoWithAcceptable(func() error {
startTime := timex.Now()
defer func() {
c.logDuration("FindOneAndReplace", startTime, err, filter, replacement)
}()
res = c.Collection.FindOneAndReplace(ctx, filter, replacement, opts...)
err = res.Err()
return err
}, acceptable)
return
}
func (c *decoratedCollection) FindOneAndUpdate(ctx context.Context, filter interface{}, update interface{},
opts ...*mopt.FindOneAndUpdateOptions) (res *mongo.SingleResult, err error) {
ctx, span := startSpan(ctx)
defer span.End()
err = c.brk.DoWithAcceptable(func() error {
startTime := timex.Now()
defer func() {
c.logDuration("FindOneAndUpdate", startTime, err, filter, update)
}()
res = c.Collection.FindOneAndUpdate(ctx, filter, update, opts...)
err = res.Err()
return err
}, acceptable)
return
}
func (c *decoratedCollection) InsertMany(ctx context.Context, documents []interface{},
opts ...*mopt.InsertManyOptions) (res *mongo.InsertManyResult, err error) {
ctx, span := startSpan(ctx)
defer span.End()
err = c.brk.DoWithAcceptable(func() error {
startTime := timex.Now()
defer func() {
c.logDurationSimple("InsertMany", startTime, err)
}()
res, err = c.Collection.InsertMany(ctx, documents, opts...)
return err
}, acceptable)
return
}
func (c *decoratedCollection) InsertOne(ctx context.Context, document interface{},
opts ...*mopt.InsertOneOptions) (res *mongo.InsertOneResult, err error) {
ctx, span := startSpan(ctx)
defer span.End()
err = c.brk.DoWithAcceptable(func() error {
startTime := timex.Now()
defer func() {
c.logDuration("InsertOne", startTime, err, document)
}()
res, err = c.Collection.InsertOne(ctx, document, opts...)
return err
}, acceptable)
return
}
func (c *decoratedCollection) ReplaceOne(ctx context.Context, filter interface{}, replacement interface{},
opts ...*mopt.ReplaceOptions) (res *mongo.UpdateResult, err error) {
ctx, span := startSpan(ctx)
defer span.End()
err = c.brk.DoWithAcceptable(func() error {
startTime := timex.Now()
defer func() {
c.logDuration("ReplaceOne", startTime, err, filter, replacement)
}()
res, err = c.Collection.ReplaceOne(ctx, filter, replacement, opts...)
return err
}, acceptable)
return
}
func (c *decoratedCollection) UpdateByID(ctx context.Context, id interface{}, update interface{},
opts ...*mopt.UpdateOptions) (res *mongo.UpdateResult, err error) {
ctx, span := startSpan(ctx)
defer span.End()
err = c.brk.DoWithAcceptable(func() error {
startTime := timex.Now()
defer func() {
c.logDuration("UpdateByID", startTime, err, id, update)
}()
res, err = c.Collection.UpdateByID(ctx, id, update, opts...)
return err
}, acceptable)
return
}
func (c *decoratedCollection) UpdateMany(ctx context.Context, filter interface{}, update interface{},
opts ...*mopt.UpdateOptions) (res *mongo.UpdateResult, err error) {
ctx, span := startSpan(ctx)
defer span.End()
err = c.brk.DoWithAcceptable(func() error {
startTime := timex.Now()
defer func() {
c.logDurationSimple("UpdateMany", startTime, err)
}()
res, err = c.Collection.UpdateMany(ctx, filter, update, opts...)
return err
}, acceptable)
return
}
func (c *decoratedCollection) UpdateOne(ctx context.Context, filter interface{}, update interface{},
opts ...*mopt.UpdateOptions) (res *mongo.UpdateResult, err error) {
ctx, span := startSpan(ctx)
defer span.End()
err = c.brk.DoWithAcceptable(func() error {
startTime := timex.Now()
defer func() {
c.logDuration("UpdateOne", startTime, err, filter, update)
}()
res, err = c.Collection.UpdateOne(ctx, filter, update, opts...)
return err
}, acceptable)
return
}
func (c *decoratedCollection) logDuration(method string, startTime time.Duration, err error,
docs ...interface{}) {
duration := timex.Since(startTime)
content, e := json.Marshal(docs)
if e != nil {
logx.Error(err)
} else if err != nil {
if duration > slowThreshold.Load() {
logx.WithDuration(duration).Slowf("[MONGO] mongo(%s) - slowcall - %s - fail(%s) - %s",
c.name, method, err.Error(), string(content))
} else {
logx.WithDuration(duration).Infof("mongo(%s) - %s - fail(%s) - %s",
c.name, method, err.Error(), string(content))
}
} else {
if duration > slowThreshold.Load() {
logx.WithDuration(duration).Slowf("[MONGO] mongo(%s) - slowcall - %s - ok - %s",
c.name, method, string(content))
} else {
logx.WithDuration(duration).Infof("mongo(%s) - %s - ok - %s", c.name, method, string(content))
}
}
}
func (c *decoratedCollection) logDurationSimple(method string, startTime time.Duration, err error) {
logDuration(c.name, method, startTime, err)
}
func (p keepablePromise) accept(err error) error {
p.promise.Accept()
p.log(err)
return err
}
func (p keepablePromise) keep(err error) error {
if acceptable(err) {
p.promise.Accept()
} else {
p.promise.Reject(err.Error())
}
p.log(err)
return err
}
func acceptable(err error) bool {
return err == nil || err == mongo.ErrNoDocuments || err == mongo.ErrNilValue ||
err == mongo.ErrNilDocument || err == mongo.ErrNilCursor || err == mongo.ErrEmptySlice ||
// session errors
err == session.ErrSessionEnded || err == session.ErrNoTransactStarted ||
err == session.ErrTransactInProgress || err == session.ErrAbortAfterCommit ||
err == session.ErrAbortTwice || err == session.ErrCommitAfterAbort ||
err == session.ErrUnackWCUnsupported || err == session.ErrSnapshotTransaction
}
func startSpan(ctx context.Context) (context.Context, tracesdk.Span) {
tracer := otel.GetTracerProvider().Tracer(trace.TraceName)
return tracer.Start(ctx, spanName)
}

View File

@@ -0,0 +1,608 @@
package mon
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/breaker"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stringx"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/integration/mtest"
mopt "go.mongodb.org/mongo-driver/mongo/options"
)
var errDummy = errors.New("dummy")
func init() {
logx.Disable()
}
func TestKeepPromise_accept(t *testing.T) {
p := new(mockPromise)
kp := keepablePromise{
promise: p,
log: func(error) {},
}
assert.Nil(t, kp.accept(nil))
assert.Equal(t, ErrNotFound, kp.accept(ErrNotFound))
}
func TestKeepPromise_keep(t *testing.T) {
tests := []struct {
err error
accepted bool
reason string
}{
{
err: nil,
accepted: true,
reason: "",
},
{
err: ErrNotFound,
accepted: true,
reason: "",
},
{
err: errors.New("any"),
accepted: false,
reason: "any",
},
}
for _, test := range tests {
t.Run(stringx.RandId(), func(t *testing.T) {
p := new(mockPromise)
kp := keepablePromise{
promise: p,
log: func(error) {},
}
assert.Equal(t, test.err, kp.keep(test.err))
assert.Equal(t, test.accepted, p.accepted)
assert.Equal(t, test.reason, p.reason)
})
}
}
func TestNewCollection(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
coll := mt.Coll
assert.NotNil(t, coll)
col := newCollection(coll, breaker.GetBreaker("localhost"))
assert.Equal(t, t.Name()+"/test", col.(*decoratedCollection).name)
})
}
func TestCollection_Aggregate(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
coll := mt.Coll
assert.NotNil(t, coll)
col := newCollection(coll, breaker.GetBreaker("localhost"))
ns := mt.Coll.Database().Name() + "." + mt.Coll.Name()
aggRes := mtest.CreateCursorResponse(1, ns, mtest.FirstBatch)
mt.AddMockResponses(aggRes)
assert.Equal(t, t.Name()+"/test", col.(*decoratedCollection).name)
cursor, err := col.Aggregate(context.Background(), mongo.Pipeline{}, mopt.Aggregate())
assert.Nil(t, err)
cursor.Close(context.Background())
})
}
func TestCollection_BulkWrite(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
c := decoratedCollection{
Collection: mt.Coll,
brk: breaker.NewBreaker(),
}
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "ok", Value: 1}}...))
res, err := c.BulkWrite(context.Background(), []mongo.WriteModel{
mongo.NewInsertOneModel().SetDocument(bson.D{{Key: "foo", Value: 1}})},
)
assert.Nil(t, err)
assert.NotNil(t, res)
c.brk = new(dropBreaker)
_, err = c.BulkWrite(context.Background(), []mongo.WriteModel{
mongo.NewInsertOneModel().SetDocument(bson.D{{Key: "foo", Value: 1}})},
)
assert.Equal(t, errDummy, err)
})
}
func TestCollection_CountDocuments(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
c := decoratedCollection{
Collection: mt.Coll,
brk: breaker.NewBreaker(),
}
mt.AddMockResponses(mtest.CreateCursorResponse(
1,
"DBName.CollectionName",
mtest.FirstBatch,
bson.D{
{Key: "n", Value: 1},
}))
res, err := c.CountDocuments(context.Background(), bson.D{})
assert.Nil(t, err)
assert.Equal(t, int64(1), res)
c.brk = new(dropBreaker)
_, err = c.CountDocuments(context.Background(), bson.D{{Key: "foo", Value: 1}})
assert.Equal(t, errDummy, err)
})
}
func TestDecoratedCollection_DeleteMany(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
c := decoratedCollection{
Collection: mt.Coll,
brk: breaker.NewBreaker(),
}
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "n", Value: 1}}...))
res, err := c.DeleteMany(context.Background(), bson.D{})
assert.Nil(t, err)
assert.Equal(t, int64(1), res.DeletedCount)
c.brk = new(dropBreaker)
_, err = c.DeleteMany(context.Background(), bson.D{{Key: "foo", Value: 1}})
assert.Equal(t, errDummy, err)
})
}
func TestCollection_Distinct(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
c := decoratedCollection{
Collection: mt.Coll,
brk: breaker.NewBreaker(),
}
mt.AddMockResponses(bson.D{{Key: "ok", Value: 1}, {Key: "values", Value: []int{1}}})
resp, err := c.Distinct(context.Background(), "foo", bson.D{})
assert.Nil(t, err)
assert.Equal(t, 1, len(resp))
c.brk = new(dropBreaker)
_, err = c.Distinct(context.Background(), "foo", bson.D{{Key: "foo", Value: 1}})
assert.Equal(t, errDummy, err)
})
}
func TestCollection_EstimatedDocumentCount(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
c := decoratedCollection{
Collection: mt.Coll,
brk: breaker.NewBreaker(),
}
mt.AddMockResponses(bson.D{{Key: "ok", Value: 1}, {Key: "n", Value: 1}})
res, err := c.EstimatedDocumentCount(context.Background())
assert.Nil(t, err)
assert.Equal(t, int64(1), res)
c.brk = new(dropBreaker)
_, err = c.EstimatedDocumentCount(context.Background())
assert.Equal(t, errDummy, err)
})
}
func TestCollectionFind(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
c := decoratedCollection{
Collection: mt.Coll,
brk: breaker.NewBreaker(),
}
find := mtest.CreateCursorResponse(
1,
"DBName.CollectionName",
mtest.FirstBatch,
bson.D{
{Key: "name", Value: "John"},
})
getMore := mtest.CreateCursorResponse(
1,
"DBName.CollectionName",
mtest.NextBatch,
bson.D{
{Key: "name", Value: "Mary"},
})
killCursors := mtest.CreateCursorResponse(
0,
"DBName.CollectionName",
mtest.NextBatch)
mt.AddMockResponses(find, getMore, killCursors)
filter := bson.D{{Key: "x", Value: 1}}
cursor, err := c.Find(context.Background(), filter, mopt.Find())
assert.Nil(t, err)
defer cursor.Close(context.Background())
var val []struct {
ID primitive.ObjectID `bson:"_id"`
Name string `bson:"name"`
}
assert.Nil(t, cursor.All(context.Background(), &val))
assert.Equal(t, 2, len(val))
assert.Equal(t, "John", val[0].Name)
assert.Equal(t, "Mary", val[1].Name)
c.brk = new(dropBreaker)
_, err = c.Find(context.Background(), filter, mopt.Find())
assert.Equal(t, errDummy, err)
})
}
func TestCollectionFindOne(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
c := decoratedCollection{
Collection: mt.Coll,
brk: breaker.NewBreaker(),
}
find := mtest.CreateCursorResponse(
1,
"DBName.CollectionName",
mtest.FirstBatch,
bson.D{
{Key: "name", Value: "John"},
})
getMore := mtest.CreateCursorResponse(
1,
"DBName.CollectionName",
mtest.NextBatch,
bson.D{
{Key: "name", Value: "Mary"},
})
killCursors := mtest.CreateCursorResponse(
0,
"DBName.CollectionName",
mtest.NextBatch)
mt.AddMockResponses(find, getMore, killCursors)
filter := bson.D{{Key: "x", Value: 1}}
resp, err := c.FindOne(context.Background(), filter)
assert.Nil(t, err)
var val struct {
ID primitive.ObjectID `bson:"_id"`
Name string `bson:"name"`
}
assert.Nil(t, resp.Decode(&val))
assert.Equal(t, "John", val.Name)
c.brk = new(dropBreaker)
_, err = c.FindOne(context.Background(), filter)
assert.Equal(t, errDummy, err)
})
}
func TestCollection_FindOneAndDelete(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
c := decoratedCollection{
Collection: mt.Coll,
brk: breaker.NewBreaker(),
}
filter := bson.D{}
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{}...))
_, err := c.FindOneAndDelete(context.Background(), filter, mopt.FindOneAndDelete())
assert.Equal(t, mongo.ErrNoDocuments, err)
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "name", Value: "John"}}},
}...))
resp, err := c.FindOneAndDelete(context.Background(), filter, mopt.FindOneAndDelete())
assert.Nil(t, err)
var val struct {
Name string `bson:"name"`
}
assert.Nil(t, resp.Decode(&val))
assert.Equal(t, "John", val.Name)
c.brk = new(dropBreaker)
_, err = c.FindOneAndDelete(context.Background(), bson.D{{Key: "foo", Value: "bar"}})
assert.Equal(t, errDummy, err)
})
}
func TestCollection_FindOneAndReplace(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
c := decoratedCollection{
Collection: mt.Coll,
brk: breaker.NewBreaker(),
}
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{}...))
filter := bson.D{{Key: "x", Value: 1}}
replacement := bson.D{{Key: "x", Value: 2}}
opts := mopt.FindOneAndReplace().SetUpsert(true)
_, err := c.FindOneAndReplace(context.Background(), filter, replacement, opts)
assert.Equal(t, mongo.ErrNoDocuments, err)
mt.AddMockResponses(bson.D{{Key: "ok", Value: 1}, {Key: "value", Value: bson.D{
{Key: "name", Value: "John"},
}}})
resp, err := c.FindOneAndReplace(context.Background(), filter, replacement, opts)
assert.Nil(t, err)
var val struct {
Name string `bson:"name"`
}
assert.Nil(t, resp.Decode(&val))
assert.Equal(t, "John", val.Name)
c.brk = new(dropBreaker)
_, err = c.FindOneAndReplace(context.Background(), filter, replacement, opts)
assert.Equal(t, errDummy, err)
})
}
func TestCollection_FindOneAndUpdate(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
c := decoratedCollection{
Collection: mt.Coll,
brk: breaker.NewBreaker(),
}
mt.AddMockResponses(bson.D{{Key: "ok", Value: 1}})
filter := bson.D{{Key: "x", Value: 1}}
update := bson.D{{Key: "$x", Value: 2}}
opts := mopt.FindOneAndUpdate().SetUpsert(true)
_, err := c.FindOneAndUpdate(context.Background(), filter, update, opts)
assert.Equal(t, mongo.ErrNoDocuments, err)
mt.AddMockResponses(bson.D{{Key: "ok", Value: 1}, {Key: "value", Value: bson.D{
{Key: "name", Value: "John"},
}}})
resp, err := c.FindOneAndUpdate(context.Background(), filter, update, opts)
assert.Nil(t, err)
var val struct {
Name string `bson:"name"`
}
assert.Nil(t, resp.Decode(&val))
assert.Equal(t, "John", val.Name)
c.brk = new(dropBreaker)
_, err = c.FindOneAndUpdate(context.Background(), filter, update, opts)
assert.Equal(t, errDummy, err)
})
}
func TestCollection_InsertOne(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
c := decoratedCollection{
Collection: mt.Coll,
brk: breaker.NewBreaker(),
}
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "ok", Value: 1}}...))
res, err := c.InsertOne(context.Background(), bson.D{{Key: "foo", Value: "bar"}})
assert.Nil(t, err)
assert.NotNil(t, res)
c.brk = new(dropBreaker)
_, err = c.InsertOne(context.Background(), bson.D{{Key: "foo", Value: "bar"}})
assert.Equal(t, errDummy, err)
})
}
func TestCollection_InsertMany(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
c := decoratedCollection{
Collection: mt.Coll,
brk: breaker.NewBreaker(),
}
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "ok", Value: 1}}...))
res, err := c.InsertMany(context.Background(), []interface{}{
bson.D{{Key: "foo", Value: "bar"}},
bson.D{{Key: "foo", Value: "baz"}},
})
assert.Nil(t, err)
assert.NotNil(t, res)
assert.Equal(t, 2, len(res.InsertedIDs))
c.brk = new(dropBreaker)
_, err = c.InsertMany(context.Background(), []interface{}{bson.D{{Key: "foo", Value: "bar"}}})
assert.Equal(t, errDummy, err)
})
}
func TestCollection_Remove(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
c := decoratedCollection{
Collection: mt.Coll,
brk: breaker.NewBreaker(),
}
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "n", Value: 1}}...))
res, err := c.DeleteOne(context.Background(), bson.D{{Key: "foo", Value: "bar"}})
assert.Nil(t, err)
assert.Equal(t, int64(1), res.DeletedCount)
c.brk = new(dropBreaker)
_, err = c.DeleteOne(context.Background(), bson.D{{Key: "foo", Value: "bar"}})
assert.Equal(t, errDummy, err)
})
}
func TestCollectionRemoveAll(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
c := decoratedCollection{
Collection: mt.Coll,
brk: breaker.NewBreaker(),
}
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "n", Value: 1}}...))
res, err := c.DeleteMany(context.Background(), bson.D{{Key: "foo", Value: "bar"}})
assert.Nil(t, err)
assert.Equal(t, int64(1), res.DeletedCount)
c.brk = new(dropBreaker)
_, err = c.DeleteMany(context.Background(), bson.D{{Key: "foo", Value: "bar"}})
assert.Equal(t, errDummy, err)
})
}
func TestCollection_ReplaceOne(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
c := decoratedCollection{
Collection: mt.Coll,
brk: breaker.NewBreaker(),
}
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "n", Value: 1}}...))
res, err := c.ReplaceOne(context.Background(), bson.D{{Key: "foo", Value: "bar"}},
bson.D{{Key: "foo", Value: "baz"}},
)
assert.Nil(t, err)
assert.Equal(t, int64(1), res.MatchedCount)
c.brk = new(dropBreaker)
_, err = c.ReplaceOne(context.Background(), bson.D{{Key: "foo", Value: "bar"}},
bson.D{{Key: "foo", Value: "baz"}})
assert.Equal(t, errDummy, err)
})
}
func TestCollection_UpdateOne(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
c := decoratedCollection{
Collection: mt.Coll,
brk: breaker.NewBreaker(),
}
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "n", Value: 1}}...))
resp, err := c.UpdateOne(context.Background(), bson.D{{Key: "foo", Value: "bar"}},
bson.D{{Key: "$set", Value: bson.D{{Key: "baz", Value: "qux"}}}})
assert.Nil(t, err)
assert.Equal(t, int64(1), resp.MatchedCount)
c.brk = new(dropBreaker)
_, err = c.UpdateOne(context.Background(), bson.D{{Key: "foo", Value: "bar"}},
bson.D{{Key: "$set", Value: bson.D{{Key: "baz", Value: "qux"}}}})
assert.Equal(t, errDummy, err)
})
}
func TestCollection_UpdateByID(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
c := decoratedCollection{
Collection: mt.Coll,
brk: breaker.NewBreaker(),
}
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "n", Value: 1}}...))
resp, err := c.UpdateByID(context.Background(), primitive.NewObjectID(),
bson.D{{Key: "$set", Value: bson.D{{Key: "baz", Value: "qux"}}}})
assert.Nil(t, err)
assert.Equal(t, int64(1), resp.MatchedCount)
c.brk = new(dropBreaker)
_, err = c.UpdateByID(context.Background(), primitive.NewObjectID(),
bson.D{{Key: "$set", Value: bson.D{{Key: "baz", Value: "qux"}}}})
assert.Equal(t, errDummy, err)
})
}
func TestCollection_UpdateMany(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
c := decoratedCollection{
Collection: mt.Coll,
brk: breaker.NewBreaker(),
}
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "n", Value: 1}}...))
resp, err := c.UpdateMany(context.Background(), bson.D{{Key: "foo", Value: "bar"}},
bson.D{{Key: "$set", Value: bson.D{{Key: "baz", Value: "qux"}}}})
assert.Nil(t, err)
assert.Equal(t, int64(1), resp.MatchedCount)
c.brk = new(dropBreaker)
_, err = c.UpdateMany(context.Background(), bson.D{{Key: "foo", Value: "bar"}},
bson.D{{Key: "$set", Value: bson.D{{Key: "baz", Value: "qux"}}}})
assert.Equal(t, errDummy, err)
})
}
type mockPromise struct {
accepted bool
reason string
}
func (p *mockPromise) Accept() {
p.accepted = true
}
func (p *mockPromise) Reject(reason string) {
p.reason = reason
}
type dropBreaker struct{}
func (d *dropBreaker) Name() string {
return "dummy"
}
func (d *dropBreaker) Allow() (breaker.Promise, error) {
return nil, errDummy
}
func (d *dropBreaker) Do(req func() error) error {
return nil
}
func (d *dropBreaker) DoWithAcceptable(req func() error, acceptable breaker.Acceptable) error {
return errDummy
}
func (d *dropBreaker) DoWithFallback(req func() error, fallback func(err error) error) error {
return nil
}
func (d *dropBreaker) DoWithFallbackAcceptable(_ func() error, _ func(err error) error,
_ breaker.Acceptable) error {
return nil
}

214
core/stores/mon/model.go Normal file
View File

@@ -0,0 +1,214 @@
package mon
import (
"context"
"log"
"strings"
"github.com/zeromicro/go-zero/core/breaker"
"github.com/zeromicro/go-zero/core/timex"
"go.mongodb.org/mongo-driver/mongo"
mopt "go.mongodb.org/mongo-driver/mongo/options"
)
type (
// Model is a mongodb store model that represents a collection.
Model struct {
Collection
name string
cli *mongo.Client
brk breaker.Breaker
opts []Option
}
wrappedSession struct {
mongo.Session
brk breaker.Breaker
}
)
// MustNewModel returns a Model, exits on errors.
func MustNewModel(uri, db, collection string, opts ...Option) *Model {
model, err := NewModel(uri, db, collection, opts...)
if err != nil {
log.Fatal(err)
}
return model
}
// NewModel returns a Model.
func NewModel(uri, db, collection string, opts ...Option) (*Model, error) {
cli, err := getClient(uri)
if err != nil {
return nil, err
}
name := strings.Join([]string{uri, collection}, "/")
brk := breaker.GetBreaker(uri)
coll := newCollection(cli.Database(db).Collection(collection), brk)
return newModel(name, cli, coll, brk, opts...), nil
}
func newModel(name string, cli *mongo.Client, coll Collection, brk breaker.Breaker,
opts ...Option) *Model {
return &Model{
name: name,
Collection: coll,
cli: cli,
brk: brk,
opts: opts,
}
}
// StartSession starts a new session.
func (m *Model) StartSession(opts ...*mopt.SessionOptions) (sess mongo.Session, err error) {
err = m.brk.DoWithAcceptable(func() error {
starTime := timex.Now()
defer func() {
logDuration(m.name, "StartSession", starTime, err)
}()
session, sessionErr := m.cli.StartSession(opts...)
if sessionErr != nil {
return sessionErr
}
sess = &wrappedSession{
Session: session,
brk: m.brk,
}
return nil
}, acceptable)
return
}
// Aggregate executes an aggregation pipeline.
func (m *Model) Aggregate(ctx context.Context, v, pipeline interface{}, opts ...*mopt.AggregateOptions) error {
cur, err := m.Collection.Aggregate(ctx, pipeline, opts...)
if err != nil {
return err
}
defer cur.Close(ctx)
return cur.All(ctx, v)
}
// DeleteMany deletes documents that match the filter.
func (m *Model) DeleteMany(ctx context.Context, filter interface{}, opts ...*mopt.DeleteOptions) (int64, error) {
res, err := m.Collection.DeleteMany(ctx, filter, opts...)
if err != nil {
return 0, err
}
return res.DeletedCount, nil
}
// DeleteOne deletes the first document that matches the filter.
func (m *Model) DeleteOne(ctx context.Context, filter interface{}, opts ...*mopt.DeleteOptions) (int64, error) {
res, err := m.Collection.DeleteOne(ctx, filter, opts...)
if err != nil {
return 0, err
}
return res.DeletedCount, nil
}
// Find finds documents that match the filter.
func (m *Model) Find(ctx context.Context, v, filter interface{}, opts ...*mopt.FindOptions) error {
cur, err := m.Collection.Find(ctx, filter, opts...)
if err != nil {
return err
}
defer cur.Close(ctx)
return cur.All(ctx, v)
}
// FindOne finds the first document that matches the filter.
func (m *Model) FindOne(ctx context.Context, v, filter interface{}, opts ...*mopt.FindOneOptions) error {
res, err := m.Collection.FindOne(ctx, filter, opts...)
if err != nil {
return err
}
return res.Decode(v)
}
// FindOneAndDelete finds a single document and deletes it.
func (m *Model) FindOneAndDelete(ctx context.Context, v, filter interface{},
opts ...*mopt.FindOneAndDeleteOptions) error {
res, err := m.Collection.FindOneAndDelete(ctx, filter, opts...)
if err != nil {
return err
}
return res.Decode(v)
}
// FindOneAndReplace finds a single document and replaces it.
func (m *Model) FindOneAndReplace(ctx context.Context, v, filter interface{}, replacement interface{},
opts ...*mopt.FindOneAndReplaceOptions) error {
res, err := m.Collection.FindOneAndReplace(ctx, filter, replacement, opts...)
if err != nil {
return err
}
return res.Decode(v)
}
// FindOneAndUpdate finds a single document and updates it.
func (m *Model) FindOneAndUpdate(ctx context.Context, v, filter interface{}, update interface{},
opts ...*mopt.FindOneAndUpdateOptions) error {
res, err := m.Collection.FindOneAndUpdate(ctx, filter, update, opts...)
if err != nil {
return err
}
return res.Decode(v)
}
func (w *wrappedSession) AbortTransaction(ctx context.Context) error {
ctx, span := startSpan(ctx)
defer span.End()
return w.brk.DoWithAcceptable(func() error {
return w.Session.AbortTransaction(ctx)
}, acceptable)
}
func (w *wrappedSession) CommitTransaction(ctx context.Context) error {
ctx, span := startSpan(ctx)
defer span.End()
return w.brk.DoWithAcceptable(func() error {
return w.Session.CommitTransaction(ctx)
}, acceptable)
}
func (w *wrappedSession) WithTransaction(
ctx context.Context,
fn func(sessCtx mongo.SessionContext) (interface{}, error),
opts ...*mopt.TransactionOptions,
) (res interface{}, err error) {
ctx, span := startSpan(ctx)
defer span.End()
err = w.brk.DoWithAcceptable(func() error {
res, err = w.Session.WithTransaction(ctx, fn, opts...)
return err
}, acceptable)
return
}
func (w *wrappedSession) EndSession(ctx context.Context) {
ctx, span := startSpan(ctx)
defer span.End()
_ = w.brk.DoWithAcceptable(func() error {
w.Session.EndSession(ctx)
return nil
}, acceptable)
}

View File

@@ -0,0 +1,243 @@
package mon
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/integration/mtest"
)
func TestModel_StartSession(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
m := createModel(mt)
sess, err := m.StartSession()
assert.Nil(t, err)
defer sess.EndSession(context.Background())
_, err = sess.WithTransaction(context.Background(), func(sessCtx mongo.SessionContext) (interface{}, error) {
_ = sessCtx.StartTransaction()
sessCtx.Client().Database("1")
sessCtx.EndSession(context.Background())
return nil, nil
})
assert.Nil(t, err)
assert.NoError(t, sess.CommitTransaction(context.Background()))
assert.Error(t, sess.AbortTransaction(context.Background()))
})
}
func TestModel_Aggregate(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
m := createModel(mt)
find := mtest.CreateCursorResponse(
1,
"DBName.CollectionName",
mtest.FirstBatch,
bson.D{
{Key: "name", Value: "John"},
})
getMore := mtest.CreateCursorResponse(
1,
"DBName.CollectionName",
mtest.NextBatch,
bson.D{
{Key: "name", Value: "Mary"},
})
killCursors := mtest.CreateCursorResponse(
0,
"DBName.CollectionName",
mtest.NextBatch)
mt.AddMockResponses(find, getMore, killCursors)
var result []interface{}
err := m.Aggregate(context.Background(), &result, mongo.Pipeline{})
assert.Nil(t, err)
assert.Equal(t, 2, len(result))
assert.Equal(t, "John", result[0].(bson.D).Map()["name"])
assert.Equal(t, "Mary", result[1].(bson.D).Map()["name"])
triggerBreaker(m)
assert.Equal(t, errDummy, m.Aggregate(context.Background(), &result, mongo.Pipeline{}))
})
}
func TestModel_DeleteMany(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
m := createModel(mt)
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "n", Value: 1}}...))
val, err := m.DeleteMany(context.Background(), bson.D{})
assert.Nil(t, err)
assert.Equal(t, int64(1), val)
triggerBreaker(m)
_, err = m.DeleteMany(context.Background(), bson.D{})
assert.Equal(t, errDummy, err)
})
}
func TestModel_DeleteOne(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
m := createModel(mt)
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "n", Value: 1}}...))
val, err := m.DeleteOne(context.Background(), bson.D{})
assert.Nil(t, err)
assert.Equal(t, int64(1), val)
triggerBreaker(m)
_, err = m.DeleteOne(context.Background(), bson.D{})
assert.Equal(t, errDummy, err)
})
}
func TestModel_Find(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
m := createModel(mt)
find := mtest.CreateCursorResponse(
1,
"DBName.CollectionName",
mtest.FirstBatch,
bson.D{
{Key: "name", Value: "John"},
})
getMore := mtest.CreateCursorResponse(
1,
"DBName.CollectionName",
mtest.NextBatch,
bson.D{
{Key: "name", Value: "Mary"},
})
killCursors := mtest.CreateCursorResponse(
0,
"DBName.CollectionName",
mtest.NextBatch)
mt.AddMockResponses(find, getMore, killCursors)
var result []interface{}
err := m.Find(context.Background(), &result, bson.D{})
assert.Nil(t, err)
assert.Equal(t, 2, len(result))
assert.Equal(t, "John", result[0].(bson.D).Map()["name"])
assert.Equal(t, "Mary", result[1].(bson.D).Map()["name"])
triggerBreaker(m)
assert.Equal(t, errDummy, m.Find(context.Background(), &result, bson.D{}))
})
}
func TestModel_FindOne(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
m := createModel(mt)
find := mtest.CreateCursorResponse(
1,
"DBName.CollectionName",
mtest.FirstBatch,
bson.D{
{Key: "name", Value: "John"},
})
killCursors := mtest.CreateCursorResponse(
0,
"DBName.CollectionName",
mtest.NextBatch)
mt.AddMockResponses(find, killCursors)
var result bson.D
err := m.FindOne(context.Background(), &result, bson.D{})
assert.Nil(t, err)
assert.Equal(t, "John", result.Map()["name"])
triggerBreaker(m)
assert.Equal(t, errDummy, m.FindOne(context.Background(), &result, bson.D{}))
})
}
func TestModel_FindOneAndDelete(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
m := createModel(mt)
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "name", Value: "John"}}},
}...))
var result bson.D
err := m.FindOneAndDelete(context.Background(), &result, bson.D{})
assert.Nil(t, err)
assert.Equal(t, "John", result.Map()["name"])
triggerBreaker(m)
assert.Equal(t, errDummy, m.FindOneAndDelete(context.Background(), &result, bson.D{}))
})
}
func TestModel_FindOneAndReplace(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
m := createModel(mt)
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "name", Value: "John"}}},
}...))
var result bson.D
err := m.FindOneAndReplace(context.Background(), &result, bson.D{}, bson.D{
{Key: "name", Value: "Mary"},
})
assert.Nil(t, err)
assert.Equal(t, "John", result.Map()["name"])
triggerBreaker(m)
assert.Equal(t, errDummy, m.FindOneAndReplace(context.Background(), &result, bson.D{}, bson.D{
{Key: "name", Value: "Mary"},
}))
})
}
func TestModel_FindOneAndUpdate(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
m := createModel(mt)
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "name", Value: "John"}}},
}...))
var result bson.D
err := m.FindOneAndUpdate(context.Background(), &result, bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "name", Value: "Mary"}}},
})
assert.Nil(t, err)
assert.Equal(t, "John", result.Map()["name"])
triggerBreaker(m)
assert.Equal(t, errDummy, m.FindOneAndUpdate(context.Background(), &result, bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "name", Value: "Mary"}}},
}))
})
}
func createModel(mt *mtest.T) *Model {
Inject(mt.Name(), mt.Client)
return MustNewModel(mt.Name(), mt.DB.Name(), mt.Coll.Name())
}
func triggerBreaker(m *Model) {
m.Collection.(*decoratedCollection).brk = new(dropBreaker)
}

View File

@@ -0,0 +1,29 @@
package mon
import (
"time"
"github.com/zeromicro/go-zero/core/syncx"
)
var slowThreshold = syncx.ForAtomicDuration(defaultSlowThreshold)
type (
options struct {
timeout time.Duration
}
// Option defines the method to customize a mongo model.
Option func(opts *options)
)
// SetSlowThreshold sets the slow threshold.
func SetSlowThreshold(threshold time.Duration) {
slowThreshold.Set(threshold)
}
func defaultOptions() *options {
return &options{
timeout: defaultTimeout,
}
}

View File

@@ -0,0 +1,18 @@
package mon
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestSetSlowThreshold(t *testing.T) {
assert.Equal(t, defaultSlowThreshold, slowThreshold.Load())
SetSlowThreshold(time.Second)
assert.Equal(t, time.Second, slowThreshold.Load())
}
func TestDefaultOptions(t *testing.T) {
assert.Equal(t, defaultTimeout, defaultOptions().timeout)
}

25
core/stores/mon/util.go Normal file
View File

@@ -0,0 +1,25 @@
package mon
import (
"strings"
"time"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/timex"
)
const mongoAddrSep = ","
// FormatAddr formats mongo hosts to a string.
func FormatAddr(hosts []string) string {
return strings.Join(hosts, mongoAddrSep)
}
func logDuration(name, method string, startTime time.Duration, err error) {
duration := timex.Since(startTime)
if err != nil {
logx.WithDuration(duration).Infof("mongo(%s) - %s - fail(%s)", name, method, err.Error())
} else {
logx.WithDuration(duration).Infof("mongo(%s) - %s - ok", name, method)
}
}

View File

@@ -0,0 +1,35 @@
package mon
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestFormatAddrs(t *testing.T) {
tests := []struct {
addrs []string
expect string
}{
{
addrs: []string{"a", "b"},
expect: "a,b",
},
{
addrs: []string{"a", "b", "c"},
expect: "a,b,c",
},
{
addrs: []string{},
expect: "",
},
{
addrs: nil,
expect: "",
},
}
for _, test := range tests {
assert.Equal(t, test.expect, FormatAddr(test.addrs))
}
}

View File

@@ -0,0 +1,281 @@
package monc
import (
"context"
"log"
"github.com/zeromicro/go-zero/core/stores/cache"
"github.com/zeromicro/go-zero/core/stores/mon"
"github.com/zeromicro/go-zero/core/stores/redis"
"github.com/zeromicro/go-zero/core/syncx"
"go.mongodb.org/mongo-driver/mongo"
mopt "go.mongodb.org/mongo-driver/mongo/options"
)
var (
// ErrNotFound is an alias of mongo.ErrNoDocuments.
ErrNotFound = mongo.ErrNoDocuments
// can't use one SingleFlight per conn, because multiple conns may share the same cache key.
singleFlight = syncx.NewSingleFlight()
stats = cache.NewStat("monc")
)
// A Model is a mongo model that built with cache capability.
type Model struct {
*mon.Model
cache cache.Cache
}
// MustNewModel returns a Model with a cache cluster, exists on errors.
func MustNewModel(uri, db, collection string, c cache.CacheConf, opts ...cache.Option) *Model {
model, err := NewModel(uri, db, collection, c, opts...)
if err != nil {
log.Fatal(err)
}
return model
}
// MustNewNodeModel returns a Model with a cache node, exists on errors.
func MustNewNodeModel(uri, db, collection string, rds *redis.Redis, opts ...cache.Option) *Model {
model, err := NewNodeModel(uri, db, collection, rds, opts...)
if err != nil {
log.Fatal(err)
}
return model
}
// NewModel returns a Model with a cache cluster.
func NewModel(uri, db, collection string, conf cache.CacheConf, opts ...cache.Option) (*Model, error) {
c := cache.New(conf, singleFlight, stats, mongo.ErrNoDocuments, opts...)
return NewModelWithCache(uri, db, collection, c)
}
// NewModelWithCache returns a Model with a custom cache.
func NewModelWithCache(uri, db, collection string, c cache.Cache) (*Model, error) {
return newModel(uri, db, collection, c)
}
// NewNodeModel returns a Model with a cache node.
func NewNodeModel(uri, db, collection string, rds *redis.Redis, opts ...cache.Option) (*Model, error) {
c := cache.NewNode(rds, singleFlight, stats, mongo.ErrNoDocuments, opts...)
return NewModelWithCache(uri, db, collection, c)
}
// newModel returns a Model with the given cache.
func newModel(uri, db, collection string, c cache.Cache) (*Model, error) {
model, err := mon.NewModel(uri, db, collection)
if err != nil {
return nil, err
}
return &Model{
Model: model,
cache: c,
}, nil
}
// DelCache deletes the cache with given keys.
func (mm *Model) DelCache(ctx context.Context, keys ...string) error {
return mm.cache.DelCtx(ctx, keys...)
}
// DeleteOne deletes the document with given filter, and remove it from cache.
func (mm *Model) DeleteOne(ctx context.Context, key string, filter interface{},
opts ...*mopt.DeleteOptions) (int64, error) {
val, err := mm.Model.DeleteOne(ctx, filter, opts...)
if err != nil {
return 0, err
}
if err := mm.DelCache(ctx, key); err != nil {
return 0, err
}
return val, nil
}
// DeleteOneNoCache deletes the document with given filter.
func (mm *Model) DeleteOneNoCache(ctx context.Context, filter interface{},
opts ...*mopt.DeleteOptions) (int64, error) {
return mm.Model.DeleteOne(ctx, filter, opts...)
}
// FindOne unmarshals a record into v with given key and query.
func (mm *Model) FindOne(ctx context.Context, key string, v, filter interface{},
opts ...*mopt.FindOneOptions) error {
return mm.cache.TakeCtx(ctx, v, key, func(v interface{}) error {
return mm.Model.FindOne(ctx, v, filter, opts...)
})
}
// FindOneNoCache unmarshals a record into v with query, without cache.
func (mm *Model) FindOneNoCache(ctx context.Context, v, filter interface{},
opts ...*mopt.FindOneOptions) error {
return mm.Model.FindOne(ctx, v, filter, opts...)
}
// FindOneAndDelete deletes the document with given filter, and unmarshals it into v.
func (mm *Model) FindOneAndDelete(ctx context.Context, key string, v, filter interface{},
opts ...*mopt.FindOneAndDeleteOptions) error {
if err := mm.Model.FindOneAndDelete(ctx, v, filter, opts...); err != nil {
return err
}
return mm.DelCache(ctx, key)
}
// FindOneAndDeleteNoCache deletes the document with given filter, and unmarshals it into v.
func (mm *Model) FindOneAndDeleteNoCache(ctx context.Context, v, filter interface{},
opts ...*mopt.FindOneAndDeleteOptions) error {
return mm.Model.FindOneAndDelete(ctx, v, filter, opts...)
}
// FindOneAndReplace replaces the document with given filter with replacement, and unmarshals it into v.
func (mm *Model) FindOneAndReplace(ctx context.Context, key string, v, filter interface{},
replacement interface{}, opts ...*mopt.FindOneAndReplaceOptions) error {
if err := mm.Model.FindOneAndReplace(ctx, v, filter, replacement, opts...); err != nil {
return err
}
return mm.DelCache(ctx, key)
}
// FindOneAndReplaceNoCache replaces the document with given filter with replacement, and unmarshals it into v.
func (mm *Model) FindOneAndReplaceNoCache(ctx context.Context, v, filter interface{},
replacement interface{}, opts ...*mopt.FindOneAndReplaceOptions) error {
return mm.Model.FindOneAndReplace(ctx, v, filter, replacement, opts...)
}
// FindOneAndUpdate updates the document with given filter with update, and unmarshals it into v.
func (mm *Model) FindOneAndUpdate(ctx context.Context, key string, v, filter interface{},
update interface{}, opts ...*mopt.FindOneAndUpdateOptions) error {
if err := mm.Model.FindOneAndUpdate(ctx, v, filter, update, opts...); err != nil {
return err
}
return mm.DelCache(ctx, key)
}
// FindOneAndUpdateNoCache updates the document with given filter with update, and unmarshals it into v.
func (mm *Model) FindOneAndUpdateNoCache(ctx context.Context, v, filter interface{},
update interface{}, opts ...*mopt.FindOneAndUpdateOptions) error {
return mm.Model.FindOneAndUpdate(ctx, v, filter, update, opts...)
}
// GetCache unmarshal the cache into v with given key.
func (mm *Model) GetCache(key string, v interface{}) error {
return mm.cache.Get(key, v)
}
// InsertOne inserts a single document into the collection, and remove the cache placeholder.
func (mm *Model) InsertOne(ctx context.Context, key string, document interface{},
opts ...*mopt.InsertOneOptions) (*mongo.InsertOneResult, error) {
res, err := mm.Model.InsertOne(ctx, document, opts...)
if err != nil {
return nil, err
}
if err = mm.DelCache(ctx, key); err != nil {
return nil, err
}
return res, nil
}
// InsertOneNoCache inserts a single document into the collection.
func (mm *Model) InsertOneNoCache(ctx context.Context, document interface{},
opts ...*mopt.InsertOneOptions) (*mongo.InsertOneResult, error) {
return mm.Model.InsertOne(ctx, document, opts...)
}
// ReplaceOne replaces a single document in the collection, and remove the cache.
func (mm *Model) ReplaceOne(ctx context.Context, key string, filter interface{}, replacement interface{},
opts ...*mopt.ReplaceOptions) (*mongo.UpdateResult, error) {
res, err := mm.Model.ReplaceOne(ctx, filter, replacement, opts...)
if err != nil {
return nil, err
}
if err = mm.DelCache(ctx, key); err != nil {
return nil, err
}
return res, nil
}
// ReplaceOneNoCache replaces a single document in the collection.
func (mm *Model) ReplaceOneNoCache(ctx context.Context, filter interface{}, replacement interface{},
opts ...*mopt.ReplaceOptions) (*mongo.UpdateResult, error) {
return mm.Model.ReplaceOne(ctx, filter, replacement, opts...)
}
// SetCache sets the cache with given key and value.
func (mm *Model) SetCache(key string, v interface{}) error {
return mm.cache.Set(key, v)
}
// UpdateByID updates the document with given id with update, and remove the cache.
func (mm *Model) UpdateByID(ctx context.Context, key string, id interface{}, update interface{},
opts ...*mopt.UpdateOptions) (*mongo.UpdateResult, error) {
res, err := mm.Model.UpdateByID(ctx, id, update, opts...)
if err != nil {
return nil, err
}
if err = mm.DelCache(ctx, key); err != nil {
return nil, err
}
return res, nil
}
// UpdateByIDNoCache updates the document with given id with update.
func (mm *Model) UpdateByIDNoCache(ctx context.Context, id interface{}, update interface{},
opts ...*mopt.UpdateOptions) (*mongo.UpdateResult, error) {
return mm.Model.UpdateByID(ctx, id, update, opts...)
}
// UpdateMany updates the documents that match filter with update, and remove the cache.
func (mm *Model) UpdateMany(ctx context.Context, keys []string, filter interface{}, update interface{},
opts ...*mopt.UpdateOptions) (*mongo.UpdateResult, error) {
res, err := mm.Model.UpdateMany(ctx, filter, update, opts...)
if err != nil {
return nil, err
}
if err = mm.DelCache(ctx, keys...); err != nil {
return nil, err
}
return res, nil
}
// UpdateManyNoCache updates the documents that match filter with update.
func (mm *Model) UpdateManyNoCache(ctx context.Context, filter interface{}, update interface{},
opts ...*mopt.UpdateOptions) (*mongo.UpdateResult, error) {
return mm.Model.UpdateMany(ctx, filter, update, opts...)
}
// UpdateOne updates the first document that matches filter with update, and remove the cache.
func (mm *Model) UpdateOne(ctx context.Context, key string, filter interface{}, update interface{},
opts ...*mopt.UpdateOptions) (*mongo.UpdateResult, error) {
res, err := mm.Model.UpdateOne(ctx, filter, update, opts...)
if err != nil {
return nil, err
}
if err = mm.DelCache(ctx, key); err != nil {
return nil, err
}
return res, nil
}
// UpdateOneNoCache updates the first document that matches filter with update.
func (mm *Model) UpdateOneNoCache(ctx context.Context, filter interface{}, update interface{},
opts ...*mopt.UpdateOptions) (*mongo.UpdateResult, error) {
return mm.Model.UpdateOne(ctx, filter, update, opts...)
}

View File

@@ -0,0 +1,581 @@
package monc
import (
"context"
"errors"
"sync/atomic"
"testing"
"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/stores/cache"
"github.com/zeromicro/go-zero/core/stores/mon"
"github.com/zeromicro/go-zero/core/stores/redis"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/integration/mtest"
)
func TestNewModel(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
_, err := newModel("foo", mt.DB.Name(), mt.Coll.Name(), nil)
assert.NotNil(mt, err)
})
}
func TestModel_DelCache(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
m := createModel(t, mt)
assert.Nil(t, m.cache.Set("foo", "bar"))
assert.Nil(t, m.cache.Set("bar", "baz"))
assert.Nil(t, m.DelCache(context.Background(), "foo", "bar"))
var v string
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
assert.True(t, m.cache.IsNotFound(m.cache.Get("bar", &v)))
})
}
func TestModel_DeleteOne(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "n", Value: 1}}...))
m := createModel(t, mt)
assert.Nil(t, m.cache.Set("foo", "bar"))
val, err := m.DeleteOne(context.Background(), "foo", bson.D{{Key: "foo", Value: "bar"}})
assert.Nil(t, err)
assert.Equal(t, int64(1), val)
var v string
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
_, err = m.DeleteOne(context.Background(), "foo", bson.D{{Key: "foo", Value: "bar"}})
assert.NotNil(t, err)
m.cache = mockedCache{m.cache}
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "n", Value: 1}}...))
_, err = m.DeleteOne(context.Background(), "foo", bson.D{{Key: "foo", Value: "bar"}})
assert.Equal(t, errMocked, err)
})
}
func TestModel_DeleteOneNoCache(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "n", Value: 1}}...))
m := createModel(t, mt)
assert.Nil(t, m.cache.Set("foo", "bar"))
val, err := m.DeleteOneNoCache(context.Background(), bson.D{{Key: "foo", Value: "bar"}})
assert.Nil(t, err)
assert.Equal(t, int64(1), val)
var v string
assert.Nil(t, m.cache.Get("foo", &v))
})
}
func TestModel_FindOne(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
resp := mtest.CreateCursorResponse(
1,
"DBName.CollectionName",
mtest.FirstBatch,
bson.D{
{Key: "foo", Value: "bar"},
})
mt.AddMockResponses(resp)
m := createModel(t, mt)
var v struct {
Foo string `bson:"foo"`
}
assert.Nil(t, m.FindOne(context.Background(), "foo", &v, bson.D{}))
assert.Equal(t, "bar", v.Foo)
assert.Nil(t, m.cache.Set("foo", "bar"))
})
}
func TestModel_FindOneNoCache(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
resp := mtest.CreateCursorResponse(
1,
"DBName.CollectionName",
mtest.FirstBatch,
bson.D{
{Key: "foo", Value: "bar"},
})
mt.AddMockResponses(resp)
m := createModel(t, mt)
v := struct {
Foo string `bson:"foo"`
}{}
assert.Nil(t, m.FindOneNoCache(context.Background(), &v, bson.D{}))
assert.Equal(t, "bar", v.Foo)
})
}
func TestModel_FindOneAndDelete(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
}...))
m := createModel(t, mt)
assert.Nil(t, m.cache.Set("foo", "bar"))
v := struct {
Foo string `bson:"foo"`
}{}
assert.Nil(t, m.FindOneAndDelete(context.Background(), "foo", &v, bson.D{}))
assert.Equal(t, "bar", v.Foo)
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
assert.NotNil(t, m.FindOneAndDelete(context.Background(), "foo", &v, bson.D{}))
m.cache = mockedCache{m.cache}
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
}...))
assert.Equal(t, errMocked, m.FindOneAndDelete(context.Background(), "foo", &v, bson.D{}))
})
}
func TestModel_FindOneAndDeleteNoCache(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
}...))
m := createModel(t, mt)
v := struct {
Foo string `bson:"foo"`
}{}
assert.Nil(t, m.FindOneAndDeleteNoCache(context.Background(), &v, bson.D{}))
assert.Equal(t, "bar", v.Foo)
})
}
func TestModel_FindOneAndReplace(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
}...))
m := createModel(t, mt)
assert.Nil(t, m.cache.Set("foo", "bar"))
v := struct {
Foo string `bson:"foo"`
}{}
assert.Nil(t, m.FindOneAndReplace(context.Background(), "foo", &v, bson.D{}, bson.D{
{Key: "name", Value: "Mary"},
}))
assert.Equal(t, "bar", v.Foo)
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
assert.NotNil(t, m.FindOneAndReplace(context.Background(), "foo", &v, bson.D{}, bson.D{
{Key: "name", Value: "Mary"},
}))
m.cache = mockedCache{m.cache}
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
}...))
assert.Equal(t, errMocked, m.FindOneAndReplace(context.Background(), "foo", &v, bson.D{}, bson.D{
{Key: "name", Value: "Mary"},
}))
})
}
func TestModel_FindOneAndReplaceNoCache(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
}...))
m := createModel(t, mt)
v := struct {
Foo string `bson:"foo"`
}{}
assert.Nil(t, m.FindOneAndReplaceNoCache(context.Background(), &v, bson.D{}, bson.D{
{Key: "name", Value: "Mary"},
}))
assert.Equal(t, "bar", v.Foo)
})
}
func TestModel_FindOneAndUpdate(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
}...))
m := createModel(t, mt)
assert.Nil(t, m.cache.Set("foo", "bar"))
v := struct {
Foo string `bson:"foo"`
}{}
assert.Nil(t, m.FindOneAndUpdate(context.Background(), "foo", &v, bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "name", Value: "Mary"}}},
}))
assert.Equal(t, "bar", v.Foo)
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
assert.NotNil(t, m.FindOneAndUpdate(context.Background(), "foo", &v, bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "name", Value: "Mary"}}},
}))
m.cache = mockedCache{m.cache}
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
}...))
assert.Equal(t, errMocked, m.FindOneAndUpdate(context.Background(), "foo", &v, bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "name", Value: "Mary"}}},
}))
})
}
func TestModel_FindOneAndUpdateNoCache(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
}...))
m := createModel(t, mt)
v := struct {
Foo string `bson:"foo"`
}{}
assert.Nil(t, m.FindOneAndUpdateNoCache(context.Background(), &v, bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "name", Value: "Mary"}}},
}))
assert.Equal(t, "bar", v.Foo)
})
}
func TestModel_GetCache(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
m := createModel(t, mt)
assert.NotNil(t, m.cache)
assert.Nil(t, m.cache.Set("foo", "bar"))
var s string
assert.Nil(t, m.cache.Get("foo", &s))
assert.Equal(t, "bar", s)
})
}
func TestModel_InsertOne(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
}...))
m := createModel(t, mt)
assert.Nil(t, m.cache.Set("foo", "bar"))
resp, err := m.InsertOne(context.Background(), "foo", bson.D{
{Key: "name", Value: "Mary"},
})
assert.Nil(t, err)
assert.NotNil(t, resp)
var v string
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
_, err = m.InsertOne(context.Background(), "foo", bson.D{
{Key: "name", Value: "Mary"},
})
assert.NotNil(t, err)
m.cache = mockedCache{m.cache}
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
}...))
_, err = m.InsertOne(context.Background(), "foo", bson.D{
{Key: "name", Value: "Mary"},
})
assert.Equal(t, errMocked, err)
})
}
func TestModel_InsertOneNoCache(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
}...))
m := createModel(t, mt)
resp, err := m.InsertOneNoCache(context.Background(), bson.D{
{Key: "name", Value: "Mary"},
})
assert.Nil(t, err)
assert.NotNil(t, resp)
})
}
func TestModel_ReplaceOne(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
}...))
m := createModel(t, mt)
assert.Nil(t, m.cache.Set("foo", "bar"))
resp, err := m.ReplaceOne(context.Background(), "foo", bson.D{}, bson.D{
{Key: "foo", Value: "baz"},
})
assert.Nil(t, err)
assert.NotNil(t, resp)
var v string
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
_, err = m.ReplaceOne(context.Background(), "foo", bson.D{}, bson.D{
{Key: "foo", Value: "baz"},
})
assert.NotNil(t, err)
m.cache = mockedCache{m.cache}
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
}...))
_, err = m.ReplaceOne(context.Background(), "foo", bson.D{}, bson.D{
{Key: "foo", Value: "baz"},
})
assert.Equal(t, errMocked, err)
})
}
func TestModel_ReplaceOneNoCache(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
}...))
m := createModel(t, mt)
resp, err := m.ReplaceOneNoCache(context.Background(), bson.D{}, bson.D{
{Key: "foo", Value: "baz"},
})
assert.Nil(t, err)
assert.NotNil(t, resp)
})
}
func TestModel_SetCache(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
m := createModel(t, mt)
assert.Nil(t, m.SetCache("foo", "bar"))
var v string
assert.Nil(t, m.GetCache("foo", &v))
assert.Equal(t, "bar", v)
})
}
func TestModel_UpdateByID(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
}...))
m := createModel(t, mt)
assert.Nil(t, m.cache.Set("foo", "bar"))
resp, err := m.UpdateByID(context.Background(), "foo", bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
})
assert.Nil(t, err)
assert.NotNil(t, resp)
var v string
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
_, err = m.UpdateByID(context.Background(), "foo", bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
})
assert.NotNil(t, err)
m.cache = mockedCache{m.cache}
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
}...))
_, err = m.UpdateByID(context.Background(), "foo", bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
})
assert.Equal(t, errMocked, err)
})
}
func TestModel_UpdateByIDNoCache(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
}...))
m := createModel(t, mt)
resp, err := m.UpdateByIDNoCache(context.Background(), bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
})
assert.Nil(t, err)
assert.NotNil(t, resp)
})
}
func TestModel_UpdateMany(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
}...))
m := createModel(t, mt)
assert.Nil(t, m.cache.Set("foo", "bar"))
assert.Nil(t, m.cache.Set("bar", "baz"))
resp, err := m.UpdateMany(context.Background(), []string{"foo", "bar"}, bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
})
assert.Nil(t, err)
assert.NotNil(t, resp)
var v string
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
assert.True(t, m.cache.IsNotFound(m.cache.Get("bar", &v)))
_, err = m.UpdateMany(context.Background(), []string{"foo", "bar"}, bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
})
assert.NotNil(t, err)
m.cache = mockedCache{m.cache}
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
}...))
_, err = m.UpdateMany(context.Background(), []string{"foo", "bar"}, bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
})
assert.Equal(t, errMocked, err)
})
}
func TestModel_UpdateManyNoCache(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
}...))
m := createModel(t, mt)
resp, err := m.UpdateManyNoCache(context.Background(), bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
})
assert.Nil(t, err)
assert.NotNil(t, resp)
})
}
func TestModel_UpdateOne(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
}...))
m := createModel(t, mt)
assert.Nil(t, m.cache.Set("foo", "bar"))
resp, err := m.UpdateOne(context.Background(), "foo", bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
})
assert.Nil(t, err)
assert.NotNil(t, resp)
var v string
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
_, err = m.UpdateOne(context.Background(), "foo", bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
})
assert.NotNil(t, err)
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
}...))
m.cache = mockedCache{m.cache}
_, err = m.UpdateOne(context.Background(), "foo", bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
})
assert.Equal(t, errMocked, err)
})
}
func TestModel_UpdateOneNoCache(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
defer mt.Close()
mt.Run("test", func(mt *mtest.T) {
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
}...))
m := createModel(t, mt)
resp, err := m.UpdateOneNoCache(context.Background(), bson.D{}, bson.D{
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
})
assert.Nil(t, err)
assert.NotNil(t, resp)
})
}
func createModel(t *testing.T, mt *mtest.T) *Model {
s, err := miniredis.Run()
assert.Nil(t, err)
mon.Inject(mt.Name(), mt.Client)
if atomic.AddInt32(&index, 1)%2 == 0 {
return MustNewNodeModel(mt.Name(), mt.DB.Name(), mt.Coll.Name(), redis.New(s.Addr()))
} else {
return MustNewModel(mt.Name(), mt.DB.Name(), mt.Coll.Name(), cache.CacheConf{
cache.NodeConf{
RedisConf: redis.RedisConf{
Host: s.Addr(),
Type: redis.NodeType,
},
Weight: 100,
},
})
}
}
var (
errMocked = errors.New("mocked error")
index int32
)
type mockedCache struct {
cache.Cache
}
func (m mockedCache) DelCtx(_ context.Context, _ ...string) error {
return errMocked
}

View File

@@ -12,8 +12,8 @@ var (
ErrNotFound = mgo.ErrNotFound ErrNotFound = mgo.ErrNotFound
// can't use one SingleFlight per conn, because multiple conns may share the same cache key. // can't use one SingleFlight per conn, because multiple conns may share the same cache key.
sharedCalls = syncx.NewSingleFlight() singleFlight = syncx.NewSingleFlight()
stats = cache.NewStat("mongoc") stats = cache.NewStat("mongoc")
) )
type ( type (

View File

@@ -34,7 +34,7 @@ func TestCollection_Count(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
defer clean() defer clean()
cach := cache.NewNode(r, sharedCalls, stats, mgo.ErrNotFound) cach := cache.NewNode(r, singleFlight, stats, mgo.ErrNotFound)
c := newCollection(dummyConn{}, cach) c := newCollection(dummyConn{}, cach)
val, err := c.Count("any") val, err := c.Count("any")
assert.Nil(t, err) assert.Nil(t, err)
@@ -98,7 +98,7 @@ func TestStat(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
defer clean() defer clean()
cach := cache.NewNode(r, sharedCalls, stats, mgo.ErrNotFound) cach := cache.NewNode(r, singleFlight, stats, mgo.ErrNotFound)
c := newCollection(dummyConn{}, cach).(*cachedCollection) c := newCollection(dummyConn{}, cach).(*cachedCollection)
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
@@ -121,7 +121,7 @@ func TestStatCacheFails(t *testing.T) {
defer log.SetOutput(os.Stdout) defer log.SetOutput(os.Stdout)
r := redis.New("localhost:59999") r := redis.New("localhost:59999")
cach := cache.NewNode(r, sharedCalls, stats, mgo.ErrNotFound) cach := cache.NewNode(r, singleFlight, stats, mgo.ErrNotFound)
c := newCollection(dummyConn{}, cach) c := newCollection(dummyConn{}, cach)
for i := 0; i < 20; i++ { for i := 0; i < 20; i++ {
@@ -142,7 +142,7 @@ func TestStatDbFails(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
defer clean() defer clean()
cach := cache.NewNode(r, sharedCalls, stats, mgo.ErrNotFound) cach := cache.NewNode(r, singleFlight, stats, mgo.ErrNotFound)
c := newCollection(dummyConn{}, cach).(*cachedCollection) c := newCollection(dummyConn{}, cach).(*cachedCollection)
for i := 0; i < 20; i++ { for i := 0; i < 20; i++ {
@@ -164,7 +164,7 @@ func TestStatFromMemory(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
defer clean() defer clean()
cach := cache.NewNode(r, sharedCalls, stats, mgo.ErrNotFound) cach := cache.NewNode(r, singleFlight, stats, mgo.ErrNotFound)
c := newCollection(dummyConn{}, cach).(*cachedCollection) c := newCollection(dummyConn{}, cach).(*cachedCollection)
var all sync.WaitGroup var all sync.WaitGroup

View File

@@ -38,7 +38,7 @@ func MustNewModel(url, collection string, c cache.CacheConf, opts ...cache.Optio
// NewModel returns a Model with a cache cluster. // NewModel returns a Model with a cache cluster.
func NewModel(url, collection string, conf cache.CacheConf, opts ...cache.Option) (*Model, error) { func NewModel(url, collection string, conf cache.CacheConf, opts ...cache.Option) (*Model, error) {
c := cache.New(conf, sharedCalls, stats, mgo.ErrNotFound, opts...) c := cache.New(conf, singleFlight, stats, mgo.ErrNotFound, opts...)
return NewModelWithCache(url, collection, c) return NewModelWithCache(url, collection, c)
} }
@@ -51,7 +51,7 @@ func NewModelWithCache(url, collection string, c cache.Cache) (*Model, error) {
// NewNodeModel returns a Model with a cache node. // NewNodeModel returns a Model with a cache node.
func NewNodeModel(url, collection string, rds *redis.Redis, opts ...cache.Option) (*Model, error) { func NewNodeModel(url, collection string, rds *redis.Redis, opts ...cache.Option) (*Model, error) {
c := cache.NewNode(rds, sharedCalls, stats, mgo.ErrNotFound, opts...) c := cache.NewNode(rds, singleFlight, stats, mgo.ErrNotFound, opts...)
return NewModelWithCache(url, collection, c) return NewModelWithCache(url, collection, c)
} }

104
core/stores/redis/hook.go Normal file
View File

@@ -0,0 +1,104 @@
package redis
import (
"context"
"strings"
"time"
red "github.com/go-redis/redis/v8"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/mapping"
"github.com/zeromicro/go-zero/core/timex"
"github.com/zeromicro/go-zero/core/trace"
"go.opentelemetry.io/otel"
tracestd "go.opentelemetry.io/otel/trace"
)
// spanName is the span name of the redis calls.
const spanName = "redis"
var (
startTimeKey = contextKey("startTime")
durationHook = hook{tracer: otel.GetTracerProvider().Tracer(trace.TraceName)}
)
type (
contextKey string
hook struct {
tracer tracestd.Tracer
}
)
func (h hook) BeforeProcess(ctx context.Context, _ red.Cmder) (context.Context, error) {
return h.startSpan(context.WithValue(ctx, startTimeKey, timex.Now())), nil
}
func (h hook) AfterProcess(ctx context.Context, cmd red.Cmder) error {
h.endSpan(ctx)
val := ctx.Value(startTimeKey)
if val == nil {
return nil
}
start, ok := val.(time.Duration)
if !ok {
return nil
}
duration := timex.Since(start)
if duration > slowThreshold.Load() {
logDuration(ctx, cmd, duration)
}
return nil
}
func (h hook) BeforeProcessPipeline(ctx context.Context, _ []red.Cmder) (context.Context, error) {
return h.startSpan(context.WithValue(ctx, startTimeKey, timex.Now())), nil
}
func (h hook) AfterProcessPipeline(ctx context.Context, cmds []red.Cmder) error {
h.endSpan(ctx)
if len(cmds) == 0 {
return nil
}
val := ctx.Value(startTimeKey)
if val == nil {
return nil
}
start, ok := val.(time.Duration)
if !ok {
return nil
}
duration := timex.Since(start)
if duration > slowThreshold.Load()*time.Duration(len(cmds)) {
logDuration(ctx, cmds[0], duration)
}
return nil
}
func logDuration(ctx context.Context, cmd red.Cmder, duration time.Duration) {
var buf strings.Builder
for i, arg := range cmd.Args() {
if i > 0 {
buf.WriteByte(' ')
}
buf.WriteString(mapping.Repr(arg))
}
logx.WithContext(ctx).WithDuration(duration).Slowf("[REDIS] slowcall on executing: %s", buf.String())
}
func (h hook) startSpan(ctx context.Context) context.Context {
ctx, _ = h.tracer.Start(ctx, spanName)
return ctx
}
func (h hook) endSpan(ctx context.Context) {
tracestd.SpanFromContext(ctx).End()
}

View File

@@ -0,0 +1,168 @@
package redis
import (
"context"
"log"
"strings"
"testing"
"time"
red "github.com/go-redis/redis/v8"
"github.com/stretchr/testify/assert"
ztrace "github.com/zeromicro/go-zero/core/trace"
tracesdk "go.opentelemetry.io/otel/trace"
)
func TestHookProcessCase1(t *testing.T) {
ztrace.StartAgent(ztrace.Config{
Name: "go-zero-test",
Endpoint: "http://localhost:14268/api/traces",
Batcher: "jaeger",
Sampler: 1.0,
})
writer := log.Writer()
var buf strings.Builder
log.SetOutput(&buf)
defer log.SetOutput(writer)
ctx, err := durationHook.BeforeProcess(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
assert.Nil(t, durationHook.AfterProcess(ctx, red.NewCmd(context.Background())))
assert.False(t, strings.Contains(buf.String(), "slow"))
assert.Equal(t, "redis", tracesdk.SpanFromContext(ctx).(interface{ Name() string }).Name())
}
func TestHookProcessCase2(t *testing.T) {
ztrace.StartAgent(ztrace.Config{
Name: "go-zero-test",
Endpoint: "http://localhost:14268/api/traces",
Batcher: "jaeger",
Sampler: 1.0,
})
writer := log.Writer()
var buf strings.Builder
log.SetOutput(&buf)
defer log.SetOutput(writer)
ctx, err := durationHook.BeforeProcess(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
assert.Equal(t, "redis", tracesdk.SpanFromContext(ctx).(interface{ Name() string }).Name())
time.Sleep(slowThreshold.Load() + time.Millisecond)
assert.Nil(t, durationHook.AfterProcess(ctx, red.NewCmd(context.Background(), "foo", "bar")))
assert.True(t, strings.Contains(buf.String(), "slow"))
assert.True(t, strings.Contains(buf.String(), "trace"))
assert.True(t, strings.Contains(buf.String(), "span"))
}
func TestHookProcessCase3(t *testing.T) {
writer := log.Writer()
var buf strings.Builder
log.SetOutput(&buf)
defer log.SetOutput(writer)
assert.Nil(t, durationHook.AfterProcess(context.Background(), red.NewCmd(context.Background())))
assert.True(t, buf.Len() == 0)
}
func TestHookProcessCase4(t *testing.T) {
writer := log.Writer()
var buf strings.Builder
log.SetOutput(&buf)
defer log.SetOutput(writer)
ctx := context.WithValue(context.Background(), startTimeKey, "foo")
assert.Nil(t, durationHook.AfterProcess(ctx, red.NewCmd(context.Background())))
assert.True(t, buf.Len() == 0)
}
func TestHookProcessPipelineCase1(t *testing.T) {
writer := log.Writer()
var buf strings.Builder
log.SetOutput(&buf)
defer log.SetOutput(writer)
ctx, err := durationHook.BeforeProcessPipeline(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
assert.Equal(t, "redis", tracesdk.SpanFromContext(ctx).(interface{ Name() string }).Name())
assert.Nil(t, durationHook.AfterProcessPipeline(ctx, []red.Cmder{
red.NewCmd(context.Background()),
}))
assert.False(t, strings.Contains(buf.String(), "slow"))
}
func TestHookProcessPipelineCase2(t *testing.T) {
ztrace.StartAgent(ztrace.Config{
Name: "go-zero-test",
Endpoint: "http://localhost:14268/api/traces",
Batcher: "jaeger",
Sampler: 1.0,
})
writer := log.Writer()
var buf strings.Builder
log.SetOutput(&buf)
defer log.SetOutput(writer)
ctx, err := durationHook.BeforeProcessPipeline(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
assert.Equal(t, "redis", tracesdk.SpanFromContext(ctx).(interface{ Name() string }).Name())
time.Sleep(slowThreshold.Load() + time.Millisecond)
assert.Nil(t, durationHook.AfterProcessPipeline(ctx, []red.Cmder{
red.NewCmd(context.Background(), "foo", "bar"),
}))
assert.True(t, strings.Contains(buf.String(), "slow"))
assert.True(t, strings.Contains(buf.String(), "trace"))
assert.True(t, strings.Contains(buf.String(), "span"))
}
func TestHookProcessPipelineCase3(t *testing.T) {
writer := log.Writer()
var buf strings.Builder
log.SetOutput(&buf)
defer log.SetOutput(writer)
assert.Nil(t, durationHook.AfterProcessPipeline(context.Background(), []red.Cmder{
red.NewCmd(context.Background()),
}))
assert.True(t, buf.Len() == 0)
}
func TestHookProcessPipelineCase4(t *testing.T) {
writer := log.Writer()
var buf strings.Builder
log.SetOutput(&buf)
defer log.SetOutput(writer)
ctx := context.WithValue(context.Background(), startTimeKey, "foo")
assert.Nil(t, durationHook.AfterProcessPipeline(ctx, []red.Cmder{
red.NewCmd(context.Background()),
}))
assert.True(t, buf.Len() == 0)
}
func TestHookProcessPipelineCase5(t *testing.T) {
writer := log.Writer()
var buf strings.Builder
log.SetOutput(&buf)
defer log.SetOutput(writer)
ctx := context.WithValue(context.Background(), startTimeKey, "foo")
assert.Nil(t, durationHook.AfterProcessPipeline(ctx, nil))
assert.True(t, buf.Len() == 0)
}

View File

@@ -1,32 +0,0 @@
package redis
import (
"strings"
red "github.com/go-redis/redis"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/mapping"
"github.com/zeromicro/go-zero/core/timex"
)
func checkDuration(proc func(red.Cmder) error) func(red.Cmder) error {
return func(cmd red.Cmder) error {
start := timex.Now()
defer func() {
duration := timex.Since(start)
if duration > slowThreshold.Load() {
var buf strings.Builder
for i, arg := range cmd.Args() {
if i > 0 {
buf.WriteByte(' ')
}
buf.WriteString(mapping.Repr(arg))
}
logx.WithDuration(duration).Slowf("[REDIS] slowcall on executing: %s", buf.String())
}
}()
return proc(cmd)
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,7 @@
package redis package redis
import ( import (
"context"
"crypto/tls" "crypto/tls"
"errors" "errors"
"io" "io"
@@ -9,8 +10,9 @@ import (
"time" "time"
"github.com/alicebob/miniredis/v2" "github.com/alicebob/miniredis/v2"
red "github.com/go-redis/redis" red "github.com/go-redis/redis/v8"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stringx" "github.com/zeromicro/go-zero/core/stringx"
) )
@@ -361,6 +363,11 @@ func TestRedis_List(t *testing.T) {
vals, err = client.Lrange("key", 0, 10) vals, err = client.Lrange("key", 0, 10)
assert.Nil(t, err) assert.Nil(t, err)
assert.EqualValues(t, []string{"value2", "value3", "value4"}, vals) assert.EqualValues(t, []string{"value2", "value3", "value4"}, vals)
err = client.Ltrim("key", 0, 1)
assert.Nil(t, err)
vals, err = client.Lrange("key", 0, 10)
assert.Nil(t, err)
assert.EqualValues(t, []string{"value2", "value3"}, vals)
}) })
} }
@@ -380,30 +387,33 @@ func TestRedis_Mget(t *testing.T) {
func TestRedis_SetBit(t *testing.T) { func TestRedis_SetBit(t *testing.T) {
runOnRedis(t, func(client *Redis) { runOnRedis(t, func(client *Redis) {
err := New(client.Addr, badType()).SetBit("key", 1, 1) _, err := New(client.Addr, badType()).SetBit("key", 1, 1)
assert.NotNil(t, err) assert.NotNil(t, err)
err = client.SetBit("key", 1, 1) val, err := client.SetBit("key", 1, 1)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, 0, val)
}) })
} }
func TestRedis_GetBit(t *testing.T) { func TestRedis_GetBit(t *testing.T) {
runOnRedis(t, func(client *Redis) { runOnRedis(t, func(client *Redis) {
err := client.SetBit("key", 2, 1) val, err := client.SetBit("key", 2, 1)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, 0, val)
_, err = New(client.Addr, badType()).GetBit("key", 2) _, err = New(client.Addr, badType()).GetBit("key", 2)
assert.NotNil(t, err) assert.NotNil(t, err)
val, err := client.GetBit("key", 2) v, err := client.GetBit("key", 2)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, 1, val) assert.Equal(t, 1, v)
}) })
} }
func TestRedis_BitCount(t *testing.T) { func TestRedis_BitCount(t *testing.T) {
runOnRedis(t, func(client *Redis) { runOnRedis(t, func(client *Redis) {
for i := 0; i < 11; i++ { for i := 0; i < 11; i++ {
err := client.SetBit("key", int64(i), 1) val, err := client.SetBit("key", int64(i), 1)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, 0, val)
} }
_, err := New(client.Addr, badType()).BitCount("key", 0, -1) _, err := New(client.Addr, badType()).BitCount("key", 0, -1)
@@ -694,6 +704,28 @@ func TestRedis_Set(t *testing.T) {
}) })
} }
func TestRedis_GetSet(t *testing.T) {
runOnRedis(t, func(client *Redis) {
_, err := New(client.Addr, badType()).GetSet("hello", "world")
assert.NotNil(t, err)
val, err := client.GetSet("hello", "world")
assert.Nil(t, err)
assert.Equal(t, "", val)
val, err = client.Get("hello")
assert.Nil(t, err)
assert.Equal(t, "world", val)
val, err = client.GetSet("hello", "newworld")
assert.Nil(t, err)
assert.Equal(t, "world", val)
val, err = client.Get("hello")
assert.Nil(t, err)
assert.Equal(t, "newworld", val)
ret, err := client.Del("hello")
assert.Nil(t, err)
assert.Equal(t, 1, ret)
})
}
func TestRedis_SetGetDel(t *testing.T) { func TestRedis_SetGetDel(t *testing.T) {
runOnRedis(t, func(client *Redis) { runOnRedis(t, func(client *Redis) {
err := New(client.Addr, badType()).Set("hello", "world") err := New(client.Addr, badType()).Set("hello", "world")
@@ -958,13 +990,14 @@ func TestRedis_SortedSet(t *testing.T) {
assert.NotNil(t, err) assert.NotNil(t, err)
client.Zadd("second", 2, "aa") client.Zadd("second", 2, "aa")
client.Zadd("third", 3, "bbb") client.Zadd("third", 3, "bbb")
val, err = client.Zunionstore("union", ZStore{ val, err = client.Zunionstore("union", &ZStore{
Keys: []string{"second", "third"},
Weights: []float64{1, 2}, Weights: []float64{1, 2},
Aggregate: "SUM", Aggregate: "SUM",
}, "second", "third") })
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, int64(2), val) assert.Equal(t, int64(2), val)
_, err = New(client.Addr, badType()).Zunionstore("union", ZStore{}) _, err = New(client.Addr, badType()).Zunionstore("union", &ZStore{})
assert.NotNil(t, err) assert.NotNil(t, err)
vals, err = client.Zrange("union", 0, 10000) vals, err = client.Zrange("union", 0, 10000)
assert.Nil(t, err) assert.Nil(t, err)
@@ -982,9 +1015,9 @@ func TestRedis_Pipelined(t *testing.T) {
})) }))
err := client.Pipelined( err := client.Pipelined(
func(pipe Pipeliner) error { func(pipe Pipeliner) error {
pipe.Incr("pipelined_counter") pipe.Incr(context.Background(), "pipelined_counter")
pipe.Expire("pipelined_counter", time.Hour) pipe.Expire(context.Background(), "pipelined_counter", time.Hour)
pipe.ZAdd("zadd", Z{Score: 12, Member: "zadd"}) pipe.ZAdd(context.Background(), "zadd", &Z{Score: 12, Member: "zadd"})
return nil return nil
}, },
) )
@@ -1130,6 +1163,8 @@ func TestRedis_WithPass(t *testing.T) {
} }
func runOnRedis(t *testing.T, fn func(client *Redis)) { func runOnRedis(t *testing.T, fn func(client *Redis)) {
logx.Disable()
s, err := miniredis.Run() s, err := miniredis.Run()
assert.Nil(t, err) assert.Nil(t, err)
defer func() { defer func() {
@@ -1148,6 +1183,8 @@ func runOnRedis(t *testing.T, fn func(client *Redis)) {
} }
func runOnRedisTLS(t *testing.T, fn func(client *Redis)) { func runOnRedisTLS(t *testing.T, fn func(client *Redis)) {
logx.Disable()
s, err := miniredis.RunTLS(&tls.Config{ s, err := miniredis.RunTLS(&tls.Config{
Certificates: make([]tls.Certificate, 1), Certificates: make([]tls.Certificate, 1),
InsecureSkipVerify: true, InsecureSkipVerify: true,
@@ -1177,6 +1214,6 @@ type mockedNode struct {
RedisNode RedisNode
} }
func (n mockedNode) BLPop(timeout time.Duration, keys ...string) *red.StringSliceCmd { func (n mockedNode) BLPop(ctx context.Context, timeout time.Duration, keys ...string) *red.StringSliceCmd {
return red.NewStringSliceCmd("foo", "bar") return red.NewStringSliceCmd(context.Background(), "foo", "bar")
} }

View File

@@ -3,7 +3,7 @@ package redis
import ( import (
"fmt" "fmt"
red "github.com/go-redis/redis" red "github.com/go-redis/redis/v8"
"github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/logx"
) )

View File

@@ -4,7 +4,7 @@ import (
"crypto/tls" "crypto/tls"
"io" "io"
red "github.com/go-redis/redis" red "github.com/go-redis/redis/v8"
"github.com/zeromicro/go-zero/core/syncx" "github.com/zeromicro/go-zero/core/syncx"
) )
@@ -32,7 +32,8 @@ func getClient(r *Redis) (*red.Client, error) {
MinIdleConns: idleConns, MinIdleConns: idleConns,
TLSConfig: tlsConfig, TLSConfig: tlsConfig,
}) })
store.WrapProcess(checkDuration) store.AddHook(durationHook)
return store, nil return store, nil
}) })
if err != nil { if err != nil {

View File

@@ -4,7 +4,7 @@ import (
"crypto/tls" "crypto/tls"
"io" "io"
red "github.com/go-redis/redis" red "github.com/go-redis/redis/v8"
"github.com/zeromicro/go-zero/core/syncx" "github.com/zeromicro/go-zero/core/syncx"
) )
@@ -25,7 +25,7 @@ func getCluster(r *Redis) (*red.ClusterClient, error) {
MinIdleConns: idleConns, MinIdleConns: idleConns,
TLSConfig: tlsConfig, TLSConfig: tlsConfig,
}) })
store.WrapProcess(checkDuration) store.AddHook(durationHook)
return store, nil return store, nil
}) })

View File

@@ -2,28 +2,36 @@ package redis
import ( import (
"math/rand" "math/rand"
"strconv"
"sync/atomic" "sync/atomic"
"time" "time"
red "github.com/go-redis/redis" red "github.com/go-redis/redis/v8"
"github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stringx" "github.com/zeromicro/go-zero/core/stringx"
) )
const ( const (
randomLen = 16
tolerance = 500 // milliseconds
millisPerSecond = 1000
lockCommand = `if redis.call("GET", KEYS[1]) == ARGV[1] then
redis.call("SET", KEYS[1], ARGV[1], "PX", ARGV[2])
return "OK"
else
return redis.call("SET", KEYS[1], ARGV[1], "NX", "PX", ARGV[2])
end`
delCommand = `if redis.call("GET", KEYS[1]) == ARGV[1] then delCommand = `if redis.call("GET", KEYS[1]) == ARGV[1] then
return redis.call("DEL", KEYS[1]) return redis.call("DEL", KEYS[1])
else else
return 0 return 0
end` end`
randomLen = 16
) )
// A RedisLock is a redis lock. // A RedisLock is a redis lock.
type RedisLock struct { type RedisLock struct {
store *Redis store *Redis
seconds uint32 seconds uint32
count int32
key string key string
id string id string
} }
@@ -43,35 +51,30 @@ func NewRedisLock(store *Redis, key string) *RedisLock {
// Acquire acquires the lock. // Acquire acquires the lock.
func (rl *RedisLock) Acquire() (bool, error) { func (rl *RedisLock) Acquire() (bool, error) {
newCount := atomic.AddInt32(&rl.count, 1) seconds := atomic.LoadUint32(&rl.seconds)
if newCount > 1 { resp, err := rl.store.Eval(lockCommand, []string{rl.key}, []string{
rl.id, strconv.Itoa(int(seconds)*millisPerSecond + tolerance),
})
if err == red.Nil {
return false, nil
} else if err != nil {
logx.Errorf("Error on acquiring lock for %s, %s", rl.key, err.Error())
return false, err
} else if resp == nil {
return false, nil
}
reply, ok := resp.(string)
if ok && reply == "OK" {
return true, nil return true, nil
} }
seconds := atomic.LoadUint32(&rl.seconds) logx.Errorf("Unknown reply when acquiring lock for %s: %v", rl.key, resp)
ok, err := rl.store.SetnxEx(rl.key, rl.id, int(seconds+1)) // +1s for tolerance return false, nil
if err == red.Nil {
atomic.AddInt32(&rl.count, -1)
return false, nil
} else if err != nil {
atomic.AddInt32(&rl.count, -1)
logx.Errorf("Error on acquiring lock for %s, %s", rl.key, err.Error())
return false, err
} else if !ok {
atomic.AddInt32(&rl.count, -1)
return false, nil
}
return true, nil
} }
// Release releases the lock. // Release releases the lock.
func (rl *RedisLock) Release() (bool, error) { func (rl *RedisLock) Release() (bool, error) {
newCount := atomic.AddInt32(&rl.count, -1)
if newCount > 0 {
return true, nil
}
resp, err := rl.store.Eval(delCommand, []string{rl.key}, []string{rl.id}) resp, err := rl.store.Eval(delCommand, []string{rl.key}, []string{rl.id})
if err != nil { if err != nil {
return false, err return false, err

View File

@@ -29,25 +29,5 @@ func TestRedisLock(t *testing.T) {
endAcquire, err := secondLock.Acquire() endAcquire, err := secondLock.Acquire()
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, endAcquire) assert.True(t, endAcquire)
endAcquire, err = secondLock.Acquire()
assert.Nil(t, err)
assert.True(t, endAcquire)
release, err = secondLock.Release()
assert.Nil(t, err)
assert.True(t, release)
againAcquire, err = firstLock.Acquire()
assert.Nil(t, err)
assert.False(t, againAcquire)
release, err = secondLock.Release()
assert.Nil(t, err)
assert.True(t, release)
firstAcquire, err = firstLock.Acquire()
assert.Nil(t, err)
assert.True(t, firstAcquire)
}) })
} }

View File

@@ -17,10 +17,12 @@ func CreateRedis() (r *redis.Redis, clean func(), err error) {
return redis.New(mr.Addr()), func() { return redis.New(mr.Addr()), func() {
ch := make(chan lang.PlaceholderType) ch := make(chan lang.PlaceholderType)
go func() { go func() {
mr.Close() mr.Close()
close(ch) close(ch)
}() }()
select { select {
case <-ch: case <-ch:
case <-time.After(time.Second): case <-time.After(time.Second):

View File

@@ -4,9 +4,12 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/logx"
) )
func TestScriptCache(t *testing.T) { func TestScriptCache(t *testing.T) {
logx.Disable()
cache := GetScriptCache() cache := GetScriptCache()
cache.SetSha("foo", "bar") cache.SetSha("foo", "bar")
cache.SetSha("bla", "blabla") cache.SetSha("bla", "blabla")

View File

@@ -1,6 +1,7 @@
package sqlc package sqlc
import ( import (
"context"
"database/sql" "database/sql"
"time" "time"
@@ -18,19 +19,27 @@ var (
ErrNotFound = sqlx.ErrNotFound ErrNotFound = sqlx.ErrNotFound
// can't use one SingleFlight per conn, because multiple conns may share the same cache key. // can't use one SingleFlight per conn, because multiple conns may share the same cache key.
exclusiveCalls = syncx.NewSingleFlight() singleFlights = syncx.NewSingleFlight()
stats = cache.NewStat("sqlc") stats = cache.NewStat("sqlc")
) )
type ( type (
// ExecFn defines the sql exec method. // ExecFn defines the sql exec method.
ExecFn func(conn sqlx.SqlConn) (sql.Result, error) ExecFn func(conn sqlx.SqlConn) (sql.Result, error)
// ExecCtxFn defines the sql exec method.
ExecCtxFn func(ctx context.Context, conn sqlx.SqlConn) (sql.Result, error)
// IndexQueryFn defines the query method that based on unique indexes. // IndexQueryFn defines the query method that based on unique indexes.
IndexQueryFn func(conn sqlx.SqlConn, v interface{}) (interface{}, error) IndexQueryFn func(conn sqlx.SqlConn, v interface{}) (interface{}, error)
// IndexQueryCtxFn defines the query method that based on unique indexes.
IndexQueryCtxFn func(ctx context.Context, conn sqlx.SqlConn, v interface{}) (interface{}, error)
// PrimaryQueryFn defines the query method that based on primary keys. // PrimaryQueryFn defines the query method that based on primary keys.
PrimaryQueryFn func(conn sqlx.SqlConn, v, primary interface{}) error PrimaryQueryFn func(conn sqlx.SqlConn, v, primary interface{}) error
// PrimaryQueryCtxFn defines the query method that based on primary keys.
PrimaryQueryCtxFn func(ctx context.Context, conn sqlx.SqlConn, v, primary interface{}) error
// QueryFn defines the query method. // QueryFn defines the query method.
QueryFn func(conn sqlx.SqlConn, v interface{}) error QueryFn func(conn sqlx.SqlConn, v interface{}) error
// QueryCtxFn defines the query method.
QueryCtxFn func(ctx context.Context, conn sqlx.SqlConn, v interface{}) error
// A CachedConn is a DB connection with cache capability. // A CachedConn is a DB connection with cache capability.
CachedConn struct { CachedConn struct {
@@ -41,7 +50,7 @@ type (
// NewConn returns a CachedConn with a redis cluster cache. // NewConn returns a CachedConn with a redis cluster cache.
func NewConn(db sqlx.SqlConn, c cache.CacheConf, opts ...cache.Option) CachedConn { func NewConn(db sqlx.SqlConn, c cache.CacheConf, opts ...cache.Option) CachedConn {
cc := cache.New(c, exclusiveCalls, stats, sql.ErrNoRows, opts...) cc := cache.New(c, singleFlights, stats, sql.ErrNoRows, opts...)
return NewConnWithCache(db, cc) return NewConnWithCache(db, cc)
} }
@@ -55,28 +64,46 @@ func NewConnWithCache(db sqlx.SqlConn, c cache.Cache) CachedConn {
// NewNodeConn returns a CachedConn with a redis node cache. // NewNodeConn returns a CachedConn with a redis node cache.
func NewNodeConn(db sqlx.SqlConn, rds *redis.Redis, opts ...cache.Option) CachedConn { func NewNodeConn(db sqlx.SqlConn, rds *redis.Redis, opts ...cache.Option) CachedConn {
c := cache.NewNode(rds, exclusiveCalls, stats, sql.ErrNoRows, opts...) c := cache.NewNode(rds, singleFlights, stats, sql.ErrNoRows, opts...)
return NewConnWithCache(db, c) return NewConnWithCache(db, c)
} }
// DelCache deletes cache with keys. // DelCache deletes cache with keys.
func (cc CachedConn) DelCache(keys ...string) error { func (cc CachedConn) DelCache(keys ...string) error {
return cc.cache.Del(keys...) return cc.DelCacheCtx(context.Background(), keys...)
}
// DelCacheCtx deletes cache with keys.
func (cc CachedConn) DelCacheCtx(ctx context.Context, keys ...string) error {
return cc.cache.DelCtx(ctx, keys...)
} }
// GetCache unmarshals cache with given key into v. // GetCache unmarshals cache with given key into v.
func (cc CachedConn) GetCache(key string, v interface{}) error { func (cc CachedConn) GetCache(key string, v interface{}) error {
return cc.cache.Get(key, v) return cc.GetCacheCtx(context.Background(), key, v)
}
// GetCacheCtx unmarshals cache with given key into v.
func (cc CachedConn) GetCacheCtx(ctx context.Context, key string, v interface{}) error {
return cc.cache.GetCtx(ctx, key, v)
} }
// Exec runs given exec on given keys, and returns execution result. // Exec runs given exec on given keys, and returns execution result.
func (cc CachedConn) Exec(exec ExecFn, keys ...string) (sql.Result, error) { func (cc CachedConn) Exec(exec ExecFn, keys ...string) (sql.Result, error) {
res, err := exec(cc.db) execCtx := func(_ context.Context, conn sqlx.SqlConn) (sql.Result, error) {
return exec(conn)
}
return cc.ExecCtx(context.Background(), execCtx, keys...)
}
// ExecCtx runs given exec on given keys, and returns execution result.
func (cc CachedConn) ExecCtx(ctx context.Context, exec ExecCtxFn, keys ...string) (sql.Result, error) {
res, err := exec(ctx, cc.db)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := cc.DelCache(keys...); err != nil { if err := cc.DelCacheCtx(ctx, keys...); err != nil {
return nil, err return nil, err
} }
@@ -85,31 +112,61 @@ func (cc CachedConn) Exec(exec ExecFn, keys ...string) (sql.Result, error) {
// ExecNoCache runs exec with given sql statement, without affecting cache. // ExecNoCache runs exec with given sql statement, without affecting cache.
func (cc CachedConn) ExecNoCache(q string, args ...interface{}) (sql.Result, error) { func (cc CachedConn) ExecNoCache(q string, args ...interface{}) (sql.Result, error) {
return cc.db.Exec(q, args...) return cc.ExecNoCacheCtx(context.Background(), q, args...)
}
// ExecNoCacheCtx runs exec with given sql statement, without affecting cache.
func (cc CachedConn) ExecNoCacheCtx(ctx context.Context, q string, args ...interface{}) (
sql.Result, error) {
return cc.db.ExecCtx(ctx, q, args...)
} }
// QueryRow unmarshals into v with given key and query func. // QueryRow unmarshals into v with given key and query func.
func (cc CachedConn) QueryRow(v interface{}, key string, query QueryFn) error { func (cc CachedConn) QueryRow(v interface{}, key string, query QueryFn) error {
return cc.cache.Take(v, key, func(v interface{}) error { queryCtx := func(_ context.Context, conn sqlx.SqlConn, v interface{}) error {
return query(cc.db, v) return query(conn, v)
}
return cc.QueryRowCtx(context.Background(), v, key, queryCtx)
}
// QueryRowCtx unmarshals into v with given key and query func.
func (cc CachedConn) QueryRowCtx(ctx context.Context, v interface{}, key string, query QueryCtxFn) error {
return cc.cache.TakeCtx(ctx, v, key, func(v interface{}) error {
return query(ctx, cc.db, v)
}) })
} }
// QueryRowIndex unmarshals into v with given key. // QueryRowIndex unmarshals into v with given key.
func (cc CachedConn) QueryRowIndex(v interface{}, key string, keyer func(primary interface{}) string, func (cc CachedConn) QueryRowIndex(v interface{}, key string, keyer func(primary interface{}) string,
indexQuery IndexQueryFn, primaryQuery PrimaryQueryFn) error { indexQuery IndexQueryFn, primaryQuery PrimaryQueryFn) error {
indexQueryCtx := func(_ context.Context, conn sqlx.SqlConn, v interface{}) (interface{}, error) {
return indexQuery(conn, v)
}
primaryQueryCtx := func(_ context.Context, conn sqlx.SqlConn, v, primary interface{}) error {
return primaryQuery(conn, v, primary)
}
return cc.QueryRowIndexCtx(context.Background(), v, key, keyer, indexQueryCtx, primaryQueryCtx)
}
// QueryRowIndexCtx unmarshals into v with given key.
func (cc CachedConn) QueryRowIndexCtx(ctx context.Context, v interface{}, key string,
keyer func(primary interface{}) string, indexQuery IndexQueryCtxFn,
primaryQuery PrimaryQueryCtxFn) error {
var primaryKey interface{} var primaryKey interface{}
var found bool var found bool
if err := cc.cache.TakeWithExpire(&primaryKey, key, func(val interface{}, expire time.Duration) (err error) { if err := cc.cache.TakeWithExpireCtx(ctx, &primaryKey, key,
primaryKey, err = indexQuery(cc.db, v) func(val interface{}, expire time.Duration) (err error) {
if err != nil { primaryKey, err = indexQuery(ctx, cc.db, v)
return if err != nil {
} return
}
found = true found = true
return cc.cache.SetWithExpire(keyer(primaryKey), v, expire+cacheSafeGapBetweenIndexAndPrimary) return cc.cache.SetWithExpireCtx(ctx, keyer(primaryKey), v,
}); err != nil { expire+cacheSafeGapBetweenIndexAndPrimary)
}); err != nil {
return err return err
} }
@@ -117,28 +174,54 @@ func (cc CachedConn) QueryRowIndex(v interface{}, key string, keyer func(primary
return nil return nil
} }
return cc.cache.Take(v, keyer(primaryKey), func(v interface{}) error { return cc.cache.TakeCtx(ctx, v, keyer(primaryKey), func(v interface{}) error {
return primaryQuery(cc.db, v, primaryKey) return primaryQuery(ctx, cc.db, v, primaryKey)
}) })
} }
// QueryRowNoCache unmarshals into v with given statement. // QueryRowNoCache unmarshals into v with given statement.
func (cc CachedConn) QueryRowNoCache(v interface{}, q string, args ...interface{}) error { func (cc CachedConn) QueryRowNoCache(v interface{}, q string, args ...interface{}) error {
return cc.db.QueryRow(v, q, args...) return cc.QueryRowNoCacheCtx(context.Background(), v, q, args...)
}
// QueryRowNoCacheCtx unmarshals into v with given statement.
func (cc CachedConn) QueryRowNoCacheCtx(ctx context.Context, v interface{}, q string,
args ...interface{}) error {
return cc.db.QueryRowCtx(ctx, v, q, args...)
} }
// QueryRowsNoCache unmarshals into v with given statement. // QueryRowsNoCache unmarshals into v with given statement.
// It doesn't use cache, because it might cause consistency problem. // It doesn't use cache, because it might cause consistency problem.
func (cc CachedConn) QueryRowsNoCache(v interface{}, q string, args ...interface{}) error { func (cc CachedConn) QueryRowsNoCache(v interface{}, q string, args ...interface{}) error {
return cc.db.QueryRows(v, q, args...) return cc.QueryRowsNoCacheCtx(context.Background(), v, q, args...)
}
// QueryRowsNoCacheCtx unmarshals into v with given statement.
// It doesn't use cache, because it might cause consistency problem.
func (cc CachedConn) QueryRowsNoCacheCtx(ctx context.Context, v interface{}, q string,
args ...interface{}) error {
return cc.db.QueryRowsCtx(ctx, v, q, args...)
} }
// SetCache sets v into cache with given key. // SetCache sets v into cache with given key.
func (cc CachedConn) SetCache(key string, v interface{}) error { func (cc CachedConn) SetCache(key string, val interface{}) error {
return cc.cache.Set(key, v) return cc.SetCacheCtx(context.Background(), key, val)
}
// SetCacheCtx sets v into cache with given key.
func (cc CachedConn) SetCacheCtx(ctx context.Context, key string, val interface{}) error {
return cc.cache.SetCtx(ctx, key, val)
} }
// Transact runs given fn in transaction mode. // Transact runs given fn in transaction mode.
func (cc CachedConn) Transact(fn func(sqlx.Session) error) error { func (cc CachedConn) Transact(fn func(sqlx.Session) error) error {
return cc.db.Transact(fn) fnCtx := func(_ context.Context, session sqlx.Session) error {
return fn(session)
}
return cc.TransactCtx(context.Background(), fnCtx)
}
// TransactCtx runs given fn in transaction mode.
func (cc CachedConn) TransactCtx(ctx context.Context, fn func(context.Context, sqlx.Session) error) error {
return cc.db.TransactCtx(ctx, fn)
} }

View File

@@ -1,6 +1,7 @@
package sqlc package sqlc
import ( import (
"context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"errors" "errors"
@@ -568,7 +569,7 @@ func TestNewConnWithCache(t *testing.T) {
defer clean() defer clean()
var conn trackedConn var conn trackedConn
c := NewConnWithCache(&conn, cache.NewNode(r, exclusiveCalls, stats, sql.ErrNoRows)) c := NewConnWithCache(&conn, cache.NewNode(r, singleFlights, stats, sql.ErrNoRows))
_, err = c.ExecNoCache("delete from user_table where id='kevin'") _, err = c.ExecNoCache("delete from user_table where id='kevin'")
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, conn.execValue) assert.True(t, conn.execValue)
@@ -585,6 +586,30 @@ type dummySqlConn struct {
queryRow func(interface{}, string, ...interface{}) error queryRow func(interface{}, string, ...interface{}) error
} }
func (d dummySqlConn) ExecCtx(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
return nil, nil
}
func (d dummySqlConn) PrepareCtx(ctx context.Context, query string) (sqlx.StmtSession, error) {
return nil, nil
}
func (d dummySqlConn) QueryRowPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
return nil
}
func (d dummySqlConn) QueryRowsCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
return nil
}
func (d dummySqlConn) QueryRowsPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
return nil
}
func (d dummySqlConn) TransactCtx(ctx context.Context, fn func(context.Context, sqlx.Session) error) error {
return nil
}
func (d dummySqlConn) Exec(query string, args ...interface{}) (sql.Result, error) { func (d dummySqlConn) Exec(query string, args ...interface{}) (sql.Result, error) {
return nil, nil return nil, nil
} }
@@ -594,6 +619,10 @@ func (d dummySqlConn) Prepare(query string) (sqlx.StmtSession, error) {
} }
func (d dummySqlConn) QueryRow(v interface{}, query string, args ...interface{}) error { func (d dummySqlConn) QueryRow(v interface{}, query string, args ...interface{}) error {
return d.QueryRowCtx(context.Background(), v, query, args...)
}
func (d dummySqlConn) QueryRowCtx(_ context.Context, v interface{}, query string, args ...interface{}) error {
if d.queryRow != nil { if d.queryRow != nil {
return d.queryRow(v, query, args...) return d.queryRow(v, query, args...)
} }
@@ -628,13 +657,21 @@ type trackedConn struct {
} }
func (c *trackedConn) Exec(query string, args ...interface{}) (sql.Result, error) { func (c *trackedConn) Exec(query string, args ...interface{}) (sql.Result, error) {
return c.ExecCtx(context.Background(), query, args...)
}
func (c *trackedConn) ExecCtx(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
c.execValue = true c.execValue = true
return c.dummySqlConn.Exec(query, args...) return c.dummySqlConn.ExecCtx(ctx, query, args...)
} }
func (c *trackedConn) QueryRows(v interface{}, query string, args ...interface{}) error { func (c *trackedConn) QueryRows(v interface{}, query string, args ...interface{}) error {
return c.QueryRowsCtx(context.Background(), v, query, args...)
}
func (c *trackedConn) QueryRowsCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
c.queryRowsValue = true c.queryRowsValue = true
return c.dummySqlConn.QueryRows(v, query, args...) return c.dummySqlConn.QueryRowsCtx(ctx, v, query, args...)
} }
func (c *trackedConn) RawDB() (*sql.DB, error) { func (c *trackedConn) RawDB() (*sql.DB, error) {
@@ -642,6 +679,12 @@ func (c *trackedConn) RawDB() (*sql.DB, error) {
} }
func (c *trackedConn) Transact(fn func(session sqlx.Session) error) error { func (c *trackedConn) Transact(fn func(session sqlx.Session) error) error {
c.transactValue = true return c.TransactCtx(context.Background(), func(_ context.Context, session sqlx.Session) error {
return c.dummySqlConn.Transact(fn) return fn(session)
})
}
func (c *trackedConn) TransactCtx(ctx context.Context, fn func(context.Context, sqlx.Session) error) error {
c.transactValue = true
return c.dummySqlConn.TransactCtx(ctx, fn)
} }

View File

@@ -25,6 +25,7 @@ type (
// A BulkInserter is used to batch insert records. // A BulkInserter is used to batch insert records.
// Postgresql is not supported yet, because of the sql is formated with symbol `$`. // Postgresql is not supported yet, because of the sql is formated with symbol `$`.
// Oracle is not supported yet, because of the sql is formated with symbol `:`.
BulkInserter struct { BulkInserter struct {
executor *executors.PeriodicalExecutor executor *executors.PeriodicalExecutor
inserter *dbInserter inserter *dbInserter

View File

@@ -1,6 +1,7 @@
package sqlx package sqlx
import ( import (
"context"
"database/sql" "database/sql"
"errors" "errors"
"strconv" "strconv"
@@ -17,12 +18,40 @@ type mockedConn struct {
execErr error execErr error
} }
func (c *mockedConn) Exec(query string, args ...interface{}) (sql.Result, error) { func (c *mockedConn) ExecCtx(_ context.Context, query string, args ...interface{}) (sql.Result, error) {
c.query = query c.query = query
c.args = args c.args = args
return nil, c.execErr return nil, c.execErr
} }
func (c *mockedConn) PrepareCtx(ctx context.Context, query string) (StmtSession, error) {
panic("implement me")
}
func (c *mockedConn) QueryRowCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
panic("implement me")
}
func (c *mockedConn) QueryRowPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
panic("implement me")
}
func (c *mockedConn) QueryRowsCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
panic("implement me")
}
func (c *mockedConn) QueryRowsPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
panic("implement me")
}
func (c *mockedConn) TransactCtx(ctx context.Context, fn func(context.Context, Session) error) error {
panic("should not called")
}
func (c *mockedConn) Exec(query string, args ...interface{}) (sql.Result, error) {
return c.ExecCtx(context.Background(), query, args...)
}
func (c *mockedConn) Prepare(query string) (StmtSession, error) { func (c *mockedConn) Prepare(query string) (StmtSession, error) {
panic("should not called") panic("should not called")
} }

View File

@@ -1,6 +1,7 @@
package sqlx package sqlx
import ( import (
"context"
"database/sql" "database/sql"
"errors" "errors"
"testing" "testing"
@@ -16,7 +17,7 @@ func TestUnmarshalRowBool(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value bool var value bool
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true) return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone")) }, "select value from users where user=?", "anyone"))
assert.True(t, value) assert.True(t, value)
@@ -29,7 +30,7 @@ func TestUnmarshalRowBoolNotSettable(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value bool var value bool
assert.NotNil(t, query(db, func(rows *sql.Rows) error { assert.NotNil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(value, rows, true) return unmarshalRow(value, rows, true)
}, "select value from users where user=?", "anyone")) }, "select value from users where user=?", "anyone"))
}) })
@@ -41,7 +42,7 @@ func TestUnmarshalRowInt(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value int var value int
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true) return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone")) }, "select value from users where user=?", "anyone"))
assert.EqualValues(t, 2, value) assert.EqualValues(t, 2, value)
@@ -54,7 +55,7 @@ func TestUnmarshalRowInt8(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value int8 var value int8
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true) return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone")) }, "select value from users where user=?", "anyone"))
assert.EqualValues(t, int8(3), value) assert.EqualValues(t, int8(3), value)
@@ -67,7 +68,7 @@ func TestUnmarshalRowInt16(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value int16 var value int16
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true) return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone")) }, "select value from users where user=?", "anyone"))
assert.Equal(t, int16(4), value) assert.Equal(t, int16(4), value)
@@ -80,7 +81,7 @@ func TestUnmarshalRowInt32(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value int32 var value int32
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true) return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone")) }, "select value from users where user=?", "anyone"))
assert.Equal(t, int32(5), value) assert.Equal(t, int32(5), value)
@@ -93,7 +94,7 @@ func TestUnmarshalRowInt64(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value int64 var value int64
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true) return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone")) }, "select value from users where user=?", "anyone"))
assert.EqualValues(t, int64(6), value) assert.EqualValues(t, int64(6), value)
@@ -106,7 +107,7 @@ func TestUnmarshalRowUint(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value uint var value uint
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true) return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone")) }, "select value from users where user=?", "anyone"))
assert.EqualValues(t, uint(2), value) assert.EqualValues(t, uint(2), value)
@@ -119,7 +120,7 @@ func TestUnmarshalRowUint8(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value uint8 var value uint8
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true) return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone")) }, "select value from users where user=?", "anyone"))
assert.EqualValues(t, uint8(3), value) assert.EqualValues(t, uint8(3), value)
@@ -132,7 +133,7 @@ func TestUnmarshalRowUint16(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value uint16 var value uint16
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true) return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone")) }, "select value from users where user=?", "anyone"))
assert.EqualValues(t, uint16(4), value) assert.EqualValues(t, uint16(4), value)
@@ -145,7 +146,7 @@ func TestUnmarshalRowUint32(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value uint32 var value uint32
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true) return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone")) }, "select value from users where user=?", "anyone"))
assert.EqualValues(t, uint32(5), value) assert.EqualValues(t, uint32(5), value)
@@ -158,7 +159,7 @@ func TestUnmarshalRowUint64(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value uint64 var value uint64
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true) return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone")) }, "select value from users where user=?", "anyone"))
assert.EqualValues(t, uint16(6), value) assert.EqualValues(t, uint16(6), value)
@@ -171,7 +172,7 @@ func TestUnmarshalRowFloat32(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value float32 var value float32
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true) return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone")) }, "select value from users where user=?", "anyone"))
assert.EqualValues(t, float32(7), value) assert.EqualValues(t, float32(7), value)
@@ -184,7 +185,7 @@ func TestUnmarshalRowFloat64(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value float64 var value float64
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true) return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone")) }, "select value from users where user=?", "anyone"))
assert.EqualValues(t, float64(8), value) assert.EqualValues(t, float64(8), value)
@@ -198,7 +199,7 @@ func TestUnmarshalRowString(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value string var value string
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true) return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone")) }, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value) assert.EqualValues(t, expect, value)
@@ -215,7 +216,7 @@ func TestUnmarshalRowStruct(t *testing.T) {
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5") rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(value, rows, true) return unmarshalRow(value, rows, true)
}, "select name, age from users where user=?", "anyone")) }, "select name, age from users where user=?", "anyone"))
assert.Equal(t, "liao", value.Name) assert.Equal(t, "liao", value.Name)
@@ -233,7 +234,7 @@ func TestUnmarshalRowStructWithTags(t *testing.T) {
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5") rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(value, rows, true) return unmarshalRow(value, rows, true)
}, "select name, age from users where user=?", "anyone")) }, "select name, age from users where user=?", "anyone"))
assert.Equal(t, "liao", value.Name) assert.Equal(t, "liao", value.Name)
@@ -251,7 +252,7 @@ func TestUnmarshalRowStructWithTagsWrongColumns(t *testing.T) {
rs := sqlmock.NewRows([]string{"name"}).FromCSVString("liao") rs := sqlmock.NewRows([]string{"name"}).FromCSVString("liao")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.NotNil(t, query(db, func(rows *sql.Rows) error { assert.NotNil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(value, rows, true) return unmarshalRow(value, rows, true)
}, "select name, age from users where user=?", "anyone")) }, "select name, age from users where user=?", "anyone"))
}) })
@@ -264,7 +265,7 @@ func TestUnmarshalRowsBool(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []bool var value []bool
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true) return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone")) }, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value) assert.EqualValues(t, expect, value)
@@ -278,7 +279,7 @@ func TestUnmarshalRowsInt(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []int var value []int
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true) return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone")) }, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value) assert.EqualValues(t, expect, value)
@@ -292,7 +293,7 @@ func TestUnmarshalRowsInt8(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []int8 var value []int8
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true) return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone")) }, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value) assert.EqualValues(t, expect, value)
@@ -306,7 +307,7 @@ func TestUnmarshalRowsInt16(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []int16 var value []int16
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true) return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone")) }, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value) assert.EqualValues(t, expect, value)
@@ -320,7 +321,7 @@ func TestUnmarshalRowsInt32(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []int32 var value []int32
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true) return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone")) }, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value) assert.EqualValues(t, expect, value)
@@ -334,7 +335,7 @@ func TestUnmarshalRowsInt64(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []int64 var value []int64
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true) return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone")) }, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value) assert.EqualValues(t, expect, value)
@@ -348,7 +349,7 @@ func TestUnmarshalRowsUint(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []uint var value []uint
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true) return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone")) }, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value) assert.EqualValues(t, expect, value)
@@ -362,7 +363,7 @@ func TestUnmarshalRowsUint8(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []uint8 var value []uint8
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true) return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone")) }, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value) assert.EqualValues(t, expect, value)
@@ -376,7 +377,7 @@ func TestUnmarshalRowsUint16(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []uint16 var value []uint16
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true) return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone")) }, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value) assert.EqualValues(t, expect, value)
@@ -390,7 +391,7 @@ func TestUnmarshalRowsUint32(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []uint32 var value []uint32
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true) return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone")) }, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value) assert.EqualValues(t, expect, value)
@@ -404,7 +405,7 @@ func TestUnmarshalRowsUint64(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []uint64 var value []uint64
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true) return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone")) }, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value) assert.EqualValues(t, expect, value)
@@ -418,7 +419,7 @@ func TestUnmarshalRowsFloat32(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []float32 var value []float32
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true) return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone")) }, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value) assert.EqualValues(t, expect, value)
@@ -432,7 +433,7 @@ func TestUnmarshalRowsFloat64(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []float64 var value []float64
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true) return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone")) }, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value) assert.EqualValues(t, expect, value)
@@ -446,7 +447,7 @@ func TestUnmarshalRowsString(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []string var value []string
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true) return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone")) }, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value) assert.EqualValues(t, expect, value)
@@ -462,7 +463,7 @@ func TestUnmarshalRowsBoolPtr(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*bool var value []*bool
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true) return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone")) }, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value) assert.EqualValues(t, expect, value)
@@ -478,7 +479,7 @@ func TestUnmarshalRowsIntPtr(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*int var value []*int
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true) return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone")) }, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value) assert.EqualValues(t, expect, value)
@@ -494,7 +495,7 @@ func TestUnmarshalRowsInt8Ptr(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*int8 var value []*int8
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true) return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone")) }, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value) assert.EqualValues(t, expect, value)
@@ -510,7 +511,7 @@ func TestUnmarshalRowsInt16Ptr(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*int16 var value []*int16
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true) return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone")) }, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value) assert.EqualValues(t, expect, value)
@@ -526,7 +527,7 @@ func TestUnmarshalRowsInt32Ptr(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*int32 var value []*int32
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true) return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone")) }, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value) assert.EqualValues(t, expect, value)
@@ -542,7 +543,7 @@ func TestUnmarshalRowsInt64Ptr(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*int64 var value []*int64
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true) return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone")) }, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value) assert.EqualValues(t, expect, value)
@@ -558,7 +559,7 @@ func TestUnmarshalRowsUintPtr(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*uint var value []*uint
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true) return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone")) }, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value) assert.EqualValues(t, expect, value)
@@ -574,7 +575,7 @@ func TestUnmarshalRowsUint8Ptr(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*uint8 var value []*uint8
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true) return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone")) }, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value) assert.EqualValues(t, expect, value)
@@ -590,7 +591,7 @@ func TestUnmarshalRowsUint16Ptr(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*uint16 var value []*uint16
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true) return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone")) }, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value) assert.EqualValues(t, expect, value)
@@ -606,7 +607,7 @@ func TestUnmarshalRowsUint32Ptr(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*uint32 var value []*uint32
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true) return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone")) }, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value) assert.EqualValues(t, expect, value)
@@ -622,7 +623,7 @@ func TestUnmarshalRowsUint64Ptr(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*uint64 var value []*uint64
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true) return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone")) }, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value) assert.EqualValues(t, expect, value)
@@ -638,7 +639,7 @@ func TestUnmarshalRowsFloat32Ptr(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*float32 var value []*float32
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true) return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone")) }, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value) assert.EqualValues(t, expect, value)
@@ -654,7 +655,7 @@ func TestUnmarshalRowsFloat64Ptr(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*float64 var value []*float64
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true) return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone")) }, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value) assert.EqualValues(t, expect, value)
@@ -670,7 +671,7 @@ func TestUnmarshalRowsStringPtr(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*string var value []*string
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true) return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone")) }, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value) assert.EqualValues(t, expect, value)
@@ -699,7 +700,7 @@ func TestUnmarshalRowsStruct(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3") rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true) return unmarshalRows(&value, rows, true)
}, "select name, age from users where user=?", "anyone")) }, "select name, age from users where user=?", "anyone"))
@@ -739,7 +740,7 @@ func TestUnmarshalRowsStructWithNullStringType(t *testing.T) {
rs := sqlmock.NewRows([]string{"name", "value"}).AddRow( rs := sqlmock.NewRows([]string{"name", "value"}).AddRow(
"first", "firstnullstring").AddRow("second", nil) "first", "firstnullstring").AddRow("second", nil)
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true) return unmarshalRows(&value, rows, true)
}, "select name, age from users where user=?", "anyone")) }, "select name, age from users where user=?", "anyone"))
@@ -773,7 +774,7 @@ func TestUnmarshalRowsStructWithTags(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3") rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true) return unmarshalRows(&value, rows, true)
}, "select name, age from users where user=?", "anyone")) }, "select name, age from users where user=?", "anyone"))
@@ -814,7 +815,7 @@ func TestUnmarshalRowsStructAndEmbeddedAnonymousStructWithTags(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age", "value"}).FromCSVString("first,2,3\nsecond,3,4") rs := sqlmock.NewRows([]string{"name", "age", "value"}).FromCSVString("first,2,3\nsecond,3,4")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true) return unmarshalRows(&value, rows, true)
}, "select name, age, value from users where user=?", "anyone")) }, "select name, age, value from users where user=?", "anyone"))
@@ -856,7 +857,7 @@ func TestUnmarshalRowsStructAndEmbeddedStructPtrAnonymousWithTags(t *testing.T)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age", "value"}).FromCSVString("first,2,3\nsecond,3,4") rs := sqlmock.NewRows([]string{"name", "age", "value"}).FromCSVString("first,2,3\nsecond,3,4")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true) return unmarshalRows(&value, rows, true)
}, "select name, age, value from users where user=?", "anyone")) }, "select name, age, value from users where user=?", "anyone"))
@@ -890,7 +891,7 @@ func TestUnmarshalRowsStructPtr(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3") rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true) return unmarshalRows(&value, rows, true)
}, "select name, age from users where user=?", "anyone")) }, "select name, age from users where user=?", "anyone"))
@@ -923,7 +924,7 @@ func TestUnmarshalRowsStructWithTagsPtr(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3") rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true) return unmarshalRows(&value, rows, true)
}, "select name, age from users where user=?", "anyone")) }, "select name, age from users where user=?", "anyone"))
@@ -956,7 +957,7 @@ func TestUnmarshalRowsStructWithTagsPtrWithInnerPtr(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3") rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true) return unmarshalRows(&value, rows, true)
}, "select name, age from users where user=?", "anyone")) }, "select name, age from users where user=?", "anyone"))
@@ -976,7 +977,7 @@ func TestCommonSqlConn_QueryRowOptional(t *testing.T) {
User string `db:"user"` User string `db:"user"`
Age int `db:"age"` Age int `db:"age"`
} }
assert.Nil(t, query(db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(&r, rows, false) return unmarshalRow(&r, rows, false)
}, "select age from users where user=?", "anyone")) }, "select age from users where user=?", "anyone"))
assert.Empty(t, r.User) assert.Empty(t, r.User)
@@ -1027,7 +1028,7 @@ func TestUnmarshalRowError(t *testing.T) {
User string `db:"user"` User string `db:"user"`
Age int `db:"age"` Age int `db:"age"`
} }
test.validate(query(db, func(rows *sql.Rows) error { test.validate(query(context.Background(), db, func(rows *sql.Rows) error {
scanner := mockedScanner{ scanner := mockedScanner{
colErr: test.colErr, colErr: test.colErr,
scanErr: test.scanErr, scanErr: test.scanErr,

View File

@@ -1,12 +1,19 @@
package sqlx package sqlx
import ( import (
"context"
"database/sql" "database/sql"
"github.com/zeromicro/go-zero/core/breaker" "github.com/zeromicro/go-zero/core/breaker"
"github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/trace"
"go.opentelemetry.io/otel"
tracesdk "go.opentelemetry.io/otel/trace"
) )
// spanName is used to identify the span name for the SQL execution.
const spanName = "sql"
// ErrNotFound is an alias of sql.ErrNoRows // ErrNotFound is an alias of sql.ErrNoRows
var ErrNotFound = sql.ErrNoRows var ErrNotFound = sql.ErrNoRows
@@ -14,11 +21,17 @@ type (
// Session stands for raw connections or transaction sessions // Session stands for raw connections or transaction sessions
Session interface { Session interface {
Exec(query string, args ...interface{}) (sql.Result, error) Exec(query string, args ...interface{}) (sql.Result, error)
ExecCtx(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
Prepare(query string) (StmtSession, error) Prepare(query string) (StmtSession, error)
PrepareCtx(ctx context.Context, query string) (StmtSession, error)
QueryRow(v interface{}, query string, args ...interface{}) error QueryRow(v interface{}, query string, args ...interface{}) error
QueryRowCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error
QueryRowPartial(v interface{}, query string, args ...interface{}) error QueryRowPartial(v interface{}, query string, args ...interface{}) error
QueryRowPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error
QueryRows(v interface{}, query string, args ...interface{}) error QueryRows(v interface{}, query string, args ...interface{}) error
QueryRowsCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error
QueryRowsPartial(v interface{}, query string, args ...interface{}) error QueryRowsPartial(v interface{}, query string, args ...interface{}) error
QueryRowsPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error
} }
// SqlConn only stands for raw connections, so Transact method can be called. // SqlConn only stands for raw connections, so Transact method can be called.
@@ -27,7 +40,8 @@ type (
// RawDB is for other ORM to operate with, use it with caution. // RawDB is for other ORM to operate with, use it with caution.
// Notice: don't close it. // Notice: don't close it.
RawDB() (*sql.DB, error) RawDB() (*sql.DB, error)
Transact(func(session Session) error) error Transact(fn func(Session) error) error
TransactCtx(ctx context.Context, fn func(context.Context, Session) error) error
} }
// SqlOption defines the method to customize a sql connection. // SqlOption defines the method to customize a sql connection.
@@ -37,10 +51,15 @@ type (
StmtSession interface { StmtSession interface {
Close() error Close() error
Exec(args ...interface{}) (sql.Result, error) Exec(args ...interface{}) (sql.Result, error)
ExecCtx(ctx context.Context, args ...interface{}) (sql.Result, error)
QueryRow(v interface{}, args ...interface{}) error QueryRow(v interface{}, args ...interface{}) error
QueryRowCtx(ctx context.Context, v interface{}, args ...interface{}) error
QueryRowPartial(v interface{}, args ...interface{}) error QueryRowPartial(v interface{}, args ...interface{}) error
QueryRowPartialCtx(ctx context.Context, v interface{}, args ...interface{}) error
QueryRows(v interface{}, args ...interface{}) error QueryRows(v interface{}, args ...interface{}) error
QueryRowsCtx(ctx context.Context, v interface{}, args ...interface{}) error
QueryRowsPartial(v interface{}, args ...interface{}) error QueryRowsPartial(v interface{}, args ...interface{}) error
QueryRowsPartialCtx(ctx context.Context, v interface{}, args ...interface{}) error
} }
// thread-safe // thread-safe
@@ -58,7 +77,9 @@ type (
sessionConn interface { sessionConn interface {
Exec(query string, args ...interface{}) (sql.Result, error) Exec(query string, args ...interface{}) (sql.Result, error)
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
Query(query string, args ...interface{}) (*sql.Rows, error) Query(query string, args ...interface{}) (*sql.Rows, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
} }
statement struct { statement struct {
@@ -68,7 +89,9 @@ type (
stmtConn interface { stmtConn interface {
Exec(args ...interface{}) (sql.Result, error) Exec(args ...interface{}) (sql.Result, error)
ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error)
Query(args ...interface{}) (*sql.Rows, error) Query(args ...interface{}) (*sql.Rows, error)
QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error)
} }
) )
@@ -112,6 +135,14 @@ func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn {
} }
func (db *commonSqlConn) Exec(q string, args ...interface{}) (result sql.Result, err error) { func (db *commonSqlConn) Exec(q string, args ...interface{}) (result sql.Result, err error) {
return db.ExecCtx(context.Background(), q, args...)
}
func (db *commonSqlConn) ExecCtx(ctx context.Context, q string, args ...interface{}) (
result sql.Result, err error) {
ctx, span := startSpan(ctx)
defer span.End()
err = db.brk.DoWithAcceptable(func() error { err = db.brk.DoWithAcceptable(func() error {
var conn *sql.DB var conn *sql.DB
conn, err = db.connProv() conn, err = db.connProv()
@@ -120,7 +151,7 @@ func (db *commonSqlConn) Exec(q string, args ...interface{}) (result sql.Result,
return err return err
} }
result, err = exec(conn, q, args...) result, err = exec(ctx, conn, q, args...)
return err return err
}, db.acceptable) }, db.acceptable)
@@ -128,6 +159,13 @@ func (db *commonSqlConn) Exec(q string, args ...interface{}) (result sql.Result,
} }
func (db *commonSqlConn) Prepare(query string) (stmt StmtSession, err error) { func (db *commonSqlConn) Prepare(query string) (stmt StmtSession, err error) {
return db.PrepareCtx(context.Background(), query)
}
func (db *commonSqlConn) PrepareCtx(ctx context.Context, query string) (stmt StmtSession, err error) {
ctx, span := startSpan(ctx)
defer span.End()
err = db.brk.DoWithAcceptable(func() error { err = db.brk.DoWithAcceptable(func() error {
var conn *sql.DB var conn *sql.DB
conn, err = db.connProv() conn, err = db.connProv()
@@ -136,7 +174,7 @@ func (db *commonSqlConn) Prepare(query string) (stmt StmtSession, err error) {
return err return err
} }
st, err := conn.Prepare(query) st, err := conn.PrepareContext(ctx, query)
if err != nil { if err != nil {
return err return err
} }
@@ -152,25 +190,57 @@ func (db *commonSqlConn) Prepare(query string) (stmt StmtSession, err error) {
} }
func (db *commonSqlConn) QueryRow(v interface{}, q string, args ...interface{}) error { func (db *commonSqlConn) QueryRow(v interface{}, q string, args ...interface{}) error {
return db.queryRows(func(rows *sql.Rows) error { return db.QueryRowCtx(context.Background(), v, q, args...)
}
func (db *commonSqlConn) QueryRowCtx(ctx context.Context, v interface{}, q string,
args ...interface{}) error {
ctx, span := startSpan(ctx)
defer span.End()
return db.queryRows(ctx, func(rows *sql.Rows) error {
return unmarshalRow(v, rows, true) return unmarshalRow(v, rows, true)
}, q, args...) }, q, args...)
} }
func (db *commonSqlConn) QueryRowPartial(v interface{}, q string, args ...interface{}) error { func (db *commonSqlConn) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
return db.queryRows(func(rows *sql.Rows) error { return db.QueryRowPartialCtx(context.Background(), v, q, args...)
}
func (db *commonSqlConn) QueryRowPartialCtx(ctx context.Context, v interface{},
q string, args ...interface{}) error {
ctx, span := startSpan(ctx)
defer span.End()
return db.queryRows(ctx, func(rows *sql.Rows) error {
return unmarshalRow(v, rows, false) return unmarshalRow(v, rows, false)
}, q, args...) }, q, args...)
} }
func (db *commonSqlConn) QueryRows(v interface{}, q string, args ...interface{}) error { func (db *commonSqlConn) QueryRows(v interface{}, q string, args ...interface{}) error {
return db.queryRows(func(rows *sql.Rows) error { return db.QueryRowsCtx(context.Background(), v, q, args...)
}
func (db *commonSqlConn) QueryRowsCtx(ctx context.Context, v interface{}, q string,
args ...interface{}) error {
ctx, span := startSpan(ctx)
defer span.End()
return db.queryRows(ctx, func(rows *sql.Rows) error {
return unmarshalRows(v, rows, true) return unmarshalRows(v, rows, true)
}, q, args...) }, q, args...)
} }
func (db *commonSqlConn) QueryRowsPartial(v interface{}, q string, args ...interface{}) error { func (db *commonSqlConn) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
return db.queryRows(func(rows *sql.Rows) error { return db.QueryRowsPartialCtx(context.Background(), v, q, args...)
}
func (db *commonSqlConn) QueryRowsPartialCtx(ctx context.Context, v interface{},
q string, args ...interface{}) error {
ctx, span := startSpan(ctx)
defer span.End()
return db.queryRows(ctx, func(rows *sql.Rows) error {
return unmarshalRows(v, rows, false) return unmarshalRows(v, rows, false)
}, q, args...) }, q, args...)
} }
@@ -180,13 +250,22 @@ func (db *commonSqlConn) RawDB() (*sql.DB, error) {
} }
func (db *commonSqlConn) Transact(fn func(Session) error) error { func (db *commonSqlConn) Transact(fn func(Session) error) error {
return db.TransactCtx(context.Background(), func(_ context.Context, session Session) error {
return fn(session)
})
}
func (db *commonSqlConn) TransactCtx(ctx context.Context, fn func(context.Context, Session) error) error {
ctx, span := startSpan(ctx)
defer span.End()
return db.brk.DoWithAcceptable(func() error { return db.brk.DoWithAcceptable(func() error {
return transact(db, db.beginTx, fn) return transact(ctx, db, db.beginTx, fn)
}, db.acceptable) }, db.acceptable)
} }
func (db *commonSqlConn) acceptable(err error) bool { func (db *commonSqlConn) acceptable(err error) bool {
ok := err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone ok := err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone || err == context.Canceled
if db.accept == nil { if db.accept == nil {
return ok return ok
} }
@@ -194,7 +273,11 @@ func (db *commonSqlConn) acceptable(err error) bool {
return ok || db.accept(err) return ok || db.accept(err)
} }
func (db *commonSqlConn) queryRows(scanner func(*sql.Rows) error, q string, args ...interface{}) error { func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows) error,
q string, args ...interface{}) error {
ctx, span := startSpan(ctx)
defer span.End()
var qerr error var qerr error
return db.brk.DoWithAcceptable(func() error { return db.brk.DoWithAcceptable(func() error {
conn, err := db.connProv() conn, err := db.connProv()
@@ -203,7 +286,7 @@ func (db *commonSqlConn) queryRows(scanner func(*sql.Rows) error, q string, args
return err return err
} }
return query(conn, func(rows *sql.Rows) error { return query(ctx, conn, func(rows *sql.Rows) error {
qerr = scanner(rows) qerr = scanner(rows)
return qerr return qerr
}, q, args...) }, q, args...)
@@ -217,29 +300,69 @@ func (s statement) Close() error {
} }
func (s statement) Exec(args ...interface{}) (sql.Result, error) { func (s statement) Exec(args ...interface{}) (sql.Result, error) {
return execStmt(s.stmt, s.query, args...) return s.ExecCtx(context.Background(), args...)
}
func (s statement) ExecCtx(ctx context.Context, args ...interface{}) (sql.Result, error) {
ctx, span := startSpan(ctx)
defer span.End()
return execStmt(ctx, s.stmt, s.query, args...)
} }
func (s statement) QueryRow(v interface{}, args ...interface{}) error { func (s statement) QueryRow(v interface{}, args ...interface{}) error {
return queryStmt(s.stmt, func(rows *sql.Rows) error { return s.QueryRowCtx(context.Background(), v, args...)
}
func (s statement) QueryRowCtx(ctx context.Context, v interface{}, args ...interface{}) error {
ctx, span := startSpan(ctx)
defer span.End()
return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
return unmarshalRow(v, rows, true) return unmarshalRow(v, rows, true)
}, s.query, args...) }, s.query, args...)
} }
func (s statement) QueryRowPartial(v interface{}, args ...interface{}) error { func (s statement) QueryRowPartial(v interface{}, args ...interface{}) error {
return queryStmt(s.stmt, func(rows *sql.Rows) error { return s.QueryRowPartialCtx(context.Background(), v, args...)
}
func (s statement) QueryRowPartialCtx(ctx context.Context, v interface{}, args ...interface{}) error {
ctx, span := startSpan(ctx)
defer span.End()
return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
return unmarshalRow(v, rows, false) return unmarshalRow(v, rows, false)
}, s.query, args...) }, s.query, args...)
} }
func (s statement) QueryRows(v interface{}, args ...interface{}) error { func (s statement) QueryRows(v interface{}, args ...interface{}) error {
return queryStmt(s.stmt, func(rows *sql.Rows) error { return s.QueryRowsCtx(context.Background(), v, args...)
}
func (s statement) QueryRowsCtx(ctx context.Context, v interface{}, args ...interface{}) error {
ctx, span := startSpan(ctx)
defer span.End()
return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
return unmarshalRows(v, rows, true) return unmarshalRows(v, rows, true)
}, s.query, args...) }, s.query, args...)
} }
func (s statement) QueryRowsPartial(v interface{}, args ...interface{}) error { func (s statement) QueryRowsPartial(v interface{}, args ...interface{}) error {
return queryStmt(s.stmt, func(rows *sql.Rows) error { return s.QueryRowsPartialCtx(context.Background(), v, args...)
}
func (s statement) QueryRowsPartialCtx(ctx context.Context, v interface{}, args ...interface{}) error {
ctx, span := startSpan(ctx)
defer span.End()
return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
return unmarshalRows(v, rows, false) return unmarshalRows(v, rows, false)
}, s.query, args...) }, s.query, args...)
} }
func startSpan(ctx context.Context) (context.Context, tracesdk.Span) {
tracer := otel.GetTracerProvider().Tracer(trace.TraceName)
return tracer.Start(ctx, spanName)
}

View File

@@ -1,6 +1,7 @@
package sqlx package sqlx
import ( import (
"context"
"database/sql" "database/sql"
"time" "time"
@@ -18,64 +19,65 @@ func SetSlowThreshold(threshold time.Duration) {
slowThreshold.Set(threshold) slowThreshold.Set(threshold)
} }
func exec(conn sessionConn, q string, args ...interface{}) (sql.Result, error) { func exec(ctx context.Context, conn sessionConn, q string, args ...interface{}) (sql.Result, error) {
stmt, err := format(q, args...) stmt, err := format(q, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
startTime := timex.Now() startTime := timex.Now()
result, err := conn.Exec(q, args...) result, err := conn.ExecContext(ctx, q, args...)
duration := timex.Since(startTime) duration := timex.Since(startTime)
if duration > slowThreshold.Load() { if duration > slowThreshold.Load() {
logx.WithDuration(duration).Slowf("[SQL] exec: slowcall - %s", stmt) logx.WithContext(ctx).WithDuration(duration).Slowf("[SQL] exec: slowcall - %s", stmt)
} else { } else {
logx.WithDuration(duration).Infof("sql exec: %s", stmt) logx.WithContext(ctx).WithDuration(duration).Infof("sql exec: %s", stmt)
} }
if err != nil { if err != nil {
logSqlError(stmt, err) logSqlError(ctx, stmt, err)
} }
return result, err return result, err
} }
func execStmt(conn stmtConn, q string, args ...interface{}) (sql.Result, error) { func execStmt(ctx context.Context, conn stmtConn, q string, args ...interface{}) (sql.Result, error) {
stmt, err := format(q, args...) stmt, err := format(q, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
startTime := timex.Now() startTime := timex.Now()
result, err := conn.Exec(args...) result, err := conn.ExecContext(ctx, args...)
duration := timex.Since(startTime) duration := timex.Since(startTime)
if duration > slowThreshold.Load() { if duration > slowThreshold.Load() {
logx.WithDuration(duration).Slowf("[SQL] execStmt: slowcall - %s", stmt) logx.WithContext(ctx).WithDuration(duration).Slowf("[SQL] execStmt: slowcall - %s", stmt)
} else { } else {
logx.WithDuration(duration).Infof("sql execStmt: %s", stmt) logx.WithContext(ctx).WithDuration(duration).Infof("sql execStmt: %s", stmt)
} }
if err != nil { if err != nil {
logSqlError(stmt, err) logSqlError(ctx, stmt, err)
} }
return result, err return result, err
} }
func query(conn sessionConn, scanner func(*sql.Rows) error, q string, args ...interface{}) error { func query(ctx context.Context, conn sessionConn, scanner func(*sql.Rows) error,
q string, args ...interface{}) error {
stmt, err := format(q, args...) stmt, err := format(q, args...)
if err != nil { if err != nil {
return err return err
} }
startTime := timex.Now() startTime := timex.Now()
rows, err := conn.Query(q, args...) rows, err := conn.QueryContext(ctx, q, args...)
duration := timex.Since(startTime) duration := timex.Since(startTime)
if duration > slowThreshold.Load() { if duration > slowThreshold.Load() {
logx.WithDuration(duration).Slowf("[SQL] query: slowcall - %s", stmt) logx.WithContext(ctx).WithDuration(duration).Slowf("[SQL] query: slowcall - %s", stmt)
} else { } else {
logx.WithDuration(duration).Infof("sql query: %s", stmt) logx.WithContext(ctx).WithDuration(duration).Infof("sql query: %s", stmt)
} }
if err != nil { if err != nil {
logSqlError(stmt, err) logSqlError(ctx, stmt, err)
return err return err
} }
defer rows.Close() defer rows.Close()
@@ -83,22 +85,23 @@ func query(conn sessionConn, scanner func(*sql.Rows) error, q string, args ...in
return scanner(rows) return scanner(rows)
} }
func queryStmt(conn stmtConn, scanner func(*sql.Rows) error, q string, args ...interface{}) error { func queryStmt(ctx context.Context, conn stmtConn, scanner func(*sql.Rows) error,
q string, args ...interface{}) error {
stmt, err := format(q, args...) stmt, err := format(q, args...)
if err != nil { if err != nil {
return err return err
} }
startTime := timex.Now() startTime := timex.Now()
rows, err := conn.Query(args...) rows, err := conn.QueryContext(ctx, args...)
duration := timex.Since(startTime) duration := timex.Since(startTime)
if duration > slowThreshold.Load() { if duration > slowThreshold.Load() {
logx.WithDuration(duration).Slowf("[SQL] queryStmt: slowcall - %s", stmt) logx.WithContext(ctx).WithDuration(duration).Slowf("[SQL] queryStmt: slowcall - %s", stmt)
} else { } else {
logx.WithDuration(duration).Infof("sql queryStmt: %s", stmt) logx.WithContext(ctx).WithDuration(duration).Infof("sql queryStmt: %s", stmt)
} }
if err != nil { if err != nil {
logSqlError(stmt, err) logSqlError(ctx, stmt, err)
return err return err
} }
defer rows.Close() defer rows.Close()

View File

@@ -1,6 +1,7 @@
package sqlx package sqlx
import ( import (
"context"
"database/sql" "database/sql"
"errors" "errors"
"testing" "testing"
@@ -57,7 +58,7 @@ func TestStmt_exec(t *testing.T) {
test := test test := test
fns := []func(args ...interface{}) (sql.Result, error){ fns := []func(args ...interface{}) (sql.Result, error){
func(args ...interface{}) (sql.Result, error) { func(args ...interface{}) (sql.Result, error) {
return exec(&mockedSessionConn{ return exec(context.Background(), &mockedSessionConn{
lastInsertId: test.lastInsertId, lastInsertId: test.lastInsertId,
rowsAffected: test.rowsAffected, rowsAffected: test.rowsAffected,
err: test.err, err: test.err,
@@ -65,7 +66,7 @@ func TestStmt_exec(t *testing.T) {
}, test.query, args...) }, test.query, args...)
}, },
func(args ...interface{}) (sql.Result, error) { func(args ...interface{}) (sql.Result, error) {
return execStmt(&mockedStmtConn{ return execStmt(context.Background(), &mockedStmtConn{
lastInsertId: test.lastInsertId, lastInsertId: test.lastInsertId,
rowsAffected: test.rowsAffected, rowsAffected: test.rowsAffected,
err: test.err, err: test.err,
@@ -137,7 +138,7 @@ func TestStmt_query(t *testing.T) {
test := test test := test
fns := []func(args ...interface{}) error{ fns := []func(args ...interface{}) error{
func(args ...interface{}) error { func(args ...interface{}) error {
return query(&mockedSessionConn{ return query(context.Background(), &mockedSessionConn{
err: test.err, err: test.err,
delay: test.delay, delay: test.delay,
}, func(rows *sql.Rows) error { }, func(rows *sql.Rows) error {
@@ -145,7 +146,7 @@ func TestStmt_query(t *testing.T) {
}, test.query, args...) }, test.query, args...)
}, },
func(args ...interface{}) error { func(args ...interface{}) error {
return queryStmt(&mockedStmtConn{ return queryStmt(context.Background(), &mockedStmtConn{
err: test.err, err: test.err,
delay: test.delay, delay: test.delay,
}, func(rows *sql.Rows) error { }, func(rows *sql.Rows) error {
@@ -185,6 +186,10 @@ type mockedSessionConn struct {
} }
func (m *mockedSessionConn) Exec(query string, args ...interface{}) (sql.Result, error) { func (m *mockedSessionConn) Exec(query string, args ...interface{}) (sql.Result, error) {
return m.ExecContext(context.Background(), query, args...)
}
func (m *mockedSessionConn) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
if m.delay { if m.delay {
time.Sleep(defaultSlowThreshold + time.Millisecond) time.Sleep(defaultSlowThreshold + time.Millisecond)
} }
@@ -195,6 +200,10 @@ func (m *mockedSessionConn) Exec(query string, args ...interface{}) (sql.Result,
} }
func (m *mockedSessionConn) Query(query string, args ...interface{}) (*sql.Rows, error) { func (m *mockedSessionConn) Query(query string, args ...interface{}) (*sql.Rows, error) {
return m.QueryContext(context.Background(), query, args...)
}
func (m *mockedSessionConn) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
if m.delay { if m.delay {
time.Sleep(defaultSlowThreshold + time.Millisecond) time.Sleep(defaultSlowThreshold + time.Millisecond)
} }
@@ -214,6 +223,10 @@ type mockedStmtConn struct {
} }
func (m *mockedStmtConn) Exec(args ...interface{}) (sql.Result, error) { func (m *mockedStmtConn) Exec(args ...interface{}) (sql.Result, error) {
return m.ExecContext(context.Background(), args...)
}
func (m *mockedStmtConn) ExecContext(_ context.Context, _ ...interface{}) (sql.Result, error) {
if m.delay { if m.delay {
time.Sleep(defaultSlowThreshold + time.Millisecond) time.Sleep(defaultSlowThreshold + time.Millisecond)
} }
@@ -224,6 +237,10 @@ func (m *mockedStmtConn) Exec(args ...interface{}) (sql.Result, error) {
} }
func (m *mockedStmtConn) Query(args ...interface{}) (*sql.Rows, error) { func (m *mockedStmtConn) Query(args ...interface{}) (*sql.Rows, error) {
return m.QueryContext(context.Background(), args...)
}
func (m *mockedStmtConn) QueryContext(_ context.Context, _ ...interface{}) (*sql.Rows, error) {
if m.delay { if m.delay {
time.Sleep(defaultSlowThreshold + time.Millisecond) time.Sleep(defaultSlowThreshold + time.Millisecond)
} }

View File

@@ -1,6 +1,7 @@
package sqlx package sqlx
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
) )
@@ -26,11 +27,25 @@ func NewSessionFromTx(tx *sql.Tx) Session {
} }
func (t txSession) Exec(q string, args ...interface{}) (sql.Result, error) { func (t txSession) Exec(q string, args ...interface{}) (sql.Result, error) {
return exec(t.Tx, q, args...) return t.ExecCtx(context.Background(), q, args...)
}
func (t txSession) ExecCtx(ctx context.Context, q string, args ...interface{}) (sql.Result, error) {
ctx, span := startSpan(ctx)
defer span.End()
return exec(ctx, t.Tx, q, args...)
} }
func (t txSession) Prepare(q string) (StmtSession, error) { func (t txSession) Prepare(q string) (StmtSession, error) {
stmt, err := t.Tx.Prepare(q) return t.PrepareCtx(context.Background(), q)
}
func (t txSession) PrepareCtx(ctx context.Context, q string) (StmtSession, error) {
ctx, span := startSpan(ctx)
defer span.End()
stmt, err := t.Tx.PrepareContext(ctx, q)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -42,25 +57,55 @@ func (t txSession) Prepare(q string) (StmtSession, error) {
} }
func (t txSession) QueryRow(v interface{}, q string, args ...interface{}) error { func (t txSession) QueryRow(v interface{}, q string, args ...interface{}) error {
return query(t.Tx, func(rows *sql.Rows) error { return t.QueryRowCtx(context.Background(), v, q, args...)
}
func (t txSession) QueryRowCtx(ctx context.Context, v interface{}, q string, args ...interface{}) error {
ctx, span := startSpan(ctx)
defer span.End()
return query(ctx, t.Tx, func(rows *sql.Rows) error {
return unmarshalRow(v, rows, true) return unmarshalRow(v, rows, true)
}, q, args...) }, q, args...)
} }
func (t txSession) QueryRowPartial(v interface{}, q string, args ...interface{}) error { func (t txSession) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
return query(t.Tx, func(rows *sql.Rows) error { return t.QueryRowPartialCtx(context.Background(), v, q, args...)
}
func (t txSession) QueryRowPartialCtx(ctx context.Context, v interface{}, q string,
args ...interface{}) error {
ctx, span := startSpan(ctx)
defer span.End()
return query(ctx, t.Tx, func(rows *sql.Rows) error {
return unmarshalRow(v, rows, false) return unmarshalRow(v, rows, false)
}, q, args...) }, q, args...)
} }
func (t txSession) QueryRows(v interface{}, q string, args ...interface{}) error { func (t txSession) QueryRows(v interface{}, q string, args ...interface{}) error {
return query(t.Tx, func(rows *sql.Rows) error { return t.QueryRowsCtx(context.Background(), v, q, args...)
}
func (t txSession) QueryRowsCtx(ctx context.Context, v interface{}, q string, args ...interface{}) error {
ctx, span := startSpan(ctx)
defer span.End()
return query(ctx, t.Tx, func(rows *sql.Rows) error {
return unmarshalRows(v, rows, true) return unmarshalRows(v, rows, true)
}, q, args...) }, q, args...)
} }
func (t txSession) QueryRowsPartial(v interface{}, q string, args ...interface{}) error { func (t txSession) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
return query(t.Tx, func(rows *sql.Rows) error { return t.QueryRowsPartialCtx(context.Background(), v, q, args...)
}
func (t txSession) QueryRowsPartialCtx(ctx context.Context, v interface{}, q string,
args ...interface{}) error {
ctx, span := startSpan(ctx)
defer span.End()
return query(ctx, t.Tx, func(rows *sql.Rows) error {
return unmarshalRows(v, rows, false) return unmarshalRows(v, rows, false)
}, q, args...) }, q, args...)
} }
@@ -76,17 +121,19 @@ func begin(db *sql.DB) (trans, error) {
}, nil }, nil
} }
func transact(db *commonSqlConn, b beginnable, fn func(Session) error) (err error) { func transact(ctx context.Context, db *commonSqlConn, b beginnable,
fn func(context.Context, Session) error) (err error) {
conn, err := db.connProv() conn, err := db.connProv()
if err != nil { if err != nil {
db.onError(err) db.onError(err)
return err return err
} }
return transactOnConn(conn, b, fn) return transactOnConn(ctx, conn, b, fn)
} }
func transactOnConn(conn *sql.DB, b beginnable, fn func(Session) error) (err error) { func transactOnConn(ctx context.Context, conn *sql.DB, b beginnable,
fn func(context.Context, Session) error) (err error) {
var tx trans var tx trans
tx, err = b(conn) tx, err = b(conn)
if err != nil { if err != nil {
@@ -96,18 +143,18 @@ func transactOnConn(conn *sql.DB, b beginnable, fn func(Session) error) (err err
defer func() { defer func() {
if p := recover(); p != nil { if p := recover(); p != nil {
if e := tx.Rollback(); e != nil { if e := tx.Rollback(); e != nil {
err = fmt.Errorf("recover from %#v, rollback failed: %s", p, e) err = fmt.Errorf("recover from %#v, rollback failed: %w", p, e)
} else { } else {
err = fmt.Errorf("recoveer from %#v", p) err = fmt.Errorf("recoveer from %#v", p)
} }
} else if err != nil { } else if err != nil {
if e := tx.Rollback(); e != nil { if e := tx.Rollback(); e != nil {
err = fmt.Errorf("transaction failed: %s, rollback failed: %s", err, e) err = fmt.Errorf("transaction failed: %s, rollback failed: %w", err, e)
} }
} else { } else {
err = tx.Commit() err = tx.Commit()
} }
}() }()
return fn(tx) return fn(ctx, tx)
} }

View File

@@ -1,6 +1,7 @@
package sqlx package sqlx
import ( import (
"context"
"database/sql" "database/sql"
"errors" "errors"
"testing" "testing"
@@ -26,26 +27,50 @@ func (mt *mockTx) Exec(q string, args ...interface{}) (sql.Result, error) {
return nil, nil return nil, nil
} }
func (mt *mockTx) ExecCtx(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
return nil, nil
}
func (mt *mockTx) Prepare(query string) (StmtSession, error) { func (mt *mockTx) Prepare(query string) (StmtSession, error) {
return nil, nil return nil, nil
} }
func (mt *mockTx) PrepareCtx(ctx context.Context, query string) (StmtSession, error) {
return nil, nil
}
func (mt *mockTx) QueryRow(v interface{}, q string, args ...interface{}) error { func (mt *mockTx) QueryRow(v interface{}, q string, args ...interface{}) error {
return nil return nil
} }
func (mt *mockTx) QueryRowCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
return nil
}
func (mt *mockTx) QueryRowPartial(v interface{}, q string, args ...interface{}) error { func (mt *mockTx) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
return nil return nil
} }
func (mt *mockTx) QueryRowPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
return nil
}
func (mt *mockTx) QueryRows(v interface{}, q string, args ...interface{}) error { func (mt *mockTx) QueryRows(v interface{}, q string, args ...interface{}) error {
return nil return nil
} }
func (mt *mockTx) QueryRowsCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
return nil
}
func (mt *mockTx) QueryRowsPartial(v interface{}, q string, args ...interface{}) error { func (mt *mockTx) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
return nil return nil
} }
func (mt *mockTx) QueryRowsPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
return nil
}
func (mt *mockTx) Rollback() error { func (mt *mockTx) Rollback() error {
mt.status |= mockRollback mt.status |= mockRollback
return nil return nil
@@ -59,18 +84,20 @@ func beginMock(mock *mockTx) beginnable {
func TestTransactCommit(t *testing.T) { func TestTransactCommit(t *testing.T) {
mock := &mockTx{} mock := &mockTx{}
err := transactOnConn(nil, beginMock(mock), func(Session) error { err := transactOnConn(context.Background(), nil, beginMock(mock),
return nil func(context.Context, Session) error {
}) return nil
})
assert.Equal(t, mockCommit, mock.status) assert.Equal(t, mockCommit, mock.status)
assert.Nil(t, err) assert.Nil(t, err)
} }
func TestTransactRollback(t *testing.T) { func TestTransactRollback(t *testing.T) {
mock := &mockTx{} mock := &mockTx{}
err := transactOnConn(nil, beginMock(mock), func(Session) error { err := transactOnConn(context.Background(), nil, beginMock(mock),
return errors.New("rollback") func(context.Context, Session) error {
}) return errors.New("rollback")
})
assert.Equal(t, mockRollback, mock.status) assert.Equal(t, mockRollback, mock.status)
assert.NotNil(t, err) assert.NotNil(t, err)
} }

View File

@@ -1,6 +1,8 @@
package sqlx package sqlx
import ( import (
"context"
"errors"
"fmt" "fmt"
"strconv" "strconv"
"strings" "strings"
@@ -9,6 +11,8 @@ import (
"github.com/zeromicro/go-zero/core/mapping" "github.com/zeromicro/go-zero/core/mapping"
) )
var errUnbalancedEscape = errors.New("no char after escape char")
func desensitize(datasource string) string { func desensitize(datasource string) string {
// remove account // remove account
pos := strings.LastIndex(datasource, "@") pos := strings.LastIndex(datasource, "@")
@@ -66,7 +70,7 @@ func format(query string, args ...interface{}) (string, error) {
writeValue(&b, args[argIndex]) writeValue(&b, args[argIndex])
argIndex++ argIndex++
case '$': case ':', '$':
var j int var j int
for j = i + 1; j < bytes; j++ { for j = i + 1; j < bytes; j++ {
char := query[j] char := query[j]
@@ -74,16 +78,18 @@ func format(query string, args ...interface{}) (string, error) {
break break
} }
} }
if j > i+1 { if j > i+1 {
index, err := strconv.Atoi(query[i+1 : j]) index, err := strconv.Atoi(query[i+1 : j])
if err != nil { if err != nil {
return "", err return "", err
} }
// index starts from 1 for pg // index starts from 1 for pg or oracle
if index > argIndex { if index > argIndex {
argIndex = index argIndex = index
} }
index-- index--
if index < 0 || numArgs <= index { if index < 0 || numArgs <= index {
return "", fmt.Errorf("error: wrong index %d in sql", index) return "", fmt.Errorf("error: wrong index %d in sql", index)
@@ -92,6 +98,25 @@ func format(query string, args ...interface{}) (string, error) {
writeValue(&b, args[index]) writeValue(&b, args[index])
i = j - 1 i = j - 1
} }
case '\'', '"', '`':
b.WriteByte(ch)
for j := i + 1; j < bytes; j++ {
cur := query[j]
b.WriteByte(cur)
if cur == '\\' {
j++
if j >= bytes {
return "", errUnbalancedEscape
}
b.WriteByte(query[j])
} else if cur == ch {
i = j
break
}
}
default: default:
b.WriteByte(ch) b.WriteByte(ch)
} }
@@ -109,9 +134,9 @@ func logInstanceError(datasource string, err error) {
logx.Errorf("Error on getting sql instance of %s: %v", datasource, err) logx.Errorf("Error on getting sql instance of %s: %v", datasource, err)
} }
func logSqlError(stmt string, err error) { func logSqlError(ctx context.Context, stmt string, err error) {
if err != nil && err != ErrNotFound { if err != nil && err != ErrNotFound {
logx.Errorf("stmt: %s, error: %s", stmt, err.Error()) logx.WithContext(ctx).Errorf("stmt: %s, error: %s", stmt, err.Error())
} }
} }

View File

@@ -73,6 +73,54 @@ func TestFormat(t *testing.T) {
args: []interface{}{"133", false}, args: []interface{}{"133", false},
hasErr: true, hasErr: true,
}, },
{
name: "oracle normal",
query: "select name, age from users where bool=:1 and phone=:2",
args: []interface{}{true, "133"},
expect: "select name, age from users where bool=1 and phone='133'",
},
{
name: "oracle normal reverse",
query: "select name, age from users where bool=:2 and phone=:1",
args: []interface{}{"133", false},
expect: "select name, age from users where bool=0 and phone='133'",
},
{
name: "oracle error not number",
query: "select name, age from users where bool=:a and phone=:1",
args: []interface{}{"133", false},
hasErr: true,
},
{
name: "oracle error more args",
query: "select name, age from users where bool=:2 and phone=:1 and nickname=:3",
args: []interface{}{"133", false},
hasErr: true,
},
{
name: "select with date",
query: "select * from user where date='2006-01-02 15:04:05' and name=:1",
args: []interface{}{"foo"},
expect: "select * from user where date='2006-01-02 15:04:05' and name='foo'",
},
{
name: "select with date and escape",
query: `select * from user where date=' 2006-01-02 15:04:05 \'' and name=:1`,
args: []interface{}{"foo"},
expect: `select * from user where date=' 2006-01-02 15:04:05 \'' and name='foo'`,
},
{
name: "select with date and bad arg",
query: `select * from user where date='2006-01-02 15:04:05 \'' and name=:a`,
args: []interface{}{"foo"},
hasErr: true,
},
{
name: "select with date and escape error",
query: `select * from user where date='2006-01-02 15:04:05 \`,
args: []interface{}{"foo"},
hasErr: true,
},
} }
for _, test := range tests { for _, test := range tests {
@@ -84,6 +132,7 @@ func TestFormat(t *testing.T) {
if test.hasErr { if test.hasErr {
assert.NotNil(t, err) assert.NotNil(t, err)
} else { } else {
assert.Nil(t, err)
assert.Equal(t, test.expect, actual) assert.Equal(t, test.expect, actual)
} }
}) })

View File

@@ -2,6 +2,8 @@ package stringx
type node struct { type node struct {
children map[rune]*node children map[rune]*node
fail *node
depth int
end bool end bool
} }
@@ -12,17 +14,19 @@ func (n *node) add(word string) {
} }
nd := n nd := n
for _, char := range chars { var depth int
for i, char := range chars {
if nd.children == nil { if nd.children == nil {
child := new(node) child := new(node)
nd.children = map[rune]*node{ child.depth = i + 1
char: child, nd.children = map[rune]*node{char: child}
}
nd = child nd = child
} else if child, ok := nd.children[char]; ok { } else if child, ok := nd.children[char]; ok {
nd = child nd = child
depth++
} else { } else {
child := new(node) child := new(node)
child.depth = i + 1
nd.children[char] = child nd.children[char] = child
nd = child nd = child
} }
@@ -30,3 +34,67 @@ func (n *node) add(word string) {
nd.end = true nd.end = true
} }
func (n *node) build() {
var nodes []*node
for _, child := range n.children {
child.fail = n
nodes = append(nodes, child)
}
for len(nodes) > 0 {
nd := nodes[0]
nodes = nodes[1:]
for key, child := range nd.children {
nodes = append(nodes, child)
cur := nd
for cur != nil {
if cur.fail == nil {
child.fail = n
break
}
if fail, ok := cur.fail.children[key]; ok {
child.fail = fail
break
}
cur = cur.fail
}
}
}
}
func (n *node) find(chars []rune) []scope {
var scopes []scope
size := len(chars)
cur := n
for i := 0; i < size; i++ {
child, ok := cur.children[chars[i]]
if ok {
cur = child
} else {
for cur != n {
cur = cur.fail
if child, ok = cur.children[chars[i]]; ok {
cur = child
break
}
}
if child == nil {
continue
}
}
for child != n {
if child.end {
scopes = append(scopes, scope{
start: i + 1 - child.depth,
stop: i + 1,
})
}
child = child.fail
}
}
return scopes
}

View File

@@ -0,0 +1,87 @@
//go:build go1.18
// +build go1.18
package stringx
import (
"fmt"
"math/rand"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func FuzzNodeFind(f *testing.F) {
rand.Seed(time.Now().UnixNano())
f.Add(10)
f.Fuzz(func(t *testing.T, keys int) {
str := Randn(rand.Intn(100) + 50)
keywords := make(map[string]struct{})
for i := 0; i < keys; i++ {
keyword := Randn(rand.Intn(10) + 5)
if !strings.Contains(str, keyword) {
keywords[keyword] = struct{}{}
}
}
size := len(str)
var scopes []scope
var n node
for i := 0; i < size%20; i++ {
start := rand.Intn(size)
stop := start + rand.Intn(20) + 1
if stop > size {
stop = size
}
if start == stop {
continue
}
keyword := str[start:stop]
if _, ok := keywords[keyword]; ok {
continue
}
keywords[keyword] = struct{}{}
var pos int
for pos <= len(str)-len(keyword) {
val := str[pos:]
p := strings.Index(val, keyword)
if p < 0 {
break
}
scopes = append(scopes, scope{
start: pos + p,
stop: pos + p + len(keyword),
})
pos += p + 1
}
}
for keyword := range keywords {
n.add(keyword)
}
n.build()
var buf strings.Builder
buf.WriteString("keywords:\n")
for key := range keywords {
fmt.Fprintf(&buf, "\t%q,\n", key)
}
buf.WriteString("scopes:\n")
for _, scp := range scopes {
fmt.Fprintf(&buf, "\t{%d, %d},\n", scp.start, scp.stop)
}
fmt.Fprintf(&buf, "text:\n\t%s\n", str)
defer func() {
if r := recover(); r != nil {
t.Errorf(buf.String())
}
}()
assert.ElementsMatchf(t, scopes, n.find([]rune(str)), buf.String())
})
}

195
core/stringx/node_test.go Normal file
View File

@@ -0,0 +1,195 @@
package stringx
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestFuzzNodeCase1(t *testing.T) {
keywords := []string{
"cs8Zh",
"G1OihlVuBz",
"K6azS2FBHjI",
"DQKvghI4",
"l7bA86Sze",
"tjBLZhCao",
"nEsXmVzP",
"cbRh8UE1nO3s",
"Wta3R2WcbGP",
"jpOIcA",
"TtkRr4k9hI",
"OKbSo0clAYTtk",
"uJs1WToEanlKV",
"05Y02iFD2",
"x2uJs1WToEanlK",
"ieaSWe",
"Kg",
"FD2bCKFazH",
}
scopes := []scope{
{62, 72},
{52, 65},
{21, 34},
{1, 10},
{19, 33},
{36, 42},
{42, 44},
{7, 17},
}
n := new(node)
for _, key := range keywords {
n.add(key)
}
n.build()
assert.ElementsMatch(t, scopes, n.find([]rune("Z05Y02iFD2bCKFazHtrx2uJs1WToEanlKVWKieaSWeKgmnUXV0ZjOKbSo0clAYTtkRr4k9hI")))
}
func TestFuzzNodeCase2(t *testing.T) {
keywords := []string{
"IP1NPsJKIvt",
"Iw7hQARwSTw",
"SmZIcA",
"OyxHPYkoQzFO",
"3suCnuSAS5d",
"HUMpbi",
"HPdvbGGpY",
"49qjMtR8bEt",
"a0zrrGKVTJ2",
"WbOBcszeo1FfZ",
"8tHUi5PJI",
"Oa2Di",
"6ZWa5nr1tU",
"o0LJRfmeXB9bF9",
"veF0ehKxH",
"Qp73r",
"B6Rmds4ELY8",
"uNpGybQZG",
"Ogm3JqicRZlA4n",
"FL6LVErKomc84H",
"qv2Pi0xJj3cR1",
"bPWLBg4",
"hYN8Q4M1sw",
"ExkTgNklmlIx",
"eVgHHDOxOUEj",
"5WPEVv0tR",
"CPjnOAqUZgV",
"oR3Ogtz",
"jwk1Zbg",
"DYqguyk8h",
"rieroDmpvYFK",
"MQ9hZnMjDqrNQe",
"EhM4KqkCBd",
"m9xalj6q",
"d5CTL5mzK",
"XJOoTvFtI8U",
"iFAwspJ",
"iGv8ErnRZIuSWX",
"i8C1BqsYX",
"vXN1KOaOgU",
"GHJFB",
"Y6OlAqbZxYG8",
"dzd4QscSih4u",
"SsLYMkKvB9APx",
"gi0huB3",
"CMICHDCSvSrgiACXVkN",
"MwOvyHbaxdaqpZpU",
"wOvyHbaxdaqpZpUbI",
"2TT5WEy",
"eoCq0T2MC",
"ZpUbI7",
"oCq0T2MCp",
"CpLFgLg0g",
"FgLg0gh",
"w5awC5HeoCq",
"1c",
}
scopes := []scope{
{0, 19},
{57, 73},
{58, 75},
{47, 54},
{29, 38},
{70, 76},
{30, 39},
{37, 46},
{40, 47},
{22, 33},
{92, 94},
}
n := new(node)
for _, key := range keywords {
n.add(key)
}
n.build()
assert.ElementsMatch(t, scopes, n.find([]rune("CMICHDCSvSrgiACXVkNF9lw5awC5HeoCq0T2MCpLFgLg0gh2TT5WEyINrMwOvyHbaxdaqpZpUbI7SpIY5yVWf33MuX7K1c")))
}
func TestFuzzNodeCase3(t *testing.T) {
keywords := []string{
"QAraACKOftI4",
"unRmd2EO0",
"s25OtuoU",
"aGlmn7KnbE4HCX",
"kuK6Uh",
"ckuK6Uh",
"uK6Uh",
"Iy",
"h",
"PMSSUNvyi",
"ahz0i",
"Lhs4XZ1e",
"shPp1Va7aQNVme",
"yIUckuK6Uh",
"pKjIyI",
"jIyIUckuK6Uh",
"UckuK6Uh",
"Uh",
"JPAULjQgHJ",
"Wp",
"sbkZxXurrI",
"pKjIyIUckuK6Uh",
}
scopes := []scope{
{9, 15},
{8, 15},
{5, 15},
{1, 7},
{10, 15},
{3, 15},
{0, 2},
{1, 15},
{7, 15},
{13, 15},
{4, 6},
{14, 15},
}
n := new(node)
for _, key := range keywords {
n.add(key)
}
n.build()
assert.ElementsMatch(t, scopes, n.find([]rune("WpKjIyIUckuK6Uh")))
}
func BenchmarkNodeFind(b *testing.B) {
b.ReportAllocs()
keywords := []string{
"A",
"AV",
"AV演员",
"无名氏",
"AV演员色情",
"日本AV女优",
}
trie := new(node)
for _, keyword := range keywords {
trie.add(keyword)
}
trie.build()
for i := 0; i < b.N; i++ {
trie.find([]rune("日本AV演员兼电视、电影演员。无名氏AV女优是xx出道, 日本AV女优们最精彩的表演是AV演员色情表演"))
}
}

View File

@@ -9,7 +9,7 @@ type (
} }
replacer struct { replacer struct {
node *node
mapping map[string]string mapping map[string]string
} }
) )
@@ -17,58 +17,81 @@ type (
// NewReplacer returns a Replacer. // NewReplacer returns a Replacer.
func NewReplacer(mapping map[string]string) Replacer { func NewReplacer(mapping map[string]string) Replacer {
rep := &replacer{ rep := &replacer{
node: new(node),
mapping: mapping, mapping: mapping,
} }
for k := range mapping { for k := range mapping {
rep.add(k) rep.add(k)
} }
rep.build()
return rep return rep
} }
// Replace replaces text with given substitutes.
func (r *replacer) Replace(text string) string { func (r *replacer) Replace(text string) string {
var builder strings.Builder var builder strings.Builder
var start int
chars := []rune(text) chars := []rune(text)
size := len(chars) size := len(chars)
start := -1
for i := 0; i < size; i++ { for start < size {
child, ok := r.children[chars[i]] cur := r.node
if !ok {
builder.WriteRune(chars[i]) if start > 0 {
continue builder.WriteString(string(chars[:start]))
} }
if start < 0 { for i := start; i < size; i++ {
start = i child, ok := cur.children[chars[i]]
} if ok {
end := -1 cur = child
if child.end { } else if cur == r.node {
end = i + 1 builder.WriteRune(chars[i])
} // cur already points to root, set start only
start = i + 1
continue
} else {
curDepth := cur.depth
cur = cur.fail
child, ok = cur.children[chars[i]]
if !ok {
// write this path
builder.WriteString(string(chars[i-curDepth : i+1]))
// go to root
cur = r.node
start = i + 1
continue
}
j := i + 1 failDepth := cur.depth
for ; j < size; j++ { // write path before jump
grandchild, ok := child.children[chars[j]] builder.WriteString(string(chars[start : start+curDepth-failDepth]))
if !ok { start += curDepth - failDepth
cur = child
}
if cur.end {
val := string(chars[i+1-cur.depth : i+1])
builder.WriteString(r.mapping[val])
builder.WriteString(string(chars[i+1:]))
// only matching this path, all previous paths are done
if start >= i+1-cur.depth && i+1 >= size {
return builder.String()
}
chars = []rune(builder.String())
size = len(chars)
builder.Reset()
break break
} }
child = grandchild
if child.end {
end = j + 1
i = j
}
} }
if end > 0 { if !cur.end {
i = j - 1 builder.WriteString(string(chars[start:]))
builder.WriteString(r.mapping[string(chars[start:end])]) return builder.String()
} else {
builder.WriteRune(chars[i])
} }
start = -1
} }
return builder.String() return string(chars)
} }

View File

@@ -0,0 +1,42 @@
//go:build go1.18
// +build go1.18
package stringx
import (
"fmt"
"math/rand"
"strings"
"testing"
)
func FuzzReplacerReplace(f *testing.F) {
keywords := make(map[string]string)
for i := 0; i < 20; i++ {
keywords[Randn(rand.Intn(10)+5)] = Randn(rand.Intn(5) + 1)
}
rep := NewReplacer(keywords)
printableKeywords := func() string {
var buf strings.Builder
for k, v := range keywords {
fmt.Fprintf(&buf, "%q: %q,\n", k, v)
}
return buf.String()
}
f.Add(50)
f.Fuzz(func(t *testing.T, n int) {
text := Randn(rand.Intn(n%50+50) + 1)
defer func() {
if r := recover(); r != nil {
t.Errorf("mapping: %s\ntext: %s", printableKeywords(), text)
}
}()
val := rep.Replace(text)
keys := rep.(*replacer).node.find([]rune(val))
if len(keys) > 0 {
t.Errorf("mapping: %s\ntext: %s\nresult: %s\nmatch: %v",
printableKeywords(), text, val, keys)
}
})
}

View File

@@ -15,6 +15,14 @@ func TestReplacer_Replace(t *testing.T) {
assert.Equal(t, "零1234五", NewReplacer(mapping).Replace("零一二三四五")) assert.Equal(t, "零1234五", NewReplacer(mapping).Replace("零一二三四五"))
} }
func TestReplacer_ReplaceOverlap(t *testing.T) {
mapping := map[string]string{
"3d": "34",
"bc": "23",
}
assert.Equal(t, "a234e", NewReplacer(mapping).Replace("abcde"))
}
func TestReplacer_ReplaceSingleChar(t *testing.T) { func TestReplacer_ReplaceSingleChar(t *testing.T) {
mapping := map[string]string{ mapping := map[string]string{
"二": "2", "二": "2",
@@ -42,3 +50,99 @@ func TestReplacer_ReplaceMultiMatches(t *testing.T) {
} }
assert.Equal(t, "零一23四五一23四五", NewReplacer(mapping).Replace("零一二三四五一二三四五")) assert.Equal(t, "零一23四五一23四五", NewReplacer(mapping).Replace("零一二三四五一二三四五"))
} }
func TestReplacer_ReplaceJumpToFail(t *testing.T) {
mapping := map[string]string{
"bcdf": "1235",
"cde": "234",
}
assert.Equal(t, "ab234fg", NewReplacer(mapping).Replace("abcdefg"))
}
func TestReplacer_ReplaceJumpToFailDup(t *testing.T) {
mapping := map[string]string{
"bcdf": "1235",
"ccde": "2234",
}
assert.Equal(t, "ab2234fg", NewReplacer(mapping).Replace("abccdefg"))
}
func TestReplacer_ReplaceJumpToFailEnding(t *testing.T) {
mapping := map[string]string{
"bcdf": "1235",
"cdef": "2345",
}
assert.Equal(t, "ab2345", NewReplacer(mapping).Replace("abcdef"))
}
func TestReplacer_ReplaceEmpty(t *testing.T) {
mapping := map[string]string{
"bcdf": "1235",
"cdef": "2345",
}
assert.Equal(t, "", NewReplacer(mapping).Replace(""))
}
func TestFuzzReplacerCase1(t *testing.T) {
keywords := map[string]string{
"yQyJykiqoh": "xw",
"tgN70z": "Q2P",
"tXKhEn": "w1G8",
"5nfOW1XZO": "GN",
"f4Ov9i9nHD": "cT",
"1ov9Q": "Y",
"7IrC9n": "400i",
"JQLxonpHkOjv": "XI",
"DyHQ3c7": "Ygxux",
"ffyqJi": "u",
"UHuvXrbD8pni": "dN",
"LIDzNbUlTX": "g",
"yN9WZh2rkc8Q": "3U",
"Vhk11rz8CObceC": "jf",
"R0Rt4H2qChUQf": "7U5M",
"MGQzzPCVKjV9": "yYz",
"B5jUUl0u1XOY": "l4PZ",
"pdvp2qfLgG8X": "BM562",
"ZKl9qdApXJ2": "T",
"37jnugkSevU66": "aOHFX",
}
rep := NewReplacer(keywords)
text := "yjF8fyqJiiqrczOCVyoYbLvrMpnkj"
val := rep.Replace(text)
keys := rep.(*replacer).node.find([]rune(val))
if len(keys) > 0 {
t.Errorf("result: %s, match: %v", val, keys)
}
}
func TestFuzzReplacerCase2(t *testing.T) {
keywords := map[string]string{
"dmv2SGZvq9Yz": "TE",
"rCL5DRI9uFP8": "hvsc8",
"7pSA2jaomgg": "v",
"kWSQvjVOIAxR": "Oje",
"hgU5bYYkD3r6": "qCXu",
"0eh6uI": "MMlt",
"3USZSl85EKeMzw": "Pc",
"JONmQSuXa": "dX",
"EO1WIF": "G",
"uUmFJGVmacjF": "1N",
"DHpw7": "M",
"NYB2bm": "CPya",
"9FiNvBAHHNku5": "7FlDE",
"tJi3I4WxcY": "q5",
"sNJ8Z1ToBV0O": "tl",
"0iOg72QcPo": "RP",
"pSEqeL": "5KZ",
"GOyYqTgmvQ": "9",
"Qv4qCsj": "nl52E",
"wNQ5tOutYu5s8": "6iGa",
}
rep := NewReplacer(keywords)
text := "AoRxrdKWsGhFpXwVqMLWRL74OukwjBuBh0g7pSrk"
val := rep.Replace(text)
keys := rep.(*replacer).node.find([]rune(val))
if len(keys) > 0 {
t.Errorf("result: %s, match: %v", val, keys)
}
}

Some files were not shown because too many files have changed in this diff Show More