Compare commits

..

72 Commits

Author SHA1 Message Date
kevin
2ea0a843f8 chore: remove any keywords 2023-03-04 20:54:26 +08:00
Kevin Wan
9e0e01b2bc chore: add tests (#2960) 2023-03-04 20:38:50 +08:00
yangjinheng
af50a80d01 timeout writer add hijack 2023-03-04 20:38:45 +08:00
yangjinheng
703fb8d970 Update timeouthandler.go 2023-03-04 20:38:40 +08:00
MarkJoyMa
e964e530e1 x 2023-03-04 20:32:21 +08:00
MarkJoyMa
52265087d1 x 2023-03-04 20:32:16 +08:00
MarkJoyMa
b4c2677eb9 add ut 2023-03-04 20:32:10 +08:00
MarkJoyMa
30296fb1ca feat: conf add FillDefault func 2023-03-04 20:31:44 +08:00
zhoumingji
356c80defd Fix bug in dartgen: The property 'isEmpty' can't be unconditionally accessed because the receiver can be 'null' 2023-03-04 20:31:38 +08:00
zhoumingji
8c31525378 Fix bug in dartgen: Increase the processing logic when route.RequestType is empty 2023-03-04 20:31:30 +08:00
cui fliter
2cf09f3c36 fix functiom name
Signed-off-by: cui fliter <imcusg@gmail.com>
2023-03-04 20:31:20 +08:00
Kevin Wan
d41e542c92 feat: support grpc client keepalive config (#2950) 2023-03-04 20:30:31 +08:00
tanglihao
265a24ac6d fix code format style use const config.DefaultFormat 2023-03-04 20:30:21 +08:00
tanglihao
7d88fc39dc fix log name conflict 2023-03-04 20:30:16 +08:00
anqiansong
6957b6a344 format code 2023-03-04 20:30:10 +08:00
anqiansong
bca6a230c8 remove unused code 2023-03-04 20:30:04 +08:00
anqiansong
cc8413d683 remove unused code 2023-03-04 20:29:56 +08:00
anqiansong
3842283fa8 Fix #2879 2023-03-04 20:29:41 +08:00
qiying.wang
fe13a533f5 chore: remove redundant prefix of "error: " in error creation 2023-03-04 20:26:40 +08:00
qiying.wang
7a327ccda4 chore: add tests for logc debug 2023-03-04 20:25:52 +08:00
qiying.wang
06e4507406 feat: add debug log for logc 2023-03-04 20:25:27 +08:00
kevin
8794d5b753 chore: add comments 2023-03-04 20:25:21 +08:00
kevin
9bfa63d995 chore: add more tests 2023-03-04 20:25:15 +08:00
kevin
a432b121fb chore: add more tests 2023-03-04 20:25:07 +08:00
kevin
b61c94bb66 feat: check key overwritten 2023-03-04 20:24:33 +08:00
Kevin Wan
93fcf899dc fix: config map cannot handle case-insensitive keys. (#2932)
* fix: #2922

* chore: rename const

* feat: support anonymous map field

* feat: support anonymous map field
2023-03-04 20:23:53 +08:00
Kevin Wan
9f4b3bae92 fix: #2899, using autoscaling/v2beta2 instead of v2beta1 (#2900)
* fix: #2899, using autoscaling/v2 instead of v2beta1

* chore: change hpa definition

---------

Co-authored-by: kevin.wan <kevin.wan@yijinin.com>
2023-03-04 20:22:27 +08:00
Kevin Wan
805cb87d98 chore: refine rest validator (#2928)
* chore: refine rest validator

* chore: add more tests

* chore: reformat code

* chore: add comments
2023-03-04 20:22:10 +08:00
Qiying Wang
366131640e feat: add configurable validator for httpx.Parse (#2923)
Co-authored-by: qiying.wang <qiying.wang@highlight.mobi>
2023-03-04 20:22:05 +08:00
Kevin Wan
956884a3ff fix: timeout not working if greater than global rest timeout (#2926) 2023-03-04 20:21:59 +08:00
raymonder jin
f571cb8af2 del unnecessary blank 2023-03-04 20:21:54 +08:00
Kevin Wan
cc5acf3b90 chore: reformat code (#2925) 2023-03-04 20:21:49 +08:00
chenquan
e1aa665443 fix: fixed the bug that old trace instances may be fetched 2023-03-04 20:21:43 +08:00
xiandong
cd357d9484 rm parseErr when kindJaeger 2023-03-04 20:21:28 +08:00
xiandong
6d4d7cbd6b rm kindJaegerUdp 2023-03-04 20:21:18 +08:00
xiandong
c593b5b531 add parseEndpoint 2023-03-04 20:20:29 +08:00
xiandong
fd5b38b07c add parseEndpoint 2023-03-04 20:20:17 +08:00
xiandong
41efb48f55 add test for Endpoint of kindJaegerUdp 2023-03-04 20:19:40 +08:00
xiandong
0ef3626839 add test for Endpoint of kindJaegerUdp 2023-03-04 20:19:34 +08:00
xiandong
77a72b16e9 add kindJaegerUdp 2023-03-04 20:19:25 +08:00
Kevin Wan
21566f1b7a chore: reformat code (#2903) 2023-03-04 20:17:35 +08:00
anqiansong
b2646e228b feat: Add request.ts (#2901)
* Add request.ts

* Update comments

* Refactor request filename
2023-03-04 20:17:21 +08:00
cong
588b883710 refactor: simplify sqlx fail fast ping and simplify miniredis setup in test (#2897)
* chore(redistest): simplify miniredis setup in test

* refactor(sqlx): simplify sqlx fail fast ping

* chore: close connection if not available
2023-03-04 20:17:16 +08:00
Kevin Wan
033910bbd8 Update readme-cn.md 2023-03-04 20:17:11 +08:00
fondoger
530dd79e3f Fix bug in dart api gen: path parameter is not replaced 2023-03-04 20:17:05 +08:00
Kevin Wan
cd5263ac75 Update readme-cn.md 2023-03-04 20:16:58 +08:00
Kevin Wan
ea3302a468 fix: test failures (#2892)
Co-authored-by: kevin.wan <kevin.wan@yijinin.com>
2023-03-04 20:16:50 +08:00
fondoger
abf15b373c Fix Dart API generation bugs; Add ability to generate API for path parameters (#2887)
* Fix bug in dartgen: Import path should match the generated api filename

* Use Route.HandlerName as generated dart API function name

Reasons:
- There is bug when using url path name as function name, because it may have invalid characters such as ":"
- Switching to HandlerName aligns with other languages such as typescript generation

* [DartGen] Add ability to generate api for url path parameters such as /path/:param
2023-03-04 20:16:44 +08:00
Kevin Wan
a865e9ee29 refactor: simplify stringx.Replacer, and avoid potential infinite loops (#2877)
* simplify replace

* backup

* refactor: simplify stringx.Replacer

* chore: add comments and const

* chore: add more tests

* chore: rename variable
2023-03-04 20:16:37 +08:00
Kevin Wan
f8292198cf Update readme-cn.md 2023-03-04 20:15:38 +08:00
Kevin Wan
016d965f56 chore: refactor (#2875) 2023-03-04 20:15:30 +08:00
dahaihu
95d7c73409 fix Replacer suffix match, and add test case (#2867)
* fix: replace shoud replace the longest match

* feat: revert bytes.Buffer to strings.Builder

* fix: loop reset nextStart

* feat: add node longest match test

* feat: add replacer suffix match test case

* feat: multiple match

* fix: partial match ends

* fix: replace look back upon error

* feat: rm unnecessary branch

---------

Co-authored-by: hudahai <hscxrzs@gmail.com>
Co-authored-by: hushichang <hushichang@sensetime.com>
2023-03-04 20:15:25 +08:00
Kevin Wan
939ef2a181 chore: add more tests (#2873) 2023-03-04 20:15:18 +08:00
Kevin Wan
f0b8dd45fe fix: test failure (#2874) 2023-03-04 20:15:08 +08:00
Mikael
0ba9335b04 only unmashal public variables (#2872)
* only unmashal public variables

* only unmashal public variables

* only unmashal public variables

* only unmashal public variables
2023-03-04 20:15:01 +08:00
Kevin Wan
04f181f0b4 chore: add more tests (#2866)
* chore: add more tests

* chore: add more tests

* chore: fix test failure
2023-03-04 20:14:54 +08:00
hudahai
89f841c126 fix: loop reset nextStart 2023-03-04 20:14:48 +08:00
hudahai
d785c8c377 feat: revert bytes.Buffer to strings.Builder 2023-03-04 20:14:41 +08:00
hudahai
687a1d15da fix: replace shoud replace the longest match 2023-03-04 20:14:35 +08:00
Kevin Wan
aaa974e1ad Update readme-cn.md 2023-03-04 20:14:22 +08:00
Kevin Wan
2779568ccf fix: conf anonymous overlay problem (#2847) 2023-03-04 20:14:10 +08:00
Kevin Wan
f7d50ae626 Update readme-cn.md 2023-03-04 20:14:01 +08:00
Kevin Wan
33594ea350 Chore/rewire (#2836)
* fix: problem on name overlaping in config (#2820)

* chore: fix missing funcs on windows (#2825)

* chore: add more tests (#2812)

* chore: add more tests

* chore: add more tests

* chore: add more tests (#2814)

* chore: add more tests (#2815)

* chore: add more tests

* chore: add more tests

* chore: add more tests

* chore: add more tests

* chore: add more tests

* chore: add more tests

* feat: upgrade go to v1.18 (#2817)

* feat: upgrade go to v1.18

* feat: upgrade go to v1.18

* chore: change interface{} to any (#2818)

* chore: change interface{} to any

* chore: update goctl version to 1.5.0

* chore: update goctl deps

* chore: update goctl interface{} to any (#2819)

* chore: update goctl interface{} to any

* chore: update goctl interface{} to any

* chore(deps): bump google.golang.org/grpc from 1.52.0 to 1.52.3 (#2823)

* support custom maxBytes in API file (#2822)

* feat: mapreduce generic version (#2827)

* feat: mapreduce generic version

* fix: gateway mr type issue

---------

Co-authored-by: kevin.wan <kevin.wan@yijinin.com>

* feat: add MustNewRedis (#2824)

* feat: add MustNewRedis

* feat: add MustNewRedis

* feat: add MustNewRedis

* x

* x

* fix ut

* x

* x

* x

* x

* x

* chore: improve codecov (#2828)

* feat: converge grpc interceptor processing (#2830)

* feat: converge grpc interceptor processing

* x

* x

* chore(deps): bump go.opentelemetry.io/otel/exporters/zipkin (#2831)

* chore(deps): bump go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp (#2833)

Bumps [go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp](https://github.com/open-telemetry/opentelemetry-go) from 1.11.2 to 1.12.0.
- [Release notes](https://github.com/open-telemetry/opentelemetry-go/releases)
- [Changelog](https://github.com/open-telemetry/opentelemetry-go/blob/main/CHANGELOG.md)
- [Commits](https://github.com/open-telemetry/opentelemetry-go/compare/v1.11.2...v1.12.0)

---
updated-dependencies:
- dependency-name: go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* chore(deps): bump go.opentelemetry.io/otel/exporters/jaeger (#2832)

Bumps [go.opentelemetry.io/otel/exporters/jaeger](https://github.com/open-telemetry/opentelemetry-go) from 1.11.2 to 1.12.0.
- [Release notes](https://github.com/open-telemetry/opentelemetry-go/releases)
- [Changelog](https://github.com/open-telemetry/opentelemetry-go/blob/main/CHANGELOG.md)
- [Commits](https://github.com/open-telemetry/opentelemetry-go/compare/v1.11.2...v1.12.0)

---
updated-dependencies:
- dependency-name: go.opentelemetry.io/otel/exporters/jaeger
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Xiaoju Jiang <44432198+jiang4869@users.noreply.github.com>
Co-authored-by: kevin.wan <kevin.wan@yijinin.com>
Co-authored-by: MarkJoyMa <64180138+MarkJoyMa@users.noreply.github.com>
2023-03-04 20:13:37 +08:00
MarkJoyMa
ee2ec974c4 feat: converge grpc interceptor processing (#2830)
* feat: converge grpc interceptor processing

* x

* x
2023-03-04 20:12:30 +08:00
Kevin Wan
fd2f2f0f54 chore: improve codecov (#2828) 2023-03-04 20:12:16 +08:00
MarkJoyMa
86a2429d7d feat: add MustNewRedis (#2824)
* feat: add MustNewRedis

* feat: add MustNewRedis

* feat: add MustNewRedis

* x

* x

* fix ut

* x

* x

* x

* x

* x
2023-03-04 20:12:05 +08:00
Xiaoju Jiang
e5fe5dcc50 support custom maxBytes in API file (#2822) 2023-03-04 20:11:55 +08:00
Kevin Wan
b510e7c242 chore: fix missing funcs on windows (#2825) 2023-03-04 20:11:46 +08:00
Kevin Wan
dfe92e709f fix: problem on name overlaping in config (#2820) 2023-03-04 20:11:18 +08:00
Kevin Wan
cb649cf627 chore: add more tests (#2815)
* chore: add more tests

* chore: add more tests

* chore: add more tests

* chore: add more tests

* chore: add more tests

* chore: add more tests
2023-03-04 20:11:03 +08:00
Kevin Wan
ce19a5ade6 chore: add more tests (#2814) 2023-03-04 20:10:57 +08:00
Kevin Wan
6dc56de714 chore: add more tests (#2812)
* chore: add more tests

* chore: add more tests
2023-03-04 20:09:03 +08:00
911 changed files with 14741 additions and 72446 deletions

6
.codecov.yml Normal file
View File

@@ -0,0 +1,6 @@
comment:
layout: "flags, files"
behavior: once
require_changes: true
ignore:
- "tools"

View File

@@ -1,7 +1 @@
**/.git **/.git
.dockerignore
Dockerfile
goctl
Makefile
readme.md
readme-cn.md

12
.github/FUNDING.yml vendored
View File

@@ -1,3 +1,13 @@
# These are supported funding model platforms # These are supported funding model platforms
github: [zeromicro] github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
patreon: # Replace with a single Patreon username
open_collective: # Replace with a single Open Collective username
ko_fi: # Replace with a single Ko-fi username
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
liberapay: # Replace with a single Liberapay username
issuehunt: # Replace with a single IssueHunt username
otechie: # Replace with a single Otechie username
custom: # https://gitee.com/kevwan/static/raw/master/images/sponsor.jpg
ethereum: 0x5052b7f6B937B02563996D23feb69b38D06Ca150 | kevwan

View File

@@ -1,197 +0,0 @@
# GitHub Copilot Instructions for go-zero
This document provides guidelines for GitHub Copilot when assisting with development in the go-zero project.
## Project Overview
go-zero is a web and RPC framework with lots of built-in engineering practices designed to ensure the stability of busy services with resilience design. It has been serving sites with tens of millions of users for years.
### Key Architecture Components
- **REST API framework** (`rest/`) - HTTP service framework with middleware support
- **RPC framework** (`zrpc/`) - gRPC-based RPC framework with service discovery
- **Core utilities** (`core/`) - Foundational components including:
- Circuit breakers, rate limiters, load shedding
- Caching, stores (Redis, MongoDB, SQL)
- Concurrency control, metrics, tracing
- Configuration management
- **Code generation tool** (`tools/goctl/`) - CLI tool for generating code from API files
## Coding Standards and Conventions
### Code Style
1. **Follow Go conventions**: Use `gofmt` for formatting, follow effective Go practices
2. **Package naming**: Use lowercase, single-word package names when possible
3. **Error handling**: Always handle errors explicitly, use `errorx.BatchError` for multiple errors
4. **Context propagation**: Always pass `context.Context` as the first parameter for functions that may block
5. **Configuration structures**: Use struct tags with JSON annotations and default values
Example configuration pattern:
```go
type Config struct {
Host string `json:",default=0.0.0.0"`
Port int `json:",default=8080"`
Timeout int `json:",default=3000"`
Optional string `json:",optional"`
}
```
### Interface Design
1. **Small interfaces**: Follow Go's preference for small, focused interfaces
2. **Context methods**: Provide both context and non-context versions of methods
3. **Options pattern**: Use functional options for complex configuration
Example:
```go
func (c *Client) Get(key string, val any) error {
return c.GetCtx(context.Background(), key, val)
}
func (c *Client) GetCtx(ctx context.Context, key string, val any) error {
// implementation
}
```
### Testing Patterns
1. **Test file naming**: Use `*_test.go` suffix
2. **Test function naming**: Use `TestFunctionName` pattern
3. **Use testify/assert**: Prefer `assert` package for assertions
4. **Table-driven tests**: Use table-driven tests for multiple scenarios
5. **Mock interfaces**: Use `go.uber.org/mock` for mocking
6. **Test helpers**: Use `redistest`, `mongtest` helpers for database testing
Example test pattern:
```go
func TestSomething(t *testing.T) {
tests := []struct {
name string
input string
expected string
wantErr bool
}{
{"valid case", "input", "output", false},
{"error case", "bad", "", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := SomeFunction(tt.input)
if tt.wantErr {
assert.Error(t, err)
return
}
assert.NoError(t, err)
assert.Equal(t, tt.expected, result)
})
}
}
```
## Framework-Specific Guidelines
### REST API Development
1. **API Definition**: Use `.api` files to define REST APIs
2. **Handler pattern**: Separate business logic into logic packages
3. **Middleware**: Use built-in middlewares (tracing, logging, metrics, recovery)
4. **Response handling**: Use `httpx.WriteJson` for JSON responses
5. **Error handling**: Use `httpx.Error` for HTTP error responses
### RPC Development
1. **Protocol Buffers**: Use protobuf for service definitions
2. **Service discovery**: Integrate with etcd for service registration
3. **Load balancing**: Use built-in load balancing strategies
4. **Interceptors**: Implement interceptors for cross-cutting concerns
### Database Operations
1. **SQL operations**: Use `sqlx` package for database operations
2. **Caching**: Implement caching patterns with `cache` package
3. **Transactions**: Use proper transaction handling
4. **Connection pooling**: Configure appropriate connection pools
Example cache pattern:
```go
err := c.QueryRowCtx(ctx, &dest, key, func(ctx context.Context, conn sqlx.SqlConn) error {
return conn.QueryRowCtx(ctx, &dest, query, args...)
})
```
### Configuration Management
1. **YAML configuration**: Use YAML for configuration files
2. **Environment variables**: Support environment variable overrides
3. **Validation**: Include proper validation for configuration parameters
4. **Sensible defaults**: Provide reasonable default values
## Error Handling Best Practices
1. **Wrap errors**: Use `fmt.Errorf` with `%w` verb to wrap errors
2. **Custom errors**: Define custom error types when needed
3. **Error logging**: Log errors appropriately with context
4. **Graceful degradation**: Implement fallback mechanisms
## Performance Considerations
1. **Resource pools**: Use connection pools and worker pools
2. **Circuit breakers**: Implement circuit breaker patterns for external calls
3. **Rate limiting**: Apply rate limiting to protect services
4. **Load shedding**: Implement adaptive load shedding
5. **Metrics**: Add appropriate metrics and monitoring
## Security Guidelines
1. **Input validation**: Validate all input parameters
2. **SQL injection prevention**: Use parameterized queries
3. **Authentication**: Implement proper JWT token handling
4. **HTTPS**: Support TLS/HTTPS configurations
5. **CORS**: Configure CORS appropriately for web APIs
## Documentation Standards
1. **Package documentation**: Include package-level documentation
2. **Function documentation**: Document exported functions with examples
3. **API documentation**: Maintain API documentation in sync
4. **README updates**: Update README for significant changes
## Common Patterns to Follow
### Service Configuration
```go
type ServiceConf struct {
Name string
Log logx.LogConf
Mode string `json:",default=pro,options=[dev,test,pre,pro]"`
// ... other common fields
}
```
### Middleware Implementation
```go
func SomeMiddleware() rest.Middleware {
return func(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Pre-processing
next.ServeHTTP(w, r)
// Post-processing
}
}
}
```
### Resource Management
Always implement proper resource cleanup using defer and context cancellation.
## Build and Test Commands
- Build: `go build ./...`
- Test: `go test ./...`
- Test with race detection: `go test -race ./...`
- Format: `gofmt -w .`
- Generate code: `goctl api go -api *.api -dir .`
Remember to run tests and ensure all checks pass before submitting changes. The project emphasizes high quality, performance, and reliability, so these should be primary considerations in all development work.

View File

@@ -5,19 +5,7 @@
version: 2 version: 2
updates: updates:
- package-ecosystem: "docker" # Update image tags in Dockerfile
directory: "/"
schedule:
interval: "weekly"
- package-ecosystem: "github-actions" # Update GitHub Actions
directory: "/"
schedule:
interval: "weekly"
- package-ecosystem: "gomod" # See documentation for possible values - package-ecosystem: "gomod" # See documentation for possible values
directory: "/" # Location of package manifests directory: "/" # Location of package manifests
schedule: schedule:
interval: "daily" interval: "daily"
- package-ecosystem: "gomod" # See documentation for possible values
directory: "/tools/goctl" # Location of package manifests
schedule:
interval: "daily"

View File

@@ -35,11 +35,11 @@ jobs:
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v5 uses: actions/checkout@v3
# Initializes the CodeQL tools for scanning. # Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL - name: Initialize CodeQL
uses: github/codeql-action/init@v4 uses: github/codeql-action/init@v2
with: with:
languages: ${{ matrix.language }} languages: ${{ matrix.language }}
# If you wish to specify custom queries, you can do so here or in a config file. # If you wish to specify custom queries, you can do so here or in a config file.
@@ -50,7 +50,7 @@ jobs:
# Autobuild attempts to build any compiled languages (C/C++, C#, or Java). # Autobuild attempts to build any compiled languages (C/C++, C#, or Java).
# If this step fails, then you should remove it and run the build manually (see below) # If this step fails, then you should remove it and run the build manually (see below)
- name: Autobuild - name: Autobuild
uses: github/codeql-action/autobuild@v4 uses: github/codeql-action/autobuild@v2
# Command-line programs to run using the OS shell. # Command-line programs to run using the OS shell.
# 📚 https://git.io/JvXDl # 📚 https://git.io/JvXDl
@@ -64,4 +64,4 @@ jobs:
# make release # make release
- name: Perform CodeQL Analysis - name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v4 uses: github/codeql-action/analyze@v2

View File

@@ -12,12 +12,12 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Check out code into the Go module directory - name: Check out code into the Go module directory
uses: actions/checkout@v5 uses: actions/checkout@v3
- name: Set up Go 1.x - name: Set up Go 1.x
uses: actions/setup-go@v6 uses: actions/setup-go@v3
with: with:
go-version-file: go.mod go-version: ^1.16
check-latest: true check-latest: true
cache: true cache: true
id: go id: go
@@ -29,36 +29,27 @@ jobs:
- name: Lint - name: Lint
run: | run: |
go vet -stdmethods=false $(go list ./...) go vet -stdmethods=false $(go list ./...)
go install mvdan.cc/gofumpt@latest
go mod tidy test -z "$(gofumpt -l -extra .)" || echo "Please run 'gofumpt -l -w -extra .'"
if ! test -z "$(git status --porcelain)"; then
echo "Please run 'go mod tidy'"
exit 1
fi
- name: Test - name: Test
run: go test -race -coverprofile=coverage.txt -covermode=atomic ./... run: go test -race -coverprofile=coverage.txt -covermode=atomic ./...
- name: Codecov - name: Codecov
uses: codecov/codecov-action@v5 uses: codecov/codecov-action@v3
with:
files: ./coverage.txt
flags: unittests
name: codecov-umbrella
fail_ci_if_error: false
test-win: test-win:
name: Windows name: Windows
runs-on: windows-latest runs-on: windows-latest
steps: steps:
- name: Checkout codebase - name: Checkout codebase
uses: actions/checkout@v5 uses: actions/checkout@v3
- name: Set up Go 1.x - name: Set up Go 1.x
uses: actions/setup-go@v6 uses: actions/setup-go@v3
with: with:
# make sure Go version compatible with go-zero # use 1.16 to guarantee Go 1.16 compatibility
go-version-file: go.mod go-version: 1.16
check-latest: true check-latest: true
cache: true cache: true
@@ -66,5 +57,5 @@ jobs:
run: | run: |
go mod verify go mod verify
go mod download go mod download
go test ./... go test -v -race ./...
cd tools/goctl && go build -v goctl.go cd tools/goctl && go build -v goctl.go

18
.github/workflows/issue-translator.yml vendored Normal file
View File

@@ -0,0 +1,18 @@
name: 'issue-translator'
on:
issue_comment:
types: [created]
issues:
types: [opened]
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: usthe/issues-translate-action@v2.7
with:
IS_MODIFY_TITLE: true
# not require, default false, . Decide whether to modify the issue title
# if true, the robot account @Issues-translate-bot must have modification permissions, invite @Issues-translate-bot to your project or use your custom bot.
CUSTOM_BOT_NOTE: Bot detected the issue body's language is not English, translate it automatically. 👯👭🏻🧑‍🤝‍🧑👫🧑🏿‍🤝‍🧑🏻👩🏾‍🤝‍👨🏿👬🏿
# not require. Customize the translation robot prefix message.

View File

@@ -7,7 +7,7 @@ jobs:
close-issues: close-issues:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/stale@v10 - uses: actions/stale@v6
with: with:
days-before-issue-stale: 365 days-before-issue-stale: 365
days-before-issue-close: 90 days-before-issue-close: 90

View File

@@ -16,13 +16,13 @@ jobs:
- goarch: "386" - goarch: "386"
goos: darwin goos: darwin
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v3
- uses: zeromicro/go-zero-release-action@master - uses: zeromicro/go-zero-release-action@master
with: with:
github_token: ${{ secrets.GITHUB_TOKEN }} github_token: ${{ secrets.GITHUB_TOKEN }}
goos: ${{ matrix.goos }} goos: ${{ matrix.goos }}
goarch: ${{ matrix.goarch }} goarch: ${{ matrix.goarch }}
goversion: "https://dl.google.com/go/go1.21.13.linux-amd64.tar.gz" goversion: "https://dl.google.com/go/go1.17.5.linux-amd64.tar.gz"
project_path: "tools/goctl" project_path: "tools/goctl"
binary_name: "goctl" binary_name: "goctl"
extra_files: tools/goctl/readme.md tools/goctl/readme-cn.md extra_files: tools/goctl/readme.md tools/goctl/readme-cn.md

View File

@@ -5,7 +5,7 @@ jobs:
name: runner / staticcheck name: runner / staticcheck
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v3
- uses: reviewdog/action-staticcheck@v1 - uses: reviewdog/action-staticcheck@v1
with: with:
github_token: ${{ secrets.github_token }} github_token: ${{ secrets.github_token }}
@@ -14,6 +14,6 @@ jobs:
# Report all results. # Report all results.
filter_mode: nofilter filter_mode: nofilter
# Exit with 1 when it find at least one finding. # Exit with 1 when it find at least one finding.
fail_level: any fail_on_error: true
# Set staticcheck flags # Set staticcheck flags
staticcheck_flags: -checks=inherit,-SA1019,-SA1029,-SA5008 staticcheck_flags: -checks=inherit,-SA1019,-SA1029,-SA5008

View File

@@ -1,42 +0,0 @@
name: Release Version Check
on:
push:
tags:
- 'tools/goctl/v*'
workflow_dispatch:
jobs:
version-check:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v5
- name: Set up Go
uses: actions/setup-go@v6
with:
go-version: '1.21'
- name: Extract tag version
id: get_version
run: |
# Extract version from tools/goctl/v* format
VERSION="${GITHUB_REF#refs/tags/tools/goctl/v}"
echo "VERSION=$VERSION" >> $GITHUB_ENV
echo "Extracted version: $VERSION"
- name: Check version in goctl source code
run: |
# Change to goctl directory
cd tools/goctl
# Check version in BuildVersion constant
VERSION_IN_CODE=$(grep -r "const BuildVersion =" . | grep -o '".*"' | tr -d '"')
echo "Version in code: $VERSION_IN_CODE"
echo "Expected version: $VERSION"
if [ "$VERSION_IN_CODE" != "$VERSION" ]; then
echo "Version mismatch: Version in code ($VERSION_IN_CODE) doesn't match tag version ($VERSION)"
exit 1
fi
echo "✅ Version check passed!"

7
.gitignore vendored
View File

@@ -11,15 +11,12 @@
!api !api
# ignore # ignore
**/.idea .idea
**/.vscode
**/.DS_Store **/.DS_Store
**/logs **/logs
**/adhoc
**/coverage.txt
**/WARP.md
# for test purpose # for test purpose
**/adhoc
go.work go.work
go.work.sum go.work.sum

View File

@@ -1,76 +1,102 @@
# 🚀 Contributing to go-zero # Contributing
Welcome to the go-zero community! We're thrilled to have you here. Contributing to our project is a fantastic way to be a part of the go-zero journey. Let's make this guide exciting and fun! Welcome to go-zero!
## 📜 Before You Dive In - [Before you get started](#before-you-get-started)
- [Code of Conduct](#code-of-conduct)
- [Community Expectations](#community-expectations)
- [Getting started](#getting-started)
- [Your First Contribution](#your-first-contribution)
- [Find something to work on](#find-something-to-work-on)
- [Find a good first topic](#find-a-good-first-topic)
- [Work on an Issue](#work-on-an-issue)
- [File an Issue](#file-an-issue)
- [Contributor Workflow](#contributor-workflow)
- [Creating Pull Requests](#creating-pull-requests)
- [Code Review](#code-review)
- [Testing](#testing)
### 🤝 Code of Conduct # Before you get started
Let's start on the right foot. Please take a moment to read and embrace our [Code of Conduct](/code-of-conduct.md). We're all about creating a welcoming and respectful environment. ## Code of Conduct
### 🌟 Community Expectations Please make sure to read and observe our [Code of Conduct](/code-of-conduct.md).
At go-zero, we're like a close-knit family, and we believe in creating a healthy, friendly, and productive atmosphere. It's all about sharing knowledge and building amazing things together. ## Community Expectations
## 🚀 Getting Started go-zero is a community project driven by its community which strives to promote a healthy, friendly and productive environment.
go-zero is a web and rpc framework written in Go. It's born to ensure the stability of the busy sites with resilient design. Builtin goctl greatly improves the development productivity.
Get your adventure rolling! Here's how to begin: # Getting started
1. 🍴 **Fork the Repository**: Head over to the GitHub repository and fork it to your own space. - Fork the repository on GitHub.
- Make your changes on your fork repository.
- Submit a PR.
2. 🛠️ **Make Your Magic**: Work your magic in your forked repository. Create new features, squash bugs, or improve documentation - it's your world to conquer!
3. 🚀 **Submit a PR (Pull Request)**: When you're ready to unveil your creation, submit a Pull Request. We can't wait to see your awesome work! # Your First Contribution
## 🌟 Your First Contribution We will help you to contribute in different areas like filing issues, developing features, fixing critical bugs and
getting your work reviewed and merged.
We're here to guide you on your quest to become a go-zero contributor. Whether you want to file issues, develop features, or tame some critical bugs, we've got you covered. If you have questions about the development process,
feel free to [file an issue](https://github.com/zeromicro/go-zero/issues/new/choose).
If you have questions or need guidance at any stage, don't hesitate to [open an issue](https://github.com/zeromicro/go-zero/issues/new/choose). ## Find something to work on
## 🔍 Find Something to Work On We are always in need of help, be it fixing documentation, reporting bugs or writing some code.
Look at places where you feel best coding practices aren't followed, code refactoring is needed or tests are missing.
Here is how you get started.
Ready to dive into the action? There are several ways to contribute: ### Find a good first topic
### 💼 Find a Good First Topic [go-zero](https://github.com/zeromicro/go-zero) has beginner-friendly issues that provide a good first issue.
For example, [go-zero](https://github.com/zeromicro/go-zero) has
[help wanted](https://github.com/zeromicro/go-zero/issues?q=is%3Aopen+is%3Aissue+label%3A%22help+wanted%22) and
[good first issue](https://github.com/zeromicro/go-zero/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)
labels for issues that should not need deep knowledge of the system.
We can help new contributors who wish to work on such issues.
Discover easy-entry issues labeled as [help wanted](https://github.com/zeromicro/go-zero/issues?q=is%3Aopen+is%3Aissue+label%3A%22help+wanted%22) or [good first issue](https://github.com/zeromicro/go-zero/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22). These issues are perfect for newcomers and don't require deep knowledge of the system. We're here to assist you with these tasks. Another good way to contribute is to find a documentation improvement, such as a missing/broken link.
Please see [Contributing](#contributing) below for the workflow.
### 🪄 Work on an Issue #### Work on an issue
Once you've picked an issue that excites you, let us know by commenting on it. Our maintainers will assign it to you, and you can embark on your mission! When you are willing to take on an issue, just reply on the issue. The maintainer will assign it to you.
### 📢 File an Issue ### File an Issue
Reporting an issue is just as valuable as code contributions. If you discover a problem, don't hesitate to [open an issue](https://github.com/zeromicro/go-zero/issues/new/choose). Be sure to follow our guidelines when submitting an issue. While we encourage everyone to contribute code, it is also appreciated when someone reports an issue.
## 🎯 Contributor Workflow Please follow the prompted submission guidelines while opening an issue.
Here's a rough guide to your contributor journey: # Contributor Workflow
1. 🌱 Create a New Branch: Start by creating a topic branch, usually based on the 'master' branch. This is where your contribution will grow. Please do not ever hesitate to ask a question or send a pull request.
2. 💡 Make Commits: Commit your work in logical units. Each commit should tell a story. This is a rough outline of what a contributor's workflow looks like:
3. 🚀 Push Changes: Push the changes in your topic branch to your personal fork of the repository. - Create a topic branch from where to base the contribution. This is usually master.
- Make commits of logical units.
- Push changes in a topic branch to a personal fork of the repository.
- Submit a pull request to [go-zero](https://github.com/zeromicro/go-zero).
4. 📦 Submit a Pull Request: When your creation is complete, submit a Pull Request to the [go-zero repository](https://github.com/zeromicro/go-zero). ## Creating Pull Requests
## 🌠 Creating Pull Requests Pull requests are often called simply "PR".
go-zero generally follows the standard [github pull request](https://help.github.com/articles/about-pull-requests/) process.
To submit a proposed change, please develop the code/fix and add new test cases.
After that, run these local verifications before submitting pull request to predict the pass or
fail of continuous integration.
Pull Requests (PRs) are your way of making a grand entrance with your contribution. Here's how to do it: * Format the code with `gofmt`
* Run the test with data race enabled `go test -race ./...`
- 💼 Format Your Code: Ensure your code is beautifully formatted with `gofmt`. ## Code Review
- 🏃 Run Tests: Verify that your changes pass all the tests, including data race tests. Run `go test -race ./...` for the ultimate validation.
## 👁️‍🗨️ Code Review To make it easier for your PR to receive reviews, consider the reviewers will need you to:
Getting your PR reviewed is the final step before your contribution becomes part of go-zero's magical world. To make the process smooth, keep these things in mind: * follow [good coding guidelines](https://github.com/golang/go/wiki/CodeReviewComments).
* write [good commit messages](https://chris.beams.io/posts/git-commit/).
* break large changes into a logical series of smaller patches which individually make easily understandable changes, and in aggregate solve a broader issue.
- 🧙‍♀️ Follow Good Coding Practices: Stick to [good coding guidelines](https://github.com/golang/go/wiki/CodeReviewComments).
- 📝 Write Awesome Commit Messages: Craft [impressive commit messages](https://chris.beams.io/posts/git-commit/) - they're like spells in the wizard's book!
- 🔍 Break It Down: For larger changes, consider breaking them into a series of smaller, logical patches. Each patch should make an understandable and meaningful improvement.
Congratulations on your contribution journey! We're thrilled to have you as part of our go-zero community. Let's make amazing things together! 🌟
Now, go out there and start your adventure! If you have any more magical ideas to enhance this guide, please share them. 🔥

View File

@@ -1,16 +0,0 @@
# Security Policy
## Supported Versions
We publish releases monthly.
| Version | Supported |
| ------- | ------------------ |
| >= 1.4.4 | :white_check_mark: |
| < 1.4.4 | :x: |
## Reporting a Vulnerability
https://github.com/zeromicro/go-zero/security/advisories
Accepted vulnerabilities are expected to be fixed within a month.

View File

@@ -1,127 +1,76 @@
# Contributor Covenant Code of Conduct # Contributor Covenant Code of Conduct
## Our Pledge ## Our Pledge
We as members, contributors, and leaders pledge to make participation in our In the interest of fostering an open and welcoming environment, we as
community a harassment-free experience for everyone, regardless of age, body contributors and maintainers pledge to make participation in our project and
size, visible or invisible disability, ethnicity, sex characteristics, gender our community a harassment-free experience for everyone, regardless of age, body
identity and expression, level of experience, education, socio-economic status, size, disability, ethnicity, sex characteristics, gender identity and expression,
nationality, personal appearance, race, caste, color, religion, or sexual level of experience, education, socio-economic status, nationality, personal
identity and orientation. appearance, race, religion, or sexual identity and orientation.
We pledge to act and interact in ways that contribute to an open, welcoming,
diverse, inclusive, and healthy community.
## Our Standards ## Our Standards
Examples of behavior that contributes to a positive environment for our Examples of behavior that contributes to creating a positive environment
community include: include:
* Demonstrating empathy and kindness toward other people * Using welcoming and inclusive language
* Being respectful of differing opinions, viewpoints, and experiences * Being respectful of differing viewpoints and experiences
* Giving and gracefully accepting constructive feedback * Gracefully accepting constructive criticism
* Accepting responsibility and apologizing to those affected by our mistakes, * Focusing on what is best for the community
and learning from the experience * Showing empathy towards other community members
* Focusing on what is best not just for us as individuals, but for the overall
community
Examples of unacceptable behavior include: Examples of unacceptable behavior by participants include:
* The use of sexualized language or imagery, and sexual attention or advances of * The use of sexualized language or imagery and unwelcome sexual attention or
any kind advances
* Trolling, insulting or derogatory comments, and personal or political attacks * Trolling, insulting/derogatory comments, and personal or political attacks
* Public or private harassment * Public or private harassment
* Publishing others' private information, such as a physical or email address, * Publishing others' private information, such as a physical or electronic
without their explicit permission address, without explicit permission
* Other conduct which could reasonably be considered inappropriate in a * Other conduct which could reasonably be considered inappropriate in a
professional setting professional setting
## Enforcement Responsibilities ## Our Responsibilities
Community leaders are responsible for clarifying and enforcing our standards of Project maintainers are responsible for clarifying the standards of acceptable
acceptable behavior and will take appropriate and fair corrective action in behavior and are expected to take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive, response to any instances of unacceptable behavior.
or harmful.
Community leaders have the right and responsibility to remove, edit, or reject Project maintainers have the right and responsibility to remove, edit, or
comments, commits, code, wiki edits, issues, and other contributions that are reject comments, commits, code, wiki edits, issues, and other contributions
not aligned to this Code of Conduct, and will communicate reasons for moderation that are not aligned to this Code of Conduct, or to ban temporarily or
decisions when appropriate. permanently any contributor for other behaviors that they deem inappropriate,
threatening, offensive, or harmful.
## Scope ## Scope
This Code of Conduct applies within all community spaces, and also applies when This Code of Conduct applies within all project spaces, and it also applies when
an individual is officially representing the community in public spaces. an individual is representing the project or its community in public spaces.
Examples of representing our community include using an official e-mail address, Examples of representing a project or community include using an official
posting via an official social media account, or acting as an appointed project e-mail address, posting via an official social media account, or acting
representative at an online or offline event. as an appointed representative at an online or offline event. Representation of
a project may be further defined and clarified by project maintainers.
## Enforcement ## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement at reported by contacting the project team at [INSERT EMAIL ADDRESS]. All
[INSERT CONTACT METHOD]. complaints will be reviewed and investigated and will result in a response that
All complaints will be reviewed and investigated promptly and fairly. is deemed necessary and appropriate to the circumstances. The project team is
obligated to maintain confidentiality with regard to the reporter of an incident.
Further details of specific enforcement policies may be posted separately.
All community leaders are obligated to respect the privacy and security of the Project maintainers who do not follow or enforce the Code of Conduct in good
reporter of any incident. faith may face temporary or permanent repercussions as determined by other
members of the project's leadership.
## Enforcement Guidelines
Community leaders will follow these Community Impact Guidelines in determining
the consequences for any action they deem in violation of this Code of Conduct:
### 1. Correction
**Community Impact**: Use of inappropriate language or other behavior deemed
unprofessional or unwelcome in the community.
**Consequence**: A private, written warning from community leaders, providing
clarity around the nature of the violation and an explanation of why the
behavior was inappropriate. A public apology may be requested.
### 2. Warning
**Community Impact**: A violation through a single incident or series of
actions.
**Consequence**: A warning with consequences for continued behavior. No
interaction with the people involved, including unsolicited interaction with
those enforcing the Code of Conduct, for a specified period of time. This
includes avoiding interactions in community spaces as well as external channels
like social media. Violating these terms may lead to a temporary or permanent
ban.
### 3. Temporary Ban
**Community Impact**: A serious violation of community standards, including
sustained inappropriate behavior.
**Consequence**: A temporary ban from any sort of interaction or public
communication with the community for a specified period of time. No public or
private interaction with the people involved, including unsolicited interaction
with those enforcing the Code of Conduct, is allowed during this period.
Violating these terms may lead to a permanent ban.
### 4. Permanent Ban
**Community Impact**: Demonstrating a pattern of violation of community
standards, including sustained inappropriate behavior, harassment of an
individual, or aggression toward or disparagement of classes of individuals.
**Consequence**: A permanent ban from any sort of public interaction within the
community.
## Attribution ## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage], This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
version 2.1, available at available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
[https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1].
Community Impact Guidelines were inspired by [homepage]: https://www.contributor-covenant.org
[Mozilla's code of conduct enforcement ladder][Mozilla CoC].
For answers to common questions about this code of conduct, see the FAQ at For answers to common questions about this code of conduct, see
[https://www.contributor-covenant.org/faq][FAQ]. Translations are available at https://www.contributor-covenant.org/faq
[https://www.contributor-covenant.org/translations][translations].

View File

@@ -1,8 +1,6 @@
package bloom package bloom
import ( import (
"context"
_ "embed"
"errors" "errors"
"strconv" "strconv"
@@ -10,23 +8,28 @@ import (
"github.com/zeromicro/go-zero/core/stores/redis" "github.com/zeromicro/go-zero/core/stores/redis"
) )
// for detailed error rate table, see http://pages.cs.wisc.edu/~cao/papers/summary-cache/node8.html const (
// maps as k in the error rate table // for detailed error rate table, see http://pages.cs.wisc.edu/~cao/papers/summary-cache/node8.html
const maps = 14 // maps as k in the error rate table
maps = 14
var ( setScript = `
// ErrTooLargeOffset indicates the offset is too large in bitset. for _, offset in ipairs(ARGV) do
ErrTooLargeOffset = errors.New("too large offset") redis.call("setbit", KEYS[1], offset, 1)
end
//go:embed setscript.lua `
setLuaScript string testScript = `
setScript = redis.NewScript(setLuaScript) for _, offset in ipairs(ARGV) do
if tonumber(redis.call("getbit", KEYS[1], offset)) == 0 then
//go:embed testscript.lua return false
testLuaScript string end
testScript = redis.NewScript(testLuaScript) end
return true
`
) )
// ErrTooLargeOffset indicates the offset is too large in bitset.
var ErrTooLargeOffset = errors.New("too large offset")
type ( type (
// A Filter is a bloom filter. // A Filter is a bloom filter.
Filter struct { Filter struct {
@@ -35,8 +38,8 @@ type (
} }
bitSetProvider interface { bitSetProvider interface {
check(ctx context.Context, offsets []uint) (bool, error) check([]uint) (bool, error)
set(ctx context.Context, offsets []uint) error set([]uint) error
} }
) )
@@ -55,24 +58,14 @@ func New(store *redis.Redis, key string, bits uint) *Filter {
// Add adds data into f. // Add adds data into f.
func (f *Filter) Add(data []byte) error { func (f *Filter) Add(data []byte) error {
return f.AddCtx(context.Background(), data)
}
// AddCtx adds data into f with context.
func (f *Filter) AddCtx(ctx context.Context, data []byte) error {
locations := f.getLocations(data) locations := f.getLocations(data)
return f.bitSet.set(ctx, locations) return f.bitSet.set(locations)
} }
// Exists checks if data is in f. // Exists checks if data is in f.
func (f *Filter) Exists(data []byte) (bool, error) { func (f *Filter) Exists(data []byte) (bool, error) {
return f.ExistsCtx(context.Background(), data)
}
// ExistsCtx checks if data is in f with context.
func (f *Filter) ExistsCtx(ctx context.Context, data []byte) (bool, error) {
locations := f.getLocations(data) locations := f.getLocations(data)
isSet, err := f.bitSet.check(ctx, locations) isSet, err := f.bitSet.check(locations)
if err != nil { if err != nil {
return false, err return false, err
} }
@@ -105,7 +98,7 @@ func newRedisBitSet(store *redis.Redis, key string, bits uint) *redisBitSet {
} }
func (r *redisBitSet) buildOffsetArgs(offsets []uint) ([]string, error) { func (r *redisBitSet) buildOffsetArgs(offsets []uint) ([]string, error) {
args := make([]string, 0, len(offsets)) var args []string
for _, offset := range offsets { for _, offset := range offsets {
if offset >= r.bits { if offset >= r.bits {
@@ -118,14 +111,14 @@ func (r *redisBitSet) buildOffsetArgs(offsets []uint) ([]string, error) {
return args, nil return args, nil
} }
func (r *redisBitSet) check(ctx context.Context, offsets []uint) (bool, error) { func (r *redisBitSet) check(offsets []uint) (bool, error) {
args, err := r.buildOffsetArgs(offsets) args, err := r.buildOffsetArgs(offsets)
if err != nil { if err != nil {
return false, err return false, err
} }
resp, err := r.store.ScriptRunCtx(ctx, testScript, []string{r.key}, args) resp, err := r.store.Eval(testScript, []string{r.key}, args)
if errors.Is(err, redis.Nil) { if err == redis.Nil {
return false, nil return false, nil
} else if err != nil { } else if err != nil {
return false, err return false, err
@@ -139,25 +132,23 @@ func (r *redisBitSet) check(ctx context.Context, offsets []uint) (bool, error) {
return exists == 1, nil return exists == 1, nil
} }
// del only use for testing.
func (r *redisBitSet) del() error { func (r *redisBitSet) del() error {
_, err := r.store.Del(r.key) _, err := r.store.Del(r.key)
return err return err
} }
// expire only use for testing.
func (r *redisBitSet) expire(seconds int) error { func (r *redisBitSet) expire(seconds int) error {
return r.store.Expire(r.key, seconds) return r.store.Expire(r.key, seconds)
} }
func (r *redisBitSet) set(ctx context.Context, offsets []uint) error { func (r *redisBitSet) set(offsets []uint) error {
args, err := r.buildOffsetArgs(offsets) args, err := r.buildOffsetArgs(offsets)
if err != nil { if err != nil {
return err return err
} }
_, err = r.store.ScriptRunCtx(ctx, setScript, []string{r.key}, args) _, err = r.store.Eval(setScript, []string{r.key}, args)
if errors.Is(err, redis.Nil) { if err == redis.Nil {
return nil return nil
} }

View File

@@ -1,31 +1,30 @@
package bloom package bloom
import ( import (
"context"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stores/redis/redistest" "github.com/zeromicro/go-zero/core/stores/redis/redistest"
) )
func TestRedisBitSet_New_Set_Test(t *testing.T) { func TestRedisBitSet_New_Set_Test(t *testing.T) {
store := redistest.CreateRedis(t) store, clean, err := redistest.CreateRedis()
ctx := context.Background() assert.Nil(t, err)
defer clean()
bitSet := newRedisBitSet(store, "test_key", 1024) bitSet := newRedisBitSet(store, "test_key", 1024)
isSetBefore, err := bitSet.check(ctx, []uint{0}) isSetBefore, err := bitSet.check([]uint{0})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if isSetBefore { if isSetBefore {
t.Fatal("Bit should not be set") t.Fatal("Bit should not be set")
} }
err = bitSet.set(ctx, []uint{512}) err = bitSet.set([]uint{512})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
isSetAfter, err := bitSet.check(ctx, []uint{512}) isSetAfter, err := bitSet.check([]uint{512})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -43,7 +42,9 @@ func TestRedisBitSet_New_Set_Test(t *testing.T) {
} }
func TestRedisBitSet_Add(t *testing.T) { func TestRedisBitSet_Add(t *testing.T) {
store := redistest.CreateRedis(t) store, clean, err := redistest.CreateRedis()
assert.Nil(t, err)
defer clean()
filter := New(store, "test_key", 64) filter := New(store, "test_key", 64)
assert.Nil(t, filter.Add([]byte("hello"))) assert.Nil(t, filter.Add([]byte("hello")))
@@ -52,51 +53,3 @@ func TestRedisBitSet_Add(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, ok) assert.True(t, ok)
} }
func TestFilter_Exists(t *testing.T) {
store, clean := redistest.CreateRedisWithClean(t)
rbs := New(store, "test", 64)
_, err := rbs.Exists([]byte{0, 1, 2})
assert.NoError(t, err)
clean()
rbs = New(store, "test", 64)
_, err = rbs.Exists([]byte{0, 1, 2})
assert.Error(t, err)
}
func TestRedisBitSet_check(t *testing.T) {
store, clean := redistest.CreateRedisWithClean(t)
ctx := context.Background()
rbs := newRedisBitSet(store, "test", 0)
assert.Error(t, rbs.set(ctx, []uint{0, 1, 2}))
_, err := rbs.check(ctx, []uint{0, 1, 2})
assert.Error(t, err)
rbs = newRedisBitSet(store, "test", 64)
_, err = rbs.check(ctx, []uint{0, 1, 2})
assert.NoError(t, err)
clean()
rbs = newRedisBitSet(store, "test", 64)
_, err = rbs.check(ctx, []uint{0, 1, 2})
assert.Error(t, err)
}
func TestRedisBitSet_set(t *testing.T) {
logx.Disable()
store, clean := redistest.CreateRedisWithClean(t)
ctx := context.Background()
rbs := newRedisBitSet(store, "test", 0)
assert.Error(t, rbs.set(ctx, []uint{0, 1, 2}))
rbs = newRedisBitSet(store, "test", 64)
assert.NoError(t, rbs.set(ctx, []uint{0, 1, 2}))
clean()
rbs = newRedisBitSet(store, "test", 64)
assert.Error(t, rbs.set(ctx, []uint{0, 1, 2}))
}

View File

@@ -1,3 +0,0 @@
for _, offset in ipairs(ARGV) do
redis.call("setbit", KEYS[1], offset, 1)
end

View File

@@ -1,6 +0,0 @@
for _, offset in ipairs(ARGV) do
if tonumber(redis.call("getbit", KEYS[1], offset)) == 0 then
return false
end
end
return true

View File

@@ -1,19 +1,22 @@
package breaker package breaker
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/zeromicro/go-zero/core/mathx"
"github.com/zeromicro/go-zero/core/proc" "github.com/zeromicro/go-zero/core/proc"
"github.com/zeromicro/go-zero/core/stat" "github.com/zeromicro/go-zero/core/stat"
"github.com/zeromicro/go-zero/core/stringx" "github.com/zeromicro/go-zero/core/stringx"
) )
const numHistoryReasons = 5 const (
numHistoryReasons = 5
timeFormat = "15:04:05"
)
// ErrServiceUnavailable is returned when the Breaker state is open. // ErrServiceUnavailable is returned when the Breaker state is open.
var ErrServiceUnavailable = errors.New("circuit breaker is open") var ErrServiceUnavailable = errors.New("circuit breaker is open")
@@ -28,53 +31,38 @@ type (
Name() string Name() string
// Allow checks if the request is allowed. // Allow checks if the request is allowed.
// If allowed, a promise will be returned, // If allowed, a promise will be returned, the caller needs to call promise.Accept()
// otherwise ErrServiceUnavailable will be returned as the error. // on success, or call promise.Reject() on failure.
// The caller needs to call promise.Accept() on success, // If not allow, ErrServiceUnavailable will be returned.
// or call promise.Reject() on failure.
Allow() (Promise, error) Allow() (Promise, error)
// AllowCtx checks if the request is allowed when ctx isn't done.
AllowCtx(ctx context.Context) (Promise, error)
// Do runs the given request if the Breaker accepts it. // Do runs the given request if the Breaker accepts it.
// Do returns an error instantly if the Breaker rejects the request. // Do returns an error instantly if the Breaker rejects the request.
// If a panic occurs in the request, the Breaker handles it as an error // If a panic occurs in the request, the Breaker handles it as an error
// and causes the same panic again. // and causes the same panic again.
Do(req func() error) error Do(req func() error) error
// DoCtx runs the given request if the Breaker accepts it when ctx isn't done.
DoCtx(ctx context.Context, req func() error) error
// DoWithAcceptable runs the given request if the Breaker accepts it. // DoWithAcceptable runs the given request if the Breaker accepts it.
// DoWithAcceptable returns an error instantly if the Breaker rejects the request. // DoWithAcceptable returns an error instantly if the Breaker rejects the request.
// If a panic occurs in the request, the Breaker handles it as an error // If a panic occurs in the request, the Breaker handles it as an error
// and causes the same panic again. // and causes the same panic again.
// acceptable checks if it's a successful call, even if the error is not nil. // acceptable checks if it's a successful call, even if the err is not nil.
DoWithAcceptable(req func() error, acceptable Acceptable) error DoWithAcceptable(req func() error, acceptable Acceptable) error
// DoWithAcceptableCtx runs the given request if the Breaker accepts it when ctx isn't done.
DoWithAcceptableCtx(ctx context.Context, req func() error, acceptable Acceptable) error
// DoWithFallback runs the given request if the Breaker accepts it. // DoWithFallback runs the given request if the Breaker accepts it.
// DoWithFallback runs the fallback if the Breaker rejects the request. // DoWithFallback runs the fallback if the Breaker rejects the request.
// If a panic occurs in the request, the Breaker handles it as an error // If a panic occurs in the request, the Breaker handles it as an error
// and causes the same panic again. // and causes the same panic again.
DoWithFallback(req func() error, fallback Fallback) error DoWithFallback(req func() error, fallback func(err error) error) error
// DoWithFallbackCtx runs the given request if the Breaker accepts it when ctx isn't done.
DoWithFallbackCtx(ctx context.Context, req func() error, fallback Fallback) error
// DoWithFallbackAcceptable runs the given request if the Breaker accepts it. // DoWithFallbackAcceptable runs the given request if the Breaker accepts it.
// DoWithFallbackAcceptable runs the fallback if the Breaker rejects the request. // DoWithFallbackAcceptable runs the fallback if the Breaker rejects the request.
// If a panic occurs in the request, the Breaker handles it as an error // If a panic occurs in the request, the Breaker handles it as an error
// and causes the same panic again. // and causes the same panic again.
// acceptable checks if it's a successful call, even if the error is not nil. // acceptable checks if it's a successful call, even if the err is not nil.
DoWithFallbackAcceptable(req func() error, fallback Fallback, acceptable Acceptable) error DoWithFallbackAcceptable(req func() error, fallback func(err error) error, acceptable Acceptable) error
// DoWithFallbackAcceptableCtx runs the given request if the Breaker accepts it when ctx isn't done.
DoWithFallbackAcceptableCtx(ctx context.Context, req func() error, fallback Fallback,
acceptable Acceptable) error
} }
// Fallback is the func to be called if the request is rejected.
Fallback func(err error) error
// Option defines the method to customize a Breaker. // Option defines the method to customize a Breaker.
Option func(breaker *circuitBreaker) Option func(breaker *circuitBreaker)
@@ -98,12 +86,12 @@ type (
internalThrottle interface { internalThrottle interface {
allow() (internalPromise, error) allow() (internalPromise, error)
doReq(req func() error, fallback Fallback, acceptable Acceptable) error doReq(req func() error, fallback func(err error) error, acceptable Acceptable) error
} }
throttle interface { throttle interface {
allow() (Promise, error) allow() (Promise, error)
doReq(req func() error, fallback Fallback, acceptable Acceptable) error doReq(req func() error, fallback func(err error) error, acceptable Acceptable) error
} }
) )
@@ -126,71 +114,23 @@ func (cb *circuitBreaker) Allow() (Promise, error) {
return cb.throttle.allow() return cb.throttle.allow()
} }
func (cb *circuitBreaker) AllowCtx(ctx context.Context) (Promise, error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
return cb.Allow()
}
}
func (cb *circuitBreaker) Do(req func() error) error { func (cb *circuitBreaker) Do(req func() error) error {
return cb.throttle.doReq(req, nil, defaultAcceptable) return cb.throttle.doReq(req, nil, defaultAcceptable)
} }
func (cb *circuitBreaker) DoCtx(ctx context.Context, req func() error) error {
select {
case <-ctx.Done():
return ctx.Err()
default:
return cb.Do(req)
}
}
func (cb *circuitBreaker) DoWithAcceptable(req func() error, acceptable Acceptable) error { func (cb *circuitBreaker) DoWithAcceptable(req func() error, acceptable Acceptable) error {
return cb.throttle.doReq(req, nil, acceptable) return cb.throttle.doReq(req, nil, acceptable)
} }
func (cb *circuitBreaker) DoWithAcceptableCtx(ctx context.Context, req func() error, func (cb *circuitBreaker) DoWithFallback(req func() error, fallback func(err error) error) error {
acceptable Acceptable) error {
select {
case <-ctx.Done():
return ctx.Err()
default:
return cb.DoWithAcceptable(req, acceptable)
}
}
func (cb *circuitBreaker) DoWithFallback(req func() error, fallback Fallback) error {
return cb.throttle.doReq(req, fallback, defaultAcceptable) return cb.throttle.doReq(req, fallback, defaultAcceptable)
} }
func (cb *circuitBreaker) DoWithFallbackCtx(ctx context.Context, req func() error, func (cb *circuitBreaker) DoWithFallbackAcceptable(req func() error, fallback func(err error) error,
fallback Fallback) error {
select {
case <-ctx.Done():
return ctx.Err()
default:
return cb.DoWithFallback(req, fallback)
}
}
func (cb *circuitBreaker) DoWithFallbackAcceptable(req func() error, fallback Fallback,
acceptable Acceptable) error { acceptable Acceptable) error {
return cb.throttle.doReq(req, fallback, acceptable) return cb.throttle.doReq(req, fallback, acceptable)
} }
func (cb *circuitBreaker) DoWithFallbackAcceptableCtx(ctx context.Context, req func() error,
fallback Fallback, acceptable Acceptable) error {
select {
case <-ctx.Done():
return ctx.Err()
default:
return cb.DoWithFallbackAcceptable(req, fallback, acceptable)
}
}
func (cb *circuitBreaker) Name() string { func (cb *circuitBreaker) Name() string {
return cb.name return cb.name
} }
@@ -228,7 +168,7 @@ func (lt loggedThrottle) allow() (Promise, error) {
}, lt.logError(err) }, lt.logError(err)
} }
func (lt loggedThrottle) doReq(req func() error, fallback Fallback, acceptable Acceptable) error { func (lt loggedThrottle) doReq(req func() error, fallback func(err error) error, acceptable Acceptable) error {
return lt.logError(lt.internalThrottle.doReq(req, fallback, func(err error) bool { return lt.logError(lt.internalThrottle.doReq(req, fallback, func(err error) bool {
accept := acceptable(err) accept := acceptable(err)
if !accept && err != nil { if !accept && err != nil {
@@ -239,7 +179,7 @@ func (lt loggedThrottle) doReq(req func() error, fallback Fallback, acceptable A
} }
func (lt loggedThrottle) logError(err error) error { func (lt loggedThrottle) logError(err error) error {
if errors.Is(err, ErrServiceUnavailable) { if err == ErrServiceUnavailable {
// if circuit open, not possible to have empty error window // if circuit open, not possible to have empty error window
stat.Report(fmt.Sprintf( stat.Report(fmt.Sprintf(
"proc(%s/%d), callee: %s, breaker is open and requests dropped\nlast errors:\n%s", "proc(%s/%d), callee: %s, breaker is open and requests dropped\nlast errors:\n%s",
@@ -258,14 +198,14 @@ type errorWindow struct {
func (ew *errorWindow) add(reason string) { func (ew *errorWindow) add(reason string) {
ew.lock.Lock() ew.lock.Lock()
ew.reasons[ew.index] = fmt.Sprintf("%s %s", time.Now().Format(time.TimeOnly), reason) ew.reasons[ew.index] = fmt.Sprintf("%s %s", time.Now().Format(timeFormat), reason)
ew.index = (ew.index + 1) % numHistoryReasons ew.index = (ew.index + 1) % numHistoryReasons
ew.count = min(ew.count+1, numHistoryReasons) ew.count = mathx.MinInt(ew.count+1, numHistoryReasons)
ew.lock.Unlock() ew.lock.Unlock()
} }
func (ew *errorWindow) String() string { func (ew *errorWindow) String() string {
reasons := make([]string, 0, ew.count) var reasons []string
ew.lock.Lock() ew.lock.Lock()
// reverse order // reverse order

View File

@@ -1,13 +1,11 @@
package breaker package breaker
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"strconv" "strconv"
"strings" "strings"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/stat" "github.com/zeromicro/go-zero/core/stat"
@@ -18,274 +16,10 @@ func init() {
} }
func TestCircuitBreaker_Allow(t *testing.T) { func TestCircuitBreaker_Allow(t *testing.T) {
t.Run("allow", func(t *testing.T) { b := NewBreaker()
b := NewBreaker() assert.True(t, len(b.Name()) > 0)
assert.True(t, len(b.Name()) > 0) _, err := b.Allow()
_, err := b.Allow() assert.Nil(t, err)
assert.Nil(t, err)
})
t.Run("allow with ctx", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
_, err := b.AllowCtx(context.Background())
assert.Nil(t, err)
})
t.Run("allow with ctx timeout", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
ctx, cancel := context.WithTimeout(context.Background(), time.Microsecond)
defer cancel()
time.Sleep(time.Millisecond)
_, err := b.AllowCtx(ctx)
assert.ErrorIs(t, err, context.DeadlineExceeded)
})
t.Run("allow with ctx cancel", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
for i := 0; i < 100; i++ {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
cancel()
_, err := b.AllowCtx(ctx)
assert.ErrorIs(t, err, context.Canceled)
}
_, err := b.AllowCtx(context.Background())
assert.NoError(t, err)
})
}
func TestCircuitBreaker_Do(t *testing.T) {
t.Run("do", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
err := b.Do(func() error {
return nil
})
assert.Nil(t, err)
})
t.Run("do with ctx", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
err := b.DoCtx(context.Background(), func() error {
return nil
})
assert.Nil(t, err)
})
t.Run("do with ctx timeout", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
ctx, cancel := context.WithTimeout(context.Background(), time.Microsecond)
defer cancel()
time.Sleep(time.Millisecond)
err := b.DoCtx(ctx, func() error {
return nil
})
assert.ErrorIs(t, err, context.DeadlineExceeded)
})
t.Run("do with ctx cancel", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
for i := 0; i < 100; i++ {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
cancel()
err := b.DoCtx(ctx, func() error {
return nil
})
assert.ErrorIs(t, err, context.Canceled)
}
assert.NoError(t, b.DoCtx(context.Background(), func() error {
return nil
}))
})
}
func TestCircuitBreaker_DoWithAcceptable(t *testing.T) {
t.Run("doWithAcceptable", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
err := b.DoWithAcceptable(func() error {
return nil
}, func(err error) bool {
return true
})
assert.Nil(t, err)
})
t.Run("doWithAcceptable with ctx", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
err := b.DoWithAcceptableCtx(context.Background(), func() error {
return nil
}, func(err error) bool {
return true
})
assert.Nil(t, err)
})
t.Run("doWithAcceptable with ctx timeout", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
ctx, cancel := context.WithTimeout(context.Background(), time.Microsecond)
defer cancel()
time.Sleep(time.Millisecond)
err := b.DoWithAcceptableCtx(ctx, func() error {
return nil
}, func(err error) bool {
return true
})
assert.ErrorIs(t, err, context.DeadlineExceeded)
})
t.Run("doWithAcceptable with ctx cancel", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
for i := 0; i < 100; i++ {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
cancel()
err := b.DoWithAcceptableCtx(ctx, func() error {
return nil
}, func(err error) bool {
return true
})
assert.ErrorIs(t, err, context.Canceled)
}
assert.NoError(t, b.DoWithAcceptableCtx(context.Background(), func() error {
return nil
}, func(err error) bool {
return true
}))
})
}
func TestCircuitBreaker_DoWithFallback(t *testing.T) {
t.Run("doWithFallback", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
err := b.DoWithFallback(func() error {
return nil
}, func(err error) error {
return err
})
assert.Nil(t, err)
})
t.Run("doWithFallback with ctx", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
err := b.DoWithFallbackCtx(context.Background(), func() error {
return nil
}, func(err error) error {
return err
})
assert.Nil(t, err)
})
t.Run("doWithFallback with ctx timeout", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
ctx, cancel := context.WithTimeout(context.Background(), time.Microsecond)
defer cancel()
time.Sleep(time.Millisecond)
err := b.DoWithFallbackCtx(ctx, func() error {
return nil
}, func(err error) error {
return err
})
assert.ErrorIs(t, err, context.DeadlineExceeded)
})
t.Run("doWithFallback with ctx cancel", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
for i := 0; i < 100; i++ {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
cancel()
err := b.DoWithFallbackCtx(ctx, func() error {
return nil
}, func(err error) error {
return err
})
assert.ErrorIs(t, err, context.Canceled)
}
assert.NoError(t, b.DoWithFallbackCtx(context.Background(), func() error {
return nil
}, func(err error) error {
return err
}))
})
}
func TestCircuitBreaker_DoWithFallbackAcceptable(t *testing.T) {
t.Run("doWithFallbackAcceptable", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
err := b.DoWithFallbackAcceptable(func() error {
return nil
}, func(err error) error {
return err
}, func(err error) bool {
return true
})
assert.Nil(t, err)
})
t.Run("doWithFallbackAcceptable with ctx", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
err := b.DoWithFallbackAcceptableCtx(context.Background(), func() error {
return nil
}, func(err error) error {
return err
}, func(err error) bool {
return true
})
assert.Nil(t, err)
})
t.Run("doWithFallbackAcceptable with ctx timeout", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
ctx, cancel := context.WithTimeout(context.Background(), time.Microsecond)
defer cancel()
time.Sleep(time.Millisecond)
err := b.DoWithFallbackAcceptableCtx(ctx, func() error {
return nil
}, func(err error) error {
return err
}, func(err error) bool {
return true
})
assert.ErrorIs(t, err, context.DeadlineExceeded)
})
t.Run("doWithFallbackAcceptable with ctx cancel", func(t *testing.T) {
b := NewBreaker()
assert.True(t, len(b.Name()) > 0)
for i := 0; i < 100; i++ {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
cancel()
err := b.DoWithFallbackAcceptableCtx(ctx, func() error {
return nil
}, func(err error) error {
return err
}, func(err error) bool {
return true
})
assert.ErrorIs(t, err, context.Canceled)
}
assert.NoError(t, b.DoWithFallbackAcceptableCtx(context.Background(), func() error {
return nil
}, func(err error) error {
return err
}, func(err error) bool {
return true
}))
})
} }
func TestLogReason(t *testing.T) { func TestLogReason(t *testing.T) {

View File

@@ -1,9 +1,6 @@
package breaker package breaker
import ( import "sync"
"context"
"sync"
)
var ( var (
lock sync.RWMutex lock sync.RWMutex
@@ -17,13 +14,6 @@ func Do(name string, req func() error) error {
}) })
} }
// DoCtx calls Breaker.DoCtx on the Breaker with given name.
func DoCtx(ctx context.Context, name string, req func() error) error {
return do(name, func(b Breaker) error {
return b.DoCtx(ctx, req)
})
}
// DoWithAcceptable calls Breaker.DoWithAcceptable on the Breaker with given name. // DoWithAcceptable calls Breaker.DoWithAcceptable on the Breaker with given name.
func DoWithAcceptable(name string, req func() error, acceptable Acceptable) error { func DoWithAcceptable(name string, req func() error, acceptable Acceptable) error {
return do(name, func(b Breaker) error { return do(name, func(b Breaker) error {
@@ -31,44 +21,21 @@ func DoWithAcceptable(name string, req func() error, acceptable Acceptable) erro
}) })
} }
// DoWithAcceptableCtx calls Breaker.DoWithAcceptableCtx on the Breaker with given name.
func DoWithAcceptableCtx(ctx context.Context, name string, req func() error,
acceptable Acceptable) error {
return do(name, func(b Breaker) error {
return b.DoWithAcceptableCtx(ctx, req, acceptable)
})
}
// DoWithFallback calls Breaker.DoWithFallback on the Breaker with given name. // DoWithFallback calls Breaker.DoWithFallback on the Breaker with given name.
func DoWithFallback(name string, req func() error, fallback Fallback) error { func DoWithFallback(name string, req func() error, fallback func(err error) error) error {
return do(name, func(b Breaker) error { return do(name, func(b Breaker) error {
return b.DoWithFallback(req, fallback) return b.DoWithFallback(req, fallback)
}) })
} }
// DoWithFallbackCtx calls Breaker.DoWithFallbackCtx on the Breaker with given name.
func DoWithFallbackCtx(ctx context.Context, name string, req func() error, fallback Fallback) error {
return do(name, func(b Breaker) error {
return b.DoWithFallbackCtx(ctx, req, fallback)
})
}
// DoWithFallbackAcceptable calls Breaker.DoWithFallbackAcceptable on the Breaker with given name. // DoWithFallbackAcceptable calls Breaker.DoWithFallbackAcceptable on the Breaker with given name.
func DoWithFallbackAcceptable(name string, req func() error, fallback Fallback, func DoWithFallbackAcceptable(name string, req func() error, fallback func(err error) error,
acceptable Acceptable) error { acceptable Acceptable) error {
return do(name, func(b Breaker) error { return do(name, func(b Breaker) error {
return b.DoWithFallbackAcceptable(req, fallback, acceptable) return b.DoWithFallbackAcceptable(req, fallback, acceptable)
}) })
} }
// DoWithFallbackAcceptableCtx calls Breaker.DoWithFallbackAcceptableCtx on the Breaker with given name.
func DoWithFallbackAcceptableCtx(ctx context.Context, name string, req func() error,
fallback Fallback, acceptable Acceptable) error {
return do(name, func(b Breaker) error {
return b.DoWithFallbackAcceptableCtx(ctx, req, fallback, acceptable)
})
}
// GetBreaker returns the Breaker with the given name. // GetBreaker returns the Breaker with the given name.
func GetBreaker(name string) Breaker { func GetBreaker(name string) Breaker {
lock.RLock() lock.RLock()
@@ -92,7 +59,7 @@ func GetBreaker(name string) Breaker {
// NoBreakerFor disables the circuit breaker for the given name. // NoBreakerFor disables the circuit breaker for the given name.
func NoBreakerFor(name string) { func NoBreakerFor(name string) {
lock.Lock() lock.Lock()
breakers[name] = NopBreaker() breakers[name] = newNoOpBreaker()
lock.Unlock() lock.Unlock()
} }

View File

@@ -1,7 +1,6 @@
package breaker package breaker
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"testing" "testing"
@@ -23,9 +22,6 @@ func TestBreakersDo(t *testing.T) {
assert.Equal(t, errDummy, Do("any", func() error { assert.Equal(t, errDummy, Do("any", func() error {
return errDummy return errDummy
})) }))
assert.Equal(t, errDummy, DoCtx(context.Background(), "any", func() error {
return errDummy
}))
} }
func TestBreakersDoWithAcceptable(t *testing.T) { func TestBreakersDoWithAcceptable(t *testing.T) {
@@ -34,7 +30,7 @@ func TestBreakersDoWithAcceptable(t *testing.T) {
assert.Equal(t, errDummy, GetBreaker("anyone").DoWithAcceptable(func() error { assert.Equal(t, errDummy, GetBreaker("anyone").DoWithAcceptable(func() error {
return errDummy return errDummy
}, func(err error) bool { }, func(err error) bool {
return err == nil || errors.Is(err, errDummy) return err == nil || err == errDummy
})) }))
} }
verify(t, func() bool { verify(t, func() bool {
@@ -42,13 +38,6 @@ func TestBreakersDoWithAcceptable(t *testing.T) {
return nil return nil
}) == nil }) == nil
}) })
verify(t, func() bool {
return DoWithAcceptableCtx(context.Background(), "anyone", func() error {
return nil
}, func(err error) bool {
return true
}) == nil
})
for i := 0; i < 10000; i++ { for i := 0; i < 10000; i++ {
err := DoWithAcceptable("another", func() error { err := DoWithAcceptable("another", func() error {
@@ -56,12 +45,12 @@ func TestBreakersDoWithAcceptable(t *testing.T) {
}, func(err error) bool { }, func(err error) bool {
return err == nil return err == nil
}) })
assert.True(t, errors.Is(err, errDummy) || errors.Is(err, ErrServiceUnavailable)) assert.True(t, err == errDummy || err == ErrServiceUnavailable)
} }
verify(t, func() bool { verify(t, func() bool {
return errors.Is(Do("another", func() error { return ErrServiceUnavailable == Do("another", func() error {
return nil return nil
}), ErrServiceUnavailable) })
}) })
} }
@@ -86,24 +75,18 @@ func TestBreakersFallback(t *testing.T) {
}, func(err error) error { }, func(err error) error {
return nil return nil
}) })
assert.True(t, err == nil || errors.Is(err, errDummy)) assert.True(t, err == nil || err == errDummy)
err = DoWithFallbackCtx(context.Background(), "fallback", func() error {
return errDummy
}, func(err error) error {
return nil
})
assert.True(t, err == nil || errors.Is(err, errDummy))
} }
verify(t, func() bool { verify(t, func() bool {
return errors.Is(Do("fallback", func() error { return ErrServiceUnavailable == Do("fallback", func() error {
return nil return nil
}), ErrServiceUnavailable) })
}) })
} }
func TestBreakersAcceptableFallback(t *testing.T) { func TestBreakersAcceptableFallback(t *testing.T) {
errDummy := errors.New("any") errDummy := errors.New("any")
for i := 0; i < 5000; i++ { for i := 0; i < 10000; i++ {
err := DoWithFallbackAcceptable("acceptablefallback", func() error { err := DoWithFallbackAcceptable("acceptablefallback", func() error {
return errDummy return errDummy
}, func(err error) error { }, func(err error) error {
@@ -111,20 +94,12 @@ func TestBreakersAcceptableFallback(t *testing.T) {
}, func(err error) bool { }, func(err error) bool {
return err == nil return err == nil
}) })
assert.True(t, err == nil || errors.Is(err, errDummy)) assert.True(t, err == nil || err == errDummy)
err = DoWithFallbackAcceptableCtx(context.Background(), "acceptablefallback", func() error {
return errDummy
}, func(err error) error {
return nil
}, func(err error) bool {
return err == nil
})
assert.True(t, err == nil || errors.Is(err, errDummy))
} }
verify(t, func() bool { verify(t, func() bool {
return errors.Is(Do("acceptablefallback", func() error { return ErrServiceUnavailable == Do("acceptablefallback", func() error {
return nil return nil
}), ErrServiceUnavailable) })
}) })
} }
@@ -135,5 +110,5 @@ func verify(t *testing.T, fn func() bool) {
count++ count++
} }
} }
assert.True(t, count >= 75, fmt.Sprintf("should be greater than 75, actual %d", count)) assert.True(t, count >= 80, fmt.Sprintf("should be greater than 80, actual %d", count))
} }

View File

@@ -1,48 +0,0 @@
package breaker
const (
success = iota
fail
drop
)
// bucket defines the bucket that holds sum and num of additions.
type bucket struct {
Sum int64
Success int64
Failure int64
Drop int64
}
func (b *bucket) Add(v int64) {
switch v {
case fail:
b.fail()
case drop:
b.drop()
default:
b.succeed()
}
}
func (b *bucket) Reset() {
b.Sum = 0
b.Success = 0
b.Failure = 0
b.Drop = 0
}
func (b *bucket) drop() {
b.Sum++
b.Drop++
}
func (b *bucket) fail() {
b.Sum++
b.Failure++
}
func (b *bucket) succeed() {
b.Sum++
b.Success++
}

View File

@@ -1,43 +0,0 @@
package breaker
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestBucketAdd(t *testing.T) {
b := &bucket{}
// Test succeed
b.Add(0) // Using 0 for success
assert.Equal(t, int64(1), b.Sum, "Sum should be incremented")
assert.Equal(t, int64(1), b.Success, "Success should be incremented")
assert.Equal(t, int64(0), b.Failure, "Failure should not be incremented")
assert.Equal(t, int64(0), b.Drop, "Drop should not be incremented")
// Test failure
b.Add(fail)
assert.Equal(t, int64(2), b.Sum, "Sum should be incremented")
assert.Equal(t, int64(1), b.Failure, "Failure should be incremented")
assert.Equal(t, int64(0), b.Drop, "Drop should not be incremented")
// Test drop
b.Add(drop)
assert.Equal(t, int64(3), b.Sum, "Sum should be incremented")
assert.Equal(t, int64(1), b.Drop, "Drop should be incremented")
}
func TestBucketReset(t *testing.T) {
b := &bucket{
Sum: 3,
Success: 1,
Failure: 1,
Drop: 1,
}
b.Reset()
assert.Equal(t, int64(0), b.Sum, "Sum should be reset to 0")
assert.Equal(t, int64(0), b.Success, "Success should be reset to 0")
assert.Equal(t, int64(0), b.Failure, "Failure should be reset to 0")
assert.Equal(t, int64(0), b.Drop, "Drop should be reset to 0")
}

View File

@@ -1,87 +1,57 @@
package breaker package breaker
import ( import (
"math"
"time" "time"
"github.com/zeromicro/go-zero/core/collection" "github.com/zeromicro/go-zero/core/collection"
"github.com/zeromicro/go-zero/core/mathx" "github.com/zeromicro/go-zero/core/mathx"
"github.com/zeromicro/go-zero/core/syncx"
"github.com/zeromicro/go-zero/core/timex"
) )
const ( const (
// 250ms for bucket duration // 250ms for bucket duration
window = time.Second * 10 window = time.Second * 10
buckets = 40 buckets = 40
forcePassDuration = time.Second k = 1.5
k = 1.5 protection = 5
minK = 1.1
protection = 5
) )
// googleBreaker is a netflixBreaker pattern from google. // googleBreaker is a netflixBreaker pattern from google.
// see Client-Side Throttling section in https://landing.google.com/sre/sre-book/chapters/handling-overload/ // see Client-Side Throttling section in https://landing.google.com/sre/sre-book/chapters/handling-overload/
type ( type googleBreaker struct {
googleBreaker struct { k float64
k float64 stat *collection.RollingWindow
stat *collection.RollingWindow[int64, *bucket] proba *mathx.Proba
proba *mathx.Proba }
lastPass *syncx.AtomicDuration
}
windowResult struct {
accepts int64
total int64
failingBuckets int64
workingBuckets int64
}
)
func newGoogleBreaker() *googleBreaker { func newGoogleBreaker() *googleBreaker {
bucketDuration := time.Duration(int64(window) / int64(buckets)) bucketDuration := time.Duration(int64(window) / int64(buckets))
st := collection.NewRollingWindow[int64, *bucket](func() *bucket { st := collection.NewRollingWindow(buckets, bucketDuration)
return new(bucket)
}, buckets, bucketDuration)
return &googleBreaker{ return &googleBreaker{
stat: st, stat: st,
k: k, k: k,
proba: mathx.NewProba(), proba: mathx.NewProba(),
lastPass: syncx.NewAtomicDuration(),
} }
} }
func (b *googleBreaker) accept() error { func (b *googleBreaker) accept() error {
var w float64 accepts, total := b.history()
history := b.history() weightedAccepts := b.k * float64(accepts)
w = b.k - (b.k-minK)*float64(history.failingBuckets)/buckets
weightedAccepts := mathx.AtLeast(w, minK) * float64(history.accepts)
// https://landing.google.com/sre/sre-book/chapters/handling-overload/#eq2101 // https://landing.google.com/sre/sre-book/chapters/handling-overload/#eq2101
// for better performance, no need to care about the negative ratio dropRatio := math.Max(0, (float64(total-protection)-weightedAccepts)/float64(total+1))
dropRatio := (float64(history.total-protection) - weightedAccepts) / float64(history.total+1)
if dropRatio <= 0 { if dropRatio <= 0 {
return nil return nil
} }
lastPass := b.lastPass.Load()
if lastPass > 0 && timex.Since(lastPass) > forcePassDuration {
b.lastPass.Set(timex.Now())
return nil
}
dropRatio *= float64(buckets-history.workingBuckets) / buckets
if b.proba.TrueOnProba(dropRatio) { if b.proba.TrueOnProba(dropRatio) {
return ErrServiceUnavailable return ErrServiceUnavailable
} }
b.lastPass.Set(timex.Now())
return nil return nil
} }
func (b *googleBreaker) allow() (internalPromise, error) { func (b *googleBreaker) allow() (internalPromise, error) {
if err := b.accept(); err != nil { if err := b.accept(); err != nil {
b.markDrop()
return nil, err return nil, err
} }
@@ -90,9 +60,8 @@ func (b *googleBreaker) allow() (internalPromise, error) {
}, nil }, nil
} }
func (b *googleBreaker) doReq(req func() error, fallback Fallback, acceptable Acceptable) error { func (b *googleBreaker) doReq(req func() error, fallback func(err error) error, acceptable Acceptable) error {
if err := b.accept(); err != nil { if err := b.accept(); err != nil {
b.markDrop()
if fallback != nil { if fallback != nil {
return fallback(err) return fallback(err)
} }
@@ -100,55 +69,38 @@ func (b *googleBreaker) doReq(req func() error, fallback Fallback, acceptable Ac
return err return err
} }
var succ bool
defer func() { defer func() {
// if req() panic, success is false, mark as failure if e := recover(); e != nil {
if succ {
b.markSuccess()
} else {
b.markFailure() b.markFailure()
panic(e)
} }
}() }()
err := req() err := req()
if acceptable(err) { if acceptable(err) {
succ = true b.markSuccess()
} else {
b.markFailure()
} }
return err return err
} }
func (b *googleBreaker) markDrop() { func (b *googleBreaker) markSuccess() {
b.stat.Add(drop) b.stat.Add(1)
} }
func (b *googleBreaker) markFailure() { func (b *googleBreaker) markFailure() {
b.stat.Add(fail) b.stat.Add(0)
} }
func (b *googleBreaker) markSuccess() { func (b *googleBreaker) history() (accepts, total int64) {
b.stat.Add(success) b.stat.Reduce(func(b *collection.Bucket) {
} accepts += int64(b.Sum)
total += b.Count
func (b *googleBreaker) history() windowResult {
var result windowResult
b.stat.Reduce(func(b *bucket) {
result.accepts += b.Success
result.total += b.Sum
if b.Failure > 0 {
result.workingBuckets = 0
} else if b.Success > 0 {
result.workingBuckets++
}
if b.Success > 0 {
result.failingBuckets = 0
} else if b.Failure > 0 {
result.failingBuckets++
}
}) })
return result return
} }
type googlePromise struct { type googlePromise struct {

View File

@@ -10,7 +10,6 @@ import (
"github.com/zeromicro/go-zero/core/collection" "github.com/zeromicro/go-zero/core/collection"
"github.com/zeromicro/go-zero/core/mathx" "github.com/zeromicro/go-zero/core/mathx"
"github.com/zeromicro/go-zero/core/stat" "github.com/zeromicro/go-zero/core/stat"
"github.com/zeromicro/go-zero/core/syncx"
) )
const ( const (
@@ -23,14 +22,11 @@ func init() {
} }
func getGoogleBreaker() *googleBreaker { func getGoogleBreaker() *googleBreaker {
st := collection.NewRollingWindow[int64, *bucket](func() *bucket { st := collection.NewRollingWindow(testBuckets, testInterval)
return new(bucket)
}, testBuckets, testInterval)
return &googleBreaker{ return &googleBreaker{
stat: st, stat: st,
k: 5, k: 5,
proba: mathx.NewProba(), proba: mathx.NewProba(),
lastPass: syncx.NewAtomicDuration(),
} }
} }
@@ -67,33 +63,6 @@ func TestGoogleBreakerOpen(t *testing.T) {
}) })
} }
func TestGoogleBreakerRecover(t *testing.T) {
st := collection.NewRollingWindow[int64, *bucket](func() *bucket {
return new(bucket)
}, testBuckets*2, testInterval)
b := &googleBreaker{
stat: st,
k: k,
proba: mathx.NewProba(),
lastPass: syncx.NewAtomicDuration(),
}
for i := 0; i < testBuckets; i++ {
for j := 0; j < 100; j++ {
b.stat.Add(1)
}
time.Sleep(testInterval)
}
for i := 0; i < testBuckets; i++ {
for j := 0; j < 100; j++ {
b.stat.Add(0)
}
time.Sleep(testInterval)
}
verify(t, func() bool {
return b.accept() == nil
})
}
func TestGoogleBreakerFallback(t *testing.T) { func TestGoogleBreakerFallback(t *testing.T) {
b := getGoogleBreaker() b := getGoogleBreaker()
markSuccess(b, 1) markSuccess(b, 1)
@@ -120,50 +89,13 @@ func TestGoogleBreakerReject(t *testing.T) {
}, nil, defaultAcceptable)) }, nil, defaultAcceptable))
} }
func TestGoogleBreakerMoreFallingBuckets(t *testing.T) {
t.Parallel()
t.Run("more falling buckets", func(t *testing.T) {
b := getGoogleBreaker()
func() {
stopChan := time.After(testInterval * 6)
for {
time.Sleep(time.Millisecond)
select {
case <-stopChan:
return
default:
assert.Error(t, b.doReq(func() error {
return errors.New("foo")
}, func(err error) error {
return err
}, func(err error) bool {
return err == nil
}))
}
}
}()
var count int
for i := 0; i < 100; i++ {
if errors.Is(b.doReq(func() error {
return ErrServiceUnavailable
}, nil, defaultAcceptable), ErrServiceUnavailable) {
count++
}
}
assert.True(t, count > 90)
})
}
func TestGoogleBreakerAcceptable(t *testing.T) { func TestGoogleBreakerAcceptable(t *testing.T) {
b := getGoogleBreaker() b := getGoogleBreaker()
errAcceptable := errors.New("any") errAcceptable := errors.New("any")
assert.Equal(t, errAcceptable, b.doReq(func() error { assert.Equal(t, errAcceptable, b.doReq(func() error {
return errAcceptable return errAcceptable
}, nil, func(err error) bool { }, nil, func(err error) bool {
return errors.Is(err, errAcceptable) return err == errAcceptable
})) }))
} }
@@ -173,7 +105,7 @@ func TestGoogleBreakerNotAcceptable(t *testing.T) {
assert.Equal(t, errAcceptable, b.doReq(func() error { assert.Equal(t, errAcceptable, b.doReq(func() error {
return errAcceptable return errAcceptable
}, nil, func(err error) bool { }, nil, func(err error) bool {
return !errors.Is(err, errAcceptable) return err != errAcceptable
})) }))
} }
@@ -232,38 +164,41 @@ func TestGoogleBreakerSelfProtection(t *testing.T) {
} }
func TestGoogleBreakerHistory(t *testing.T) { func TestGoogleBreakerHistory(t *testing.T) {
var b *googleBreaker
var accepts, total int64
sleep := testInterval sleep := testInterval
t.Run("accepts == total", func(t *testing.T) { t.Run("accepts == total", func(t *testing.T) {
b := getGoogleBreaker() b = getGoogleBreaker()
markSuccessWithDuration(b, 10, sleep/2) markSuccessWithDuration(b, 10, sleep/2)
result := b.history() accepts, total = b.history()
assert.Equal(t, int64(10), result.accepts) assert.Equal(t, int64(10), accepts)
assert.Equal(t, int64(10), result.total) assert.Equal(t, int64(10), total)
}) })
t.Run("fail == total", func(t *testing.T) { t.Run("fail == total", func(t *testing.T) {
b := getGoogleBreaker() b = getGoogleBreaker()
markFailedWithDuration(b, 10, sleep/2) markFailedWithDuration(b, 10, sleep/2)
result := b.history() accepts, total = b.history()
assert.Equal(t, int64(0), result.accepts) assert.Equal(t, int64(0), accepts)
assert.Equal(t, int64(10), result.total) assert.Equal(t, int64(10), total)
}) })
t.Run("accepts = 1/2 * total, fail = 1/2 * total", func(t *testing.T) { t.Run("accepts = 1/2 * total, fail = 1/2 * total", func(t *testing.T) {
b := getGoogleBreaker() b = getGoogleBreaker()
markFailedWithDuration(b, 5, sleep/2) markFailedWithDuration(b, 5, sleep/2)
markSuccessWithDuration(b, 5, sleep/2) markSuccessWithDuration(b, 5, sleep/2)
result := b.history() accepts, total = b.history()
assert.Equal(t, int64(5), result.accepts) assert.Equal(t, int64(5), accepts)
assert.Equal(t, int64(10), result.total) assert.Equal(t, int64(10), total)
}) })
t.Run("auto reset rolling counter", func(t *testing.T) { t.Run("auto reset rolling counter", func(t *testing.T) {
b := getGoogleBreaker() b = getGoogleBreaker()
time.Sleep(testInterval * testBuckets) time.Sleep(testInterval * testBuckets)
result := b.history() accepts, total = b.history()
assert.Equal(t, int64(0), result.accepts) assert.Equal(t, int64(0), accepts)
assert.Equal(t, int64(0), result.total) assert.Equal(t, int64(0), total)
}) })
} }
@@ -271,7 +206,7 @@ func BenchmarkGoogleBreakerAllow(b *testing.B) {
breaker := getGoogleBreaker() breaker := getGoogleBreaker()
b.ResetTimer() b.ResetTimer()
for i := 0; i <= b.N; i++ { for i := 0; i <= b.N; i++ {
_ = breaker.accept() breaker.accept()
if i%2 == 0 { if i%2 == 0 {
breaker.markSuccess() breaker.markSuccess()
} else { } else {
@@ -280,16 +215,6 @@ func BenchmarkGoogleBreakerAllow(b *testing.B) {
} }
} }
func BenchmarkGoogleBreakerDoReq(b *testing.B) {
breaker := getGoogleBreaker()
b.ResetTimer()
for i := 0; i <= b.N; i++ {
_ = breaker.doReq(func() error {
return nil
}, nil, defaultAcceptable)
}
}
func markSuccess(b *googleBreaker, count int) { func markSuccess(b *googleBreaker, count int) {
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
p, err := b.allow() p, err := b.allow()

View File

@@ -1,58 +1,35 @@
package breaker package breaker
import "context" const noOpBreakerName = "nopBreaker"
const nopBreakerName = "nopBreaker" type noOpBreaker struct{}
type nopBreaker struct{} func newNoOpBreaker() Breaker {
return noOpBreaker{}
// NopBreaker returns a breaker that never trigger breaker circuit.
func NopBreaker() Breaker {
return nopBreaker{}
} }
func (b nopBreaker) Name() string { func (b noOpBreaker) Name() string {
return nopBreakerName return noOpBreakerName
} }
func (b nopBreaker) Allow() (Promise, error) { func (b noOpBreaker) Allow() (Promise, error) {
return nopPromise{}, nil return nopPromise{}, nil
} }
func (b nopBreaker) AllowCtx(_ context.Context) (Promise, error) { func (b noOpBreaker) Do(req func() error) error {
return nopPromise{}, nil
}
func (b nopBreaker) Do(req func() error) error {
return req() return req()
} }
func (b nopBreaker) DoCtx(_ context.Context, req func() error) error { func (b noOpBreaker) DoWithAcceptable(req func() error, _ Acceptable) error {
return req() return req()
} }
func (b nopBreaker) DoWithAcceptable(req func() error, _ Acceptable) error { func (b noOpBreaker) DoWithFallback(req func() error, _ func(err error) error) error {
return req() return req()
} }
func (b nopBreaker) DoWithAcceptableCtx(_ context.Context, req func() error, _ Acceptable) error { func (b noOpBreaker) DoWithFallbackAcceptable(req func() error, _ func(err error) error,
return req() _ Acceptable) error {
}
func (b nopBreaker) DoWithFallback(req func() error, _ Fallback) error {
return req()
}
func (b nopBreaker) DoWithFallbackCtx(_ context.Context, req func() error, _ Fallback) error {
return req()
}
func (b nopBreaker) DoWithFallbackAcceptable(req func() error, _ Fallback, _ Acceptable) error {
return req()
}
func (b nopBreaker) DoWithFallbackAcceptableCtx(_ context.Context, req func() error,
_ Fallback, _ Acceptable) error {
return req() return req()
} }

View File

@@ -1,7 +1,6 @@
package breaker package breaker
import ( import (
"context"
"errors" "errors"
"testing" "testing"
@@ -9,11 +8,9 @@ import (
) )
func TestNopBreaker(t *testing.T) { func TestNopBreaker(t *testing.T) {
b := NopBreaker() b := newNoOpBreaker()
assert.Equal(t, nopBreakerName, b.Name()) assert.Equal(t, noOpBreakerName, b.Name())
_, err := b.Allow() p, err := b.Allow()
assert.Nil(t, err)
p, err := b.AllowCtx(context.Background())
assert.Nil(t, err) assert.Nil(t, err)
p.Accept() p.Accept()
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {
@@ -24,34 +21,18 @@ func TestNopBreaker(t *testing.T) {
assert.Nil(t, b.Do(func() error { assert.Nil(t, b.Do(func() error {
return nil return nil
})) }))
assert.Nil(t, b.DoCtx(context.Background(), func() error {
return nil
}))
assert.Nil(t, b.DoWithAcceptable(func() error { assert.Nil(t, b.DoWithAcceptable(func() error {
return nil return nil
}, defaultAcceptable)) }, defaultAcceptable))
assert.Nil(t, b.DoWithAcceptableCtx(context.Background(), func() error {
return nil
}, defaultAcceptable))
errDummy := errors.New("any") errDummy := errors.New("any")
assert.Equal(t, errDummy, b.DoWithFallback(func() error { assert.Equal(t, errDummy, b.DoWithFallback(func() error {
return errDummy return errDummy
}, func(err error) error { }, func(err error) error {
return nil return nil
})) }))
assert.Equal(t, errDummy, b.DoWithFallbackCtx(context.Background(), func() error {
return errDummy
}, func(err error) error {
return nil
}))
assert.Equal(t, errDummy, b.DoWithFallbackAcceptable(func() error { assert.Equal(t, errDummy, b.DoWithFallbackAcceptable(func() error {
return errDummy return errDummy
}, func(err error) error { }, func(err error) error {
return nil return nil
}, defaultAcceptable)) }, defaultAcceptable))
assert.Equal(t, errDummy, b.DoWithFallbackAcceptableCtx(context.Background(), func() error {
return errDummy
}, func(err error) error {
return nil
}, defaultAcceptable))
} }

View File

@@ -23,7 +23,7 @@ var (
zero = big.NewInt(0) zero = big.NewInt(0)
) )
// DhKey defines the Diffie-Hellman key. // DhKey defines the Diffie Hellman key.
type DhKey struct { type DhKey struct {
PriKey *big.Int PriKey *big.Int
PubKey *big.Int PubKey *big.Int
@@ -46,7 +46,7 @@ func ComputeKey(pubKey, priKey *big.Int) (*big.Int, error) {
return new(big.Int).Exp(pubKey, priKey, p), nil return new(big.Int).Exp(pubKey, priKey, p), nil
} }
// GenerateKey returns a Diffie-Hellman key. // GenerateKey returns a Diffie Hellman key.
func GenerateKey() (*DhKey, error) { func GenerateKey() (*DhKey, error) {
var err error var err error
var x *big.Int var x *big.Int

View File

@@ -2,8 +2,6 @@ package codec
import ( import (
"bytes" "bytes"
"compress/gzip"
"errors"
"fmt" "fmt"
"testing" "testing"
@@ -23,45 +21,3 @@ func TestGzip(t *testing.T) {
assert.True(t, len(bs) < buf.Len()) assert.True(t, len(bs) < buf.Len())
assert.Equal(t, buf.Bytes(), actual) assert.Equal(t, buf.Bytes(), actual)
} }
func TestGunzip(t *testing.T) {
tests := []struct {
name string
input []byte
expected []byte
expectedErr error
}{
{
name: "valid input",
input: func() []byte {
var buf bytes.Buffer
gz := gzip.NewWriter(&buf)
gz.Write([]byte("hello"))
gz.Close()
return buf.Bytes()
}(),
expected: []byte("hello"),
expectedErr: nil,
},
{
name: "invalid input",
input: []byte("invalid input"),
expected: nil,
expectedErr: gzip.ErrHeader,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
result, err := Gunzip(test.input)
if !bytes.Equal(result, test.expected) {
t.Errorf("unexpected result: %v", result)
}
if !errors.Is(err, test.expectedErr) {
t.Errorf("unexpected error: %v", err)
}
})
}
}

View File

@@ -2,7 +2,6 @@ package codec
import ( import (
"encoding/base64" "encoding/base64"
"os"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -42,7 +41,6 @@ func TestCryption(t *testing.T) {
file, err := fs.TempFilenameWithText(priKey) file, err := fs.TempFilenameWithText(priKey)
assert.Nil(t, err) assert.Nil(t, err)
defer os.Remove(file)
dec, err := NewRsaDecrypter(file) dec, err := NewRsaDecrypter(file)
assert.Nil(t, err) assert.Nil(t, err)
actual, err := dec.Decrypt(ret) actual, err := dec.Decrypt(ret)

View File

@@ -30,7 +30,7 @@ type (
Cache struct { Cache struct {
name string name string
lock sync.Mutex lock sync.Mutex
data map[string]any data map[string]interface{}
expire time.Duration expire time.Duration
timingWheel *TimingWheel timingWheel *TimingWheel
lruCache lru lruCache lru
@@ -43,7 +43,7 @@ type (
// NewCache returns a Cache with given expire. // NewCache returns a Cache with given expire.
func NewCache(expire time.Duration, opts ...CacheOption) (*Cache, error) { func NewCache(expire time.Duration, opts ...CacheOption) (*Cache, error) {
cache := &Cache{ cache := &Cache{
data: make(map[string]any), data: make(map[string]interface{}),
expire: expire, expire: expire,
lruCache: emptyLruCache, lruCache: emptyLruCache,
barrier: syncx.NewSingleFlight(), barrier: syncx.NewSingleFlight(),
@@ -59,7 +59,7 @@ func NewCache(expire time.Duration, opts ...CacheOption) (*Cache, error) {
} }
cache.stats = newCacheStat(cache.name, cache.size) cache.stats = newCacheStat(cache.name, cache.size)
timingWheel, err := NewTimingWheel(time.Second, slots, func(k, v any) { timingWheel, err := NewTimingWheel(time.Second, slots, func(k, v interface{}) {
key, ok := k.(string) key, ok := k.(string)
if !ok { if !ok {
return return
@@ -85,7 +85,7 @@ func (c *Cache) Del(key string) {
} }
// Get returns the item with the given key from c. // Get returns the item with the given key from c.
func (c *Cache) Get(key string) (any, bool) { func (c *Cache) Get(key string) (interface{}, bool) {
value, ok := c.doGet(key) value, ok := c.doGet(key)
if ok { if ok {
c.stats.IncrementHit() c.stats.IncrementHit()
@@ -97,12 +97,12 @@ func (c *Cache) Get(key string) (any, bool) {
} }
// Set sets value into c with key. // Set sets value into c with key.
func (c *Cache) Set(key string, value any) { func (c *Cache) Set(key string, value interface{}) {
c.SetWithExpire(key, value, c.expire) c.SetWithExpire(key, value, c.expire)
} }
// SetWithExpire sets value into c with key and expire with the given value. // SetWithExpire sets value into c with key and expire with the given value.
func (c *Cache) SetWithExpire(key string, value any, expire time.Duration) { func (c *Cache) SetWithExpire(key string, value interface{}, expire time.Duration) {
c.lock.Lock() c.lock.Lock()
_, ok := c.data[key] _, ok := c.data[key]
c.data[key] = value c.data[key] = value
@@ -120,16 +120,16 @@ func (c *Cache) SetWithExpire(key string, value any, expire time.Duration) {
// Take returns the item with the given key. // Take returns the item with the given key.
// If the item is in c, return it directly. // If the item is in c, return it directly.
// If not, use fetch method to get the item, set into c and return it. // If not, use fetch method to get the item, set into c and return it.
func (c *Cache) Take(key string, fetch func() (any, error)) (any, error) { func (c *Cache) Take(key string, fetch func() (interface{}, error)) (interface{}, error) {
if val, ok := c.doGet(key); ok { if val, ok := c.doGet(key); ok {
c.stats.IncrementHit() c.stats.IncrementHit()
return val, nil return val, nil
} }
var fresh bool var fresh bool
val, err := c.barrier.Do(key, func() (any, error) { val, err := c.barrier.Do(key, func() (interface{}, error) {
// because O(1) on map search in memory, and fetch is an IO query, // because O(1) on map search in memory, and fetch is an IO query
// so we do double-check, cache might be taken by another call // so we do double check, cache might be taken by another call
if val, ok := c.doGet(key); ok { if val, ok := c.doGet(key); ok {
return val, nil return val, nil
} }
@@ -157,7 +157,7 @@ func (c *Cache) Take(key string, fetch func() (any, error)) (any, error) {
return val, nil return val, nil
} }
func (c *Cache) doGet(key string) (any, bool) { func (c *Cache) doGet(key string) (interface{}, bool) {
c.lock.Lock() c.lock.Lock()
defer c.lock.Unlock() defer c.lock.Unlock()

View File

@@ -52,7 +52,7 @@ func TestCacheTake(t *testing.T) {
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
wg.Add(1) wg.Add(1)
go func() { go func() {
cache.Take("first", func() (any, error) { cache.Take("first", func() (interface{}, error) {
atomic.AddInt32(&count, 1) atomic.AddInt32(&count, 1)
time.Sleep(time.Millisecond * 100) time.Sleep(time.Millisecond * 100)
return "first element", nil return "first element", nil
@@ -76,7 +76,7 @@ func TestCacheTakeExists(t *testing.T) {
wg.Add(1) wg.Add(1)
go func() { go func() {
cache.Set("first", "first element") cache.Set("first", "first element")
cache.Take("first", func() (any, error) { cache.Take("first", func() (interface{}, error) {
atomic.AddInt32(&count, 1) atomic.AddInt32(&count, 1)
time.Sleep(time.Millisecond * 100) time.Sleep(time.Millisecond * 100)
return "first element", nil return "first element", nil
@@ -99,7 +99,7 @@ func TestCacheTakeError(t *testing.T) {
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
wg.Add(1) wg.Add(1)
go func() { go func() {
_, err := cache.Take("first", func() (any, error) { _, err := cache.Take("first", func() (interface{}, error) {
atomic.AddInt32(&count, 1) atomic.AddInt32(&count, 1)
time.Sleep(time.Millisecond * 100) time.Sleep(time.Millisecond * 100)
return "", errDummy return "", errDummy

View File

@@ -5,7 +5,7 @@ import "sync"
// A Queue is a FIFO queue. // A Queue is a FIFO queue.
type Queue struct { type Queue struct {
lock sync.Mutex lock sync.Mutex
elements []any elements []interface{}
size int size int
head int head int
tail int tail int
@@ -15,7 +15,7 @@ type Queue struct {
// NewQueue returns a Queue object. // NewQueue returns a Queue object.
func NewQueue(size int) *Queue { func NewQueue(size int) *Queue {
return &Queue{ return &Queue{
elements: make([]any, size), elements: make([]interface{}, size),
size: size, size: size,
} }
} }
@@ -30,12 +30,12 @@ func (q *Queue) Empty() bool {
} }
// Put puts element into q at the last position. // Put puts element into q at the last position.
func (q *Queue) Put(element any) { func (q *Queue) Put(element interface{}) {
q.lock.Lock() q.lock.Lock()
defer q.lock.Unlock() defer q.lock.Unlock()
if q.head == q.tail && q.count > 0 { if q.head == q.tail && q.count > 0 {
nodes := make([]any, len(q.elements)+q.size) nodes := make([]interface{}, len(q.elements)+q.size)
copy(nodes, q.elements[q.head:]) copy(nodes, q.elements[q.head:])
copy(nodes[len(q.elements)-q.head:], q.elements[:q.head]) copy(nodes[len(q.elements)-q.head:], q.elements[:q.head])
q.head = 0 q.head = 0
@@ -49,7 +49,7 @@ func (q *Queue) Put(element any) {
} }
// Take takes the first element out of q if not empty. // Take takes the first element out of q if not empty.
func (q *Queue) Take() (any, bool) { func (q *Queue) Take() (interface{}, bool) {
q.lock.Lock() q.lock.Lock()
defer q.lock.Unlock() defer q.lock.Unlock()

View File

@@ -4,7 +4,7 @@ import "sync"
// A Ring can be used as fixed size ring. // A Ring can be used as fixed size ring.
type Ring struct { type Ring struct {
elements []any elements []interface{}
index int index int
lock sync.RWMutex lock sync.RWMutex
} }
@@ -16,44 +16,36 @@ func NewRing(n int) *Ring {
} }
return &Ring{ return &Ring{
elements: make([]any, n), elements: make([]interface{}, n),
} }
} }
// Add adds v into r. // Add adds v into r.
func (r *Ring) Add(v any) { func (r *Ring) Add(v interface{}) {
r.lock.Lock() r.lock.Lock()
defer r.lock.Unlock() defer r.lock.Unlock()
rlen := len(r.elements) r.elements[r.index%len(r.elements)] = v
r.elements[r.index%rlen] = v
r.index++ r.index++
// prevent ring index overflow
if r.index >= rlen<<1 {
r.index -= rlen
}
} }
// Take takes all items from r. // Take takes all items from r.
func (r *Ring) Take() []any { func (r *Ring) Take() []interface{} {
r.lock.RLock() r.lock.RLock()
defer r.lock.RUnlock() defer r.lock.RUnlock()
var size int var size int
var start int var start int
rlen := len(r.elements) if r.index > len(r.elements) {
size = len(r.elements)
if r.index > rlen { start = r.index % len(r.elements)
size = rlen
start = r.index % rlen
} else { } else {
size = r.index size = r.index
} }
elements := make([]any, size) elements := make([]interface{}, size)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
elements[i] = r.elements[(start+i)%rlen] elements[i] = r.elements[(start+i)%len(r.elements)]
} }
return elements return elements

View File

@@ -19,7 +19,7 @@ func TestRingLess(t *testing.T) {
ring.Add(i) ring.Add(i)
} }
elements := ring.Take() elements := ring.Take()
assert.ElementsMatch(t, []any{0, 1, 2}, elements) assert.ElementsMatch(t, []interface{}{0, 1, 2}, elements)
} }
func TestRingMore(t *testing.T) { func TestRingMore(t *testing.T) {
@@ -28,7 +28,7 @@ func TestRingMore(t *testing.T) {
ring.Add(i) ring.Add(i)
} }
elements := ring.Take() elements := ring.Take()
assert.ElementsMatch(t, []any{6, 7, 8, 9, 10}, elements) assert.ElementsMatch(t, []interface{}{6, 7, 8, 9, 10}, elements)
} }
func TestRingAdd(t *testing.T) { func TestRingAdd(t *testing.T) {

View File

@@ -4,28 +4,18 @@ import (
"sync" "sync"
"time" "time"
"github.com/zeromicro/go-zero/core/mathx"
"github.com/zeromicro/go-zero/core/timex" "github.com/zeromicro/go-zero/core/timex"
) )
type ( type (
// BucketInterface is the interface that defines the buckets.
BucketInterface[T Numerical] interface {
Add(v T)
Reset()
}
// Numerical is the interface that restricts the numerical type.
Numerical = mathx.Numerical
// RollingWindowOption let callers customize the RollingWindow. // RollingWindowOption let callers customize the RollingWindow.
RollingWindowOption[T Numerical, B BucketInterface[T]] func(rollingWindow *RollingWindow[T, B]) RollingWindowOption func(rollingWindow *RollingWindow)
// RollingWindow defines a rolling window to calculate the events in buckets with the time interval. // RollingWindow defines a rolling window to calculate the events in buckets with time interval.
RollingWindow[T Numerical, B BucketInterface[T]] struct { RollingWindow struct {
lock sync.RWMutex lock sync.RWMutex
size int size int
win *window[T, B] win *window
interval time.Duration interval time.Duration
offset int offset int
ignoreCurrent bool ignoreCurrent bool
@@ -35,15 +25,14 @@ type (
// NewRollingWindow returns a RollingWindow that with size buckets and time interval, // NewRollingWindow returns a RollingWindow that with size buckets and time interval,
// use opts to customize the RollingWindow. // use opts to customize the RollingWindow.
func NewRollingWindow[T Numerical, B BucketInterface[T]](newBucket func() B, size int, func NewRollingWindow(size int, interval time.Duration, opts ...RollingWindowOption) *RollingWindow {
interval time.Duration, opts ...RollingWindowOption[T, B]) *RollingWindow[T, B] {
if size < 1 { if size < 1 {
panic("size must be greater than 0") panic("size must be greater than 0")
} }
w := &RollingWindow[T, B]{ w := &RollingWindow{
size: size, size: size,
win: newWindow[T, B](newBucket, size), win: newWindow(size),
interval: interval, interval: interval,
lastTime: timex.Now(), lastTime: timex.Now(),
} }
@@ -54,7 +43,7 @@ func NewRollingWindow[T Numerical, B BucketInterface[T]](newBucket func() B, siz
} }
// Add adds value to current bucket. // Add adds value to current bucket.
func (rw *RollingWindow[T, B]) Add(v T) { func (rw *RollingWindow) Add(v float64) {
rw.lock.Lock() rw.lock.Lock()
defer rw.lock.Unlock() defer rw.lock.Unlock()
rw.updateOffset() rw.updateOffset()
@@ -62,13 +51,13 @@ func (rw *RollingWindow[T, B]) Add(v T) {
} }
// Reduce runs fn on all buckets, ignore current bucket if ignoreCurrent was set. // Reduce runs fn on all buckets, ignore current bucket if ignoreCurrent was set.
func (rw *RollingWindow[T, B]) Reduce(fn func(b B)) { func (rw *RollingWindow) Reduce(fn func(b *Bucket)) {
rw.lock.RLock() rw.lock.RLock()
defer rw.lock.RUnlock() defer rw.lock.RUnlock()
var diff int var diff int
span := rw.span() span := rw.span()
// ignore the current bucket, because of partial data // ignore current bucket, because of partial data
if span == 0 && rw.ignoreCurrent { if span == 0 && rw.ignoreCurrent {
diff = rw.size - 1 diff = rw.size - 1
} else { } else {
@@ -80,7 +69,7 @@ func (rw *RollingWindow[T, B]) Reduce(fn func(b B)) {
} }
} }
func (rw *RollingWindow[T, B]) span() int { func (rw *RollingWindow) span() int {
offset := int(timex.Since(rw.lastTime) / rw.interval) offset := int(timex.Since(rw.lastTime) / rw.interval)
if 0 <= offset && offset < rw.size { if 0 <= offset && offset < rw.size {
return offset return offset
@@ -89,7 +78,7 @@ func (rw *RollingWindow[T, B]) span() int {
return rw.size return rw.size
} }
func (rw *RollingWindow[T, B]) updateOffset() { func (rw *RollingWindow) updateOffset() {
span := rw.span() span := rw.span()
if span <= 0 { if span <= 0 {
return return
@@ -108,54 +97,54 @@ func (rw *RollingWindow[T, B]) updateOffset() {
} }
// Bucket defines the bucket that holds sum and num of additions. // Bucket defines the bucket that holds sum and num of additions.
type Bucket[T Numerical] struct { type Bucket struct {
Sum T Sum float64
Count int64 Count int64
} }
func (b *Bucket[T]) Add(v T) { func (b *Bucket) add(v float64) {
b.Sum += v b.Sum += v
b.Count++ b.Count++
} }
func (b *Bucket[T]) Reset() { func (b *Bucket) reset() {
b.Sum = 0 b.Sum = 0
b.Count = 0 b.Count = 0
} }
type window[T Numerical, B BucketInterface[T]] struct { type window struct {
buckets []B buckets []*Bucket
size int size int
} }
func newWindow[T Numerical, B BucketInterface[T]](newBucket func() B, size int) *window[T, B] { func newWindow(size int) *window {
buckets := make([]B, size) buckets := make([]*Bucket, size)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
buckets[i] = newBucket() buckets[i] = new(Bucket)
} }
return &window[T, B]{ return &window{
buckets: buckets, buckets: buckets,
size: size, size: size,
} }
} }
func (w *window[T, B]) add(offset int, v T) { func (w *window) add(offset int, v float64) {
w.buckets[offset%w.size].Add(v) w.buckets[offset%w.size].add(v)
} }
func (w *window[T, B]) reduce(start, count int, fn func(b B)) { func (w *window) reduce(start, count int, fn func(b *Bucket)) {
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
fn(w.buckets[(start+i)%w.size]) fn(w.buckets[(start+i)%w.size])
} }
} }
func (w *window[T, B]) resetBucket(offset int) { func (w *window) resetBucket(offset int) {
w.buckets[offset%w.size].Reset() w.buckets[offset%w.size].reset()
} }
// IgnoreCurrentBucket lets the Reduce call ignore current bucket. // IgnoreCurrentBucket lets the Reduce call ignore current bucket.
func IgnoreCurrentBucket[T Numerical, B BucketInterface[T]]() RollingWindowOption[T, B] { func IgnoreCurrentBucket() RollingWindowOption {
return func(w *RollingWindow[T, B]) { return func(w *RollingWindow) {
w.ignoreCurrent = true w.ignoreCurrent = true
} }
} }

View File

@@ -12,24 +12,18 @@ import (
const duration = time.Millisecond * 50 const duration = time.Millisecond * 50
func TestNewRollingWindow(t *testing.T) { func TestNewRollingWindow(t *testing.T) {
assert.NotNil(t, NewRollingWindow[int64, *Bucket[int64]](func() *Bucket[int64] { assert.NotNil(t, NewRollingWindow(10, time.Second))
return new(Bucket[int64])
}, 10, time.Second))
assert.Panics(t, func() { assert.Panics(t, func() {
NewRollingWindow[int64, *Bucket[int64]](func() *Bucket[int64] { NewRollingWindow(0, time.Second)
return new(Bucket[int64])
}, 0, time.Second)
}) })
} }
func TestRollingWindowAdd(t *testing.T) { func TestRollingWindowAdd(t *testing.T) {
const size = 3 const size = 3
r := NewRollingWindow[float64, *Bucket[float64]](func() *Bucket[float64] { r := NewRollingWindow(size, duration)
return new(Bucket[float64])
}, size, duration)
listBuckets := func() []float64 { listBuckets := func() []float64 {
var buckets []float64 var buckets []float64
r.Reduce(func(b *Bucket[float64]) { r.Reduce(func(b *Bucket) {
buckets = append(buckets, b.Sum) buckets = append(buckets, b.Sum)
}) })
return buckets return buckets
@@ -53,12 +47,10 @@ func TestRollingWindowAdd(t *testing.T) {
func TestRollingWindowReset(t *testing.T) { func TestRollingWindowReset(t *testing.T) {
const size = 3 const size = 3
r := NewRollingWindow[float64, *Bucket[float64]](func() *Bucket[float64] { r := NewRollingWindow(size, duration, IgnoreCurrentBucket())
return new(Bucket[float64])
}, size, duration, IgnoreCurrentBucket[float64, *Bucket[float64]]())
listBuckets := func() []float64 { listBuckets := func() []float64 {
var buckets []float64 var buckets []float64
r.Reduce(func(b *Bucket[float64]) { r.Reduce(func(b *Bucket) {
buckets = append(buckets, b.Sum) buckets = append(buckets, b.Sum)
}) })
return buckets return buckets
@@ -80,19 +72,15 @@ func TestRollingWindowReset(t *testing.T) {
func TestRollingWindowReduce(t *testing.T) { func TestRollingWindowReduce(t *testing.T) {
const size = 4 const size = 4
tests := []struct { tests := []struct {
win *RollingWindow[float64, *Bucket[float64]] win *RollingWindow
expect float64 expect float64
}{ }{
{ {
win: NewRollingWindow[float64, *Bucket[float64]](func() *Bucket[float64] { win: NewRollingWindow(size, duration),
return new(Bucket[float64])
}, size, duration),
expect: 10, expect: 10,
}, },
{ {
win: NewRollingWindow[float64, *Bucket[float64]](func() *Bucket[float64] { win: NewRollingWindow(size, duration, IgnoreCurrentBucket()),
return new(Bucket[float64])
}, size, duration, IgnoreCurrentBucket[float64, *Bucket[float64]]()),
expect: 4, expect: 4,
}, },
} }
@@ -109,7 +97,7 @@ func TestRollingWindowReduce(t *testing.T) {
} }
} }
var result float64 var result float64
r.Reduce(func(b *Bucket[float64]) { r.Reduce(func(b *Bucket) {
result += b.Sum result += b.Sum
}) })
assert.Equal(t, test.expect, result) assert.Equal(t, test.expect, result)
@@ -120,12 +108,10 @@ func TestRollingWindowReduce(t *testing.T) {
func TestRollingWindowBucketTimeBoundary(t *testing.T) { func TestRollingWindowBucketTimeBoundary(t *testing.T) {
const size = 3 const size = 3
interval := time.Millisecond * 30 interval := time.Millisecond * 30
r := NewRollingWindow[float64, *Bucket[float64]](func() *Bucket[float64] { r := NewRollingWindow(size, interval)
return new(Bucket[float64])
}, size, interval)
listBuckets := func() []float64 { listBuckets := func() []float64 {
var buckets []float64 var buckets []float64
r.Reduce(func(b *Bucket[float64]) { r.Reduce(func(b *Bucket) {
buckets = append(buckets, b.Sum) buckets = append(buckets, b.Sum)
}) })
return buckets return buckets
@@ -152,9 +138,7 @@ func TestRollingWindowBucketTimeBoundary(t *testing.T) {
func TestRollingWindowDataRace(t *testing.T) { func TestRollingWindowDataRace(t *testing.T) {
const size = 3 const size = 3
r := NewRollingWindow[float64, *Bucket[float64]](func() *Bucket[float64] { r := NewRollingWindow(size, duration)
return new(Bucket[float64])
}, size, duration)
stop := make(chan bool) stop := make(chan bool)
go func() { go func() {
for { for {
@@ -173,7 +157,7 @@ func TestRollingWindowDataRace(t *testing.T) {
case <-stop: case <-stop:
return return
default: default:
r.Reduce(func(b *Bucket[float64]) {}) r.Reduce(func(b *Bucket) {})
} }
} }
}() }()

View File

@@ -14,23 +14,21 @@ type SafeMap struct {
lock sync.RWMutex lock sync.RWMutex
deletionOld int deletionOld int
deletionNew int deletionNew int
dirtyOld map[any]any dirtyOld map[interface{}]interface{}
dirtyNew map[any]any dirtyNew map[interface{}]interface{}
} }
// NewSafeMap returns a SafeMap. // NewSafeMap returns a SafeMap.
func NewSafeMap() *SafeMap { func NewSafeMap() *SafeMap {
return &SafeMap{ return &SafeMap{
dirtyOld: make(map[any]any), dirtyOld: make(map[interface{}]interface{}),
dirtyNew: make(map[any]any), dirtyNew: make(map[interface{}]interface{}),
} }
} }
// Del deletes the value with the given key from m. // Del deletes the value with the given key from m.
func (m *SafeMap) Del(key any) { func (m *SafeMap) Del(key interface{}) {
m.lock.Lock() m.lock.Lock()
defer m.lock.Unlock()
if _, ok := m.dirtyOld[key]; ok { if _, ok := m.dirtyOld[key]; ok {
delete(m.dirtyOld, key) delete(m.dirtyOld, key)
m.deletionOld++ m.deletionOld++
@@ -44,20 +42,21 @@ func (m *SafeMap) Del(key any) {
} }
m.dirtyOld = m.dirtyNew m.dirtyOld = m.dirtyNew
m.deletionOld = m.deletionNew m.deletionOld = m.deletionNew
m.dirtyNew = make(map[any]any) m.dirtyNew = make(map[interface{}]interface{})
m.deletionNew = 0 m.deletionNew = 0
} }
if m.deletionNew >= maxDeletion && len(m.dirtyNew) < copyThreshold { if m.deletionNew >= maxDeletion && len(m.dirtyNew) < copyThreshold {
for k, v := range m.dirtyNew { for k, v := range m.dirtyNew {
m.dirtyOld[k] = v m.dirtyOld[k] = v
} }
m.dirtyNew = make(map[any]any) m.dirtyNew = make(map[interface{}]interface{})
m.deletionNew = 0 m.deletionNew = 0
} }
m.lock.Unlock()
} }
// Get gets the value with the given key from m. // Get gets the value with the given key from m.
func (m *SafeMap) Get(key any) (any, bool) { func (m *SafeMap) Get(key interface{}) (interface{}, bool) {
m.lock.RLock() m.lock.RLock()
defer m.lock.RUnlock() defer m.lock.RUnlock()
@@ -71,7 +70,7 @@ func (m *SafeMap) Get(key any) (any, bool) {
// Range calls f sequentially for each key and value present in the map. // Range calls f sequentially for each key and value present in the map.
// If f returns false, range stops the iteration. // If f returns false, range stops the iteration.
func (m *SafeMap) Range(f func(key, val any) bool) { func (m *SafeMap) Range(f func(key, val interface{}) bool) {
m.lock.RLock() m.lock.RLock()
defer m.lock.RUnlock() defer m.lock.RUnlock()
@@ -88,10 +87,8 @@ func (m *SafeMap) Range(f func(key, val any) bool) {
} }
// Set sets the value into m with the given key. // Set sets the value into m with the given key.
func (m *SafeMap) Set(key, value any) { func (m *SafeMap) Set(key, value interface{}) {
m.lock.Lock() m.lock.Lock()
defer m.lock.Unlock()
if m.deletionOld <= maxDeletion { if m.deletionOld <= maxDeletion {
if _, ok := m.dirtyNew[key]; ok { if _, ok := m.dirtyNew[key]; ok {
delete(m.dirtyNew, key) delete(m.dirtyNew, key)
@@ -105,6 +102,7 @@ func (m *SafeMap) Set(key, value any) {
} }
m.dirtyNew[key] = value m.dirtyNew[key] = value
} }
m.lock.Unlock()
} }
// Size returns the size of m. // Size returns the size of m.

View File

@@ -138,7 +138,7 @@ func TestSafeMap_Range(t *testing.T) {
} }
var count int32 var count int32
m.Range(func(k, v any) bool { m.Range(func(k, v interface{}) bool {
atomic.AddInt32(&count, 1) atomic.AddInt32(&count, 1)
newMap.Set(k, v) newMap.Set(k, v)
return true return true
@@ -147,65 +147,3 @@ func TestSafeMap_Range(t *testing.T) {
assert.Equal(t, m.dirtyNew, newMap.dirtyNew) assert.Equal(t, m.dirtyNew, newMap.dirtyNew)
assert.Equal(t, m.dirtyOld, newMap.dirtyOld) assert.Equal(t, m.dirtyOld, newMap.dirtyOld)
} }
func TestSetManyTimes(t *testing.T) {
const iteration = maxDeletion * 2
m := NewSafeMap()
for i := 0; i < iteration; i++ {
m.Set(i, i)
if i%3 == 0 {
m.Del(i / 2)
}
}
var count int
m.Range(func(k, v any) bool {
count++
return count < maxDeletion/2
})
assert.Equal(t, maxDeletion/2, count)
for i := 0; i < iteration; i++ {
m.Set(i, i)
if i%3 == 0 {
m.Del(i / 2)
}
}
for i := 0; i < iteration; i++ {
m.Set(i, i)
if i%3 == 0 {
m.Del(i / 2)
}
}
for i := 0; i < iteration; i++ {
m.Set(i, i)
if i%3 == 0 {
m.Del(i / 2)
}
}
count = 0
m.Range(func(k, v any) bool {
count++
return count < maxDeletion
})
assert.Equal(t, maxDeletion, count)
}
func TestSetManyTimesNew(t *testing.T) {
m := NewSafeMap()
for i := 0; i < maxDeletion*3; i++ {
m.Set(i, i)
}
for i := 0; i < maxDeletion*2; i++ {
m.Del(i)
}
for i := 0; i < maxDeletion*3; i++ {
m.Set(i+maxDeletion*3, i+maxDeletion*3)
}
for i := 0; i < maxDeletion*2; i++ {
m.Del(i + maxDeletion*2)
}
for i := 0; i < maxDeletion-copyThreshold+1; i++ {
m.Del(i + maxDeletion*2)
}
assert.Equal(t, 0, len(m.dirtyNew))
}

View File

@@ -1,53 +1,235 @@
package collection package collection
import "github.com/zeromicro/go-zero/core/lang" import (
"github.com/zeromicro/go-zero/core/lang"
"github.com/zeromicro/go-zero/core/logx"
)
// Set is a type-safe generic set collection. const (
// It's not thread-safe, use with synchronization for concurrent access. unmanaged = iota
type Set[T comparable] struct { untyped
data map[T]lang.PlaceholderType intType
int64Type
uintType
uint64Type
stringType
)
// Set is not thread-safe, for concurrent use, make sure to use it with synchronization.
type Set struct {
data map[interface{}]lang.PlaceholderType
tp int
} }
// NewSet returns a new type-safe set. // NewSet returns a managed Set, can only put the values with the same type.
func NewSet[T comparable]() *Set[T] { func NewSet() *Set {
return &Set[T]{ return &Set{
data: make(map[T]lang.PlaceholderType), data: make(map[interface{}]lang.PlaceholderType),
tp: untyped,
} }
} }
// Add adds items to the set. Duplicates are automatically ignored. // NewUnmanagedSet returns an unmanaged Set, which can put values with different types.
func (s *Set[T]) Add(items ...T) { func NewUnmanagedSet() *Set {
for _, item := range items { return &Set{
s.data[item] = lang.Placeholder data: make(map[interface{}]lang.PlaceholderType),
tp: unmanaged,
} }
} }
// Clear removes all items from the set. // Add adds i into s.
func (s *Set[T]) Clear() { func (s *Set) Add(i ...interface{}) {
clear(s.data) for _, each := range i {
s.add(each)
}
} }
// Contains checks if an item exists in the set. // AddInt adds int values ii into s.
func (s *Set[T]) Contains(item T) bool { func (s *Set) AddInt(ii ...int) {
_, ok := s.data[item] for _, each := range ii {
s.add(each)
}
}
// AddInt64 adds int64 values ii into s.
func (s *Set) AddInt64(ii ...int64) {
for _, each := range ii {
s.add(each)
}
}
// AddUint adds uint values ii into s.
func (s *Set) AddUint(ii ...uint) {
for _, each := range ii {
s.add(each)
}
}
// AddUint64 adds uint64 values ii into s.
func (s *Set) AddUint64(ii ...uint64) {
for _, each := range ii {
s.add(each)
}
}
// AddStr adds string values ss into s.
func (s *Set) AddStr(ss ...string) {
for _, each := range ss {
s.add(each)
}
}
// Contains checks if i is in s.
func (s *Set) Contains(i interface{}) bool {
if len(s.data) == 0 {
return false
}
s.validate(i)
_, ok := s.data[i]
return ok return ok
} }
// Count returns the number of items in the set. // Keys returns the keys in s.
func (s *Set[T]) Count() int { func (s *Set) Keys() []interface{} {
return len(s.data) var keys []interface{}
}
// Keys returns all elements in the set as a slice.
func (s *Set[T]) Keys() []T {
keys := make([]T, 0, len(s.data))
for key := range s.data { for key := range s.data {
keys = append(keys, key) keys = append(keys, key)
} }
return keys return keys
} }
// Remove removes an item from the set. // KeysInt returns the int keys in s.
func (s *Set[T]) Remove(item T) { func (s *Set) KeysInt() []int {
delete(s.data, item) var keys []int
for key := range s.data {
if intKey, ok := key.(int); ok {
keys = append(keys, intKey)
}
}
return keys
}
// KeysInt64 returns int64 keys in s.
func (s *Set) KeysInt64() []int64 {
var keys []int64
for key := range s.data {
if intKey, ok := key.(int64); ok {
keys = append(keys, intKey)
}
}
return keys
}
// KeysUint returns uint keys in s.
func (s *Set) KeysUint() []uint {
var keys []uint
for key := range s.data {
if intKey, ok := key.(uint); ok {
keys = append(keys, intKey)
}
}
return keys
}
// KeysUint64 returns uint64 keys in s.
func (s *Set) KeysUint64() []uint64 {
var keys []uint64
for key := range s.data {
if intKey, ok := key.(uint64); ok {
keys = append(keys, intKey)
}
}
return keys
}
// KeysStr returns string keys in s.
func (s *Set) KeysStr() []string {
var keys []string
for key := range s.data {
if strKey, ok := key.(string); ok {
keys = append(keys, strKey)
}
}
return keys
}
// Remove removes i from s.
func (s *Set) Remove(i interface{}) {
s.validate(i)
delete(s.data, i)
}
// Count returns the number of items in s.
func (s *Set) Count() int {
return len(s.data)
}
func (s *Set) add(i interface{}) {
switch s.tp {
case unmanaged:
// do nothing
case untyped:
s.setType(i)
default:
s.validate(i)
}
s.data[i] = lang.Placeholder
}
func (s *Set) setType(i interface{}) {
// s.tp can only be untyped here
switch i.(type) {
case int:
s.tp = intType
case int64:
s.tp = int64Type
case uint:
s.tp = uintType
case uint64:
s.tp = uint64Type
case string:
s.tp = stringType
}
}
func (s *Set) validate(i interface{}) {
if s.tp == unmanaged {
return
}
switch i.(type) {
case int:
if s.tp != intType {
logx.Errorf("element is int, but set contains elements with type %d", s.tp)
}
case int64:
if s.tp != int64Type {
logx.Errorf("element is int64, but set contains elements with type %d", s.tp)
}
case uint:
if s.tp != uintType {
logx.Errorf("element is uint, but set contains elements with type %d", s.tp)
}
case uint64:
if s.tp != uint64Type {
logx.Errorf("element is uint64, but set contains elements with type %d", s.tp)
}
case string:
if s.tp != stringType {
logx.Errorf("element is string, but set contains elements with type %d", s.tp)
}
}
} }

View File

@@ -12,117 +12,34 @@ func init() {
logx.Disable() logx.Disable()
} }
// Set functionality tests
func TestTypedSetInt(t *testing.T) {
set := NewSet[int]()
values := []int{1, 2, 3, 2, 1} // Contains duplicates
// Test adding
set.Add(values...)
assert.Equal(t, 3, set.Count()) // Should only have 3 elements after deduplication
// Test contains
assert.True(t, set.Contains(1))
assert.True(t, set.Contains(2))
assert.True(t, set.Contains(3))
assert.False(t, set.Contains(4))
// Test getting all keys
keys := set.Keys()
sort.Ints(keys)
assert.EqualValues(t, []int{1, 2, 3}, keys)
// Test removal
set.Remove(2)
assert.False(t, set.Contains(2))
assert.Equal(t, 2, set.Count())
}
func TestTypedSetStringOps(t *testing.T) {
set := NewSet[string]()
values := []string{"a", "b", "c", "b", "a"}
set.Add(values...)
assert.Equal(t, 3, set.Count())
assert.True(t, set.Contains("a"))
assert.True(t, set.Contains("b"))
assert.True(t, set.Contains("c"))
assert.False(t, set.Contains("d"))
keys := set.Keys()
sort.Strings(keys)
assert.EqualValues(t, []string{"a", "b", "c"}, keys)
}
func TestTypedSetClear(t *testing.T) {
set := NewSet[int]()
set.Add(1, 2, 3)
assert.Equal(t, 3, set.Count())
set.Clear()
assert.Equal(t, 0, set.Count())
assert.False(t, set.Contains(1))
}
func TestTypedSetEmpty(t *testing.T) {
set := NewSet[int]()
assert.Equal(t, 0, set.Count())
assert.False(t, set.Contains(1))
assert.Empty(t, set.Keys())
}
func TestTypedSetMultipleTypes(t *testing.T) {
// Test different typed generic sets
intSet := NewSet[int]()
int64Set := NewSet[int64]()
uintSet := NewSet[uint]()
uint64Set := NewSet[uint64]()
stringSet := NewSet[string]()
intSet.Add(1, 2, 3)
int64Set.Add(1, 2, 3)
uintSet.Add(1, 2, 3)
uint64Set.Add(1, 2, 3)
stringSet.Add("1", "2", "3")
assert.Equal(t, 3, intSet.Count())
assert.Equal(t, 3, int64Set.Count())
assert.Equal(t, 3, uintSet.Count())
assert.Equal(t, 3, uint64Set.Count())
assert.Equal(t, 3, stringSet.Count())
}
// Set benchmarks
func BenchmarkTypedIntSet(b *testing.B) {
s := NewSet[int]()
for i := 0; i < b.N; i++ {
s.Add(i)
_ = s.Contains(i)
}
}
func BenchmarkTypedStringSet(b *testing.B) {
s := NewSet[string]()
for i := 0; i < b.N; i++ {
s.Add(string(rune(i)))
_ = s.Contains(string(rune(i)))
}
}
// Legacy tests remain unchanged for backward compatibility
func BenchmarkRawSet(b *testing.B) { func BenchmarkRawSet(b *testing.B) {
m := make(map[any]struct{}) m := make(map[interface{}]struct{})
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
m[i] = struct{}{} m[i] = struct{}{}
_ = m[i] _ = m[i]
} }
} }
func BenchmarkUnmanagedSet(b *testing.B) {
s := NewUnmanagedSet()
for i := 0; i < b.N; i++ {
s.Add(i)
_ = s.Contains(i)
}
}
func BenchmarkSet(b *testing.B) {
s := NewSet()
for i := 0; i < b.N; i++ {
s.AddInt(i)
_ = s.Contains(i)
}
}
func TestAdd(t *testing.T) { func TestAdd(t *testing.T) {
// given // given
set := NewSet[int]() set := NewUnmanagedSet()
values := []int{1, 2, 3} values := []interface{}{1, 2, 3}
// when // when
set.Add(values...) set.Add(values...)
@@ -134,74 +51,82 @@ func TestAdd(t *testing.T) {
func TestAddInt(t *testing.T) { func TestAddInt(t *testing.T) {
// given // given
set := NewSet[int]() set := NewSet()
values := []int{1, 2, 3} values := []int{1, 2, 3}
// when // when
set.Add(values...) set.AddInt(values...)
// then // then
assert.True(t, set.Contains(1) && set.Contains(2) && set.Contains(3)) assert.True(t, set.Contains(1) && set.Contains(2) && set.Contains(3))
keys := set.Keys() keys := set.KeysInt()
sort.Ints(keys) sort.Ints(keys)
assert.EqualValues(t, values, keys) assert.EqualValues(t, values, keys)
} }
func TestAddInt64(t *testing.T) { func TestAddInt64(t *testing.T) {
// given // given
set := NewSet[int64]() set := NewSet()
values := []int64{1, 2, 3} values := []int64{1, 2, 3}
// when // when
set.Add(values...) set.AddInt64(values...)
// then // then
assert.True(t, set.Contains(1) && set.Contains(2) && set.Contains(3)) assert.True(t, set.Contains(int64(1)) && set.Contains(int64(2)) && set.Contains(int64(3)))
assert.Equal(t, len(values), len(set.Keys())) assert.Equal(t, len(values), len(set.KeysInt64()))
} }
func TestAddUint(t *testing.T) { func TestAddUint(t *testing.T) {
// given // given
set := NewSet[uint]() set := NewSet()
values := []uint{1, 2, 3} values := []uint{1, 2, 3}
// when // when
set.Add(values...) set.AddUint(values...)
// then // then
assert.True(t, set.Contains(1) && set.Contains(2) && set.Contains(3)) assert.True(t, set.Contains(uint(1)) && set.Contains(uint(2)) && set.Contains(uint(3)))
assert.Equal(t, len(values), len(set.Keys())) assert.Equal(t, len(values), len(set.KeysUint()))
} }
func TestAddUint64(t *testing.T) { func TestAddUint64(t *testing.T) {
// given // given
set := NewSet[uint64]() set := NewSet()
values := []uint64{1, 2, 3} values := []uint64{1, 2, 3}
// when // when
set.Add(values...) set.AddUint64(values...)
// then // then
assert.True(t, set.Contains(1) && set.Contains(2) && set.Contains(3)) assert.True(t, set.Contains(uint64(1)) && set.Contains(uint64(2)) && set.Contains(uint64(3)))
assert.Equal(t, len(values), len(set.Keys())) assert.Equal(t, len(values), len(set.KeysUint64()))
} }
func TestAddStr(t *testing.T) { func TestAddStr(t *testing.T) {
// given // given
set := NewSet[string]() set := NewSet()
values := []string{"1", "2", "3"} values := []string{"1", "2", "3"}
// when // when
set.Add(values...) set.AddStr(values...)
// then // then
assert.True(t, set.Contains("1") && set.Contains("2") && set.Contains("3")) assert.True(t, set.Contains("1") && set.Contains("2") && set.Contains("3"))
assert.Equal(t, len(values), len(set.Keys())) assert.Equal(t, len(values), len(set.KeysStr()))
} }
func TestContainsWithoutElements(t *testing.T) { func TestContainsWithoutElements(t *testing.T) {
// given // given
set := NewSet[int]() set := NewSet()
// then
assert.False(t, set.Contains(1))
}
func TestContainsUnmanagedWithoutElements(t *testing.T) {
// given
set := NewUnmanagedSet()
// then // then
assert.False(t, set.Contains(1)) assert.False(t, set.Contains(1))
@@ -209,8 +134,8 @@ func TestContainsWithoutElements(t *testing.T) {
func TestRemove(t *testing.T) { func TestRemove(t *testing.T) {
// given // given
set := NewSet[int]() set := NewSet()
set.Add([]int{1, 2, 3}...) set.Add([]interface{}{1, 2, 3}...)
// when // when
set.Remove(2) set.Remove(2)
@@ -221,9 +146,57 @@ func TestRemove(t *testing.T) {
func TestCount(t *testing.T) { func TestCount(t *testing.T) {
// given // given
set := NewSet[int]() set := NewSet()
set.Add([]int{1, 2, 3}...) set.Add([]interface{}{1, 2, 3}...)
// then // then
assert.Equal(t, set.Count(), 3) assert.Equal(t, set.Count(), 3)
} }
func TestKeysIntMismatch(t *testing.T) {
set := NewSet()
set.add(int64(1))
set.add(2)
vals := set.KeysInt()
assert.EqualValues(t, []int{2}, vals)
}
func TestKeysInt64Mismatch(t *testing.T) {
set := NewSet()
set.add(1)
set.add(int64(2))
vals := set.KeysInt64()
assert.EqualValues(t, []int64{2}, vals)
}
func TestKeysUintMismatch(t *testing.T) {
set := NewSet()
set.add(1)
set.add(uint(2))
vals := set.KeysUint()
assert.EqualValues(t, []uint{2}, vals)
}
func TestKeysUint64Mismatch(t *testing.T) {
set := NewSet()
set.add(1)
set.add(uint64(2))
vals := set.KeysUint64()
assert.EqualValues(t, []uint64{2}, vals)
}
func TestKeysStrMismatch(t *testing.T) {
set := NewSet()
set.add(1)
set.add("2")
vals := set.KeysStr()
assert.EqualValues(t, []string{"2"}, vals)
}
func TestSetType(t *testing.T) {
set := NewUnmanagedSet()
set.add(1)
set.add("2")
vals := set.Keys()
assert.ElementsMatch(t, []interface{}{1, "2"}, vals)
}

View File

@@ -20,7 +20,7 @@ var (
type ( type (
// Execute defines the method to execute the task. // Execute defines the method to execute the task.
Execute func(key, value any) Execute func(key, value interface{})
// A TimingWheel is a timing wheel object to schedule tasks. // A TimingWheel is a timing wheel object to schedule tasks.
TimingWheel struct { TimingWheel struct {
@@ -33,14 +33,14 @@ type (
execute Execute execute Execute
setChannel chan timingEntry setChannel chan timingEntry
moveChannel chan baseEntry moveChannel chan baseEntry
removeChannel chan any removeChannel chan interface{}
drainChannel chan func(key, value any) drainChannel chan func(key, value interface{})
stopChannel chan lang.PlaceholderType stopChannel chan lang.PlaceholderType
} }
timingEntry struct { timingEntry struct {
baseEntry baseEntry
value any value interface{}
circle int circle int
diff int diff int
removed bool removed bool
@@ -48,7 +48,7 @@ type (
baseEntry struct { baseEntry struct {
delay time.Duration delay time.Duration
key any key interface{}
} }
positionEntry struct { positionEntry struct {
@@ -57,8 +57,8 @@ type (
} }
timingTask struct { timingTask struct {
key any key interface{}
value any value interface{}
} }
) )
@@ -85,8 +85,8 @@ func NewTimingWheelWithTicker(interval time.Duration, numSlots int, execute Exec
numSlots: numSlots, numSlots: numSlots,
setChannel: make(chan timingEntry), setChannel: make(chan timingEntry),
moveChannel: make(chan baseEntry), moveChannel: make(chan baseEntry),
removeChannel: make(chan any), removeChannel: make(chan interface{}),
drainChannel: make(chan func(key, value any)), drainChannel: make(chan func(key, value interface{})),
stopChannel: make(chan lang.PlaceholderType), stopChannel: make(chan lang.PlaceholderType),
} }
@@ -97,7 +97,7 @@ func NewTimingWheelWithTicker(interval time.Duration, numSlots int, execute Exec
} }
// Drain drains all items and executes them. // Drain drains all items and executes them.
func (tw *TimingWheel) Drain(fn func(key, value any)) error { func (tw *TimingWheel) Drain(fn func(key, value interface{})) error {
select { select {
case tw.drainChannel <- fn: case tw.drainChannel <- fn:
return nil return nil
@@ -107,7 +107,7 @@ func (tw *TimingWheel) Drain(fn func(key, value any)) error {
} }
// MoveTimer moves the task with the given key to the given delay. // MoveTimer moves the task with the given key to the given delay.
func (tw *TimingWheel) MoveTimer(key any, delay time.Duration) error { func (tw *TimingWheel) MoveTimer(key interface{}, delay time.Duration) error {
if delay <= 0 || key == nil { if delay <= 0 || key == nil {
return ErrArgument return ErrArgument
} }
@@ -124,7 +124,7 @@ func (tw *TimingWheel) MoveTimer(key any, delay time.Duration) error {
} }
// RemoveTimer removes the task with the given key. // RemoveTimer removes the task with the given key.
func (tw *TimingWheel) RemoveTimer(key any) error { func (tw *TimingWheel) RemoveTimer(key interface{}) error {
if key == nil { if key == nil {
return ErrArgument return ErrArgument
} }
@@ -138,7 +138,7 @@ func (tw *TimingWheel) RemoveTimer(key any) error {
} }
// SetTimer sets the task value with the given key to the delay. // SetTimer sets the task value with the given key to the delay.
func (tw *TimingWheel) SetTimer(key, value any, delay time.Duration) error { func (tw *TimingWheel) SetTimer(key, value interface{}, delay time.Duration) error {
if delay <= 0 || key == nil { if delay <= 0 || key == nil {
return ErrArgument return ErrArgument
} }
@@ -162,7 +162,7 @@ func (tw *TimingWheel) Stop() {
close(tw.stopChannel) close(tw.stopChannel)
} }
func (tw *TimingWheel) drainAll(fn func(key, value any)) { func (tw *TimingWheel) drainAll(fn func(key, value interface{})) {
runner := threading.NewTaskRunner(drainWorkers) runner := threading.NewTaskRunner(drainWorkers)
for _, slot := range tw.slots { for _, slot := range tw.slots {
for e := slot.Front(); e != nil; { for e := slot.Front(); e != nil; {
@@ -232,7 +232,7 @@ func (tw *TimingWheel) onTick() {
tw.scanAndRunTasks(l) tw.scanAndRunTasks(l)
} }
func (tw *TimingWheel) removeTask(key any) { func (tw *TimingWheel) removeTask(key interface{}) {
val, ok := tw.timers.Get(key) val, ok := tw.timers.Get(key)
if !ok { if !ok {
return return

View File

@@ -20,13 +20,13 @@ const (
) )
func TestNewTimingWheel(t *testing.T) { func TestNewTimingWheel(t *testing.T) {
_, err := NewTimingWheel(0, 10, func(key, value any) {}) _, err := NewTimingWheel(0, 10, func(key, value interface{}) {})
assert.NotNil(t, err) assert.NotNil(t, err)
} }
func TestTimingWheel_Drain(t *testing.T) { func TestTimingWheel_Drain(t *testing.T) {
ticker := timex.NewFakeTicker() ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) { tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {
}, ticker) }, ticker)
tw.SetTimer("first", 3, testStep*4) tw.SetTimer("first", 3, testStep*4)
tw.SetTimer("second", 5, testStep*7) tw.SetTimer("second", 5, testStep*7)
@@ -36,7 +36,7 @@ func TestTimingWheel_Drain(t *testing.T) {
var lock sync.Mutex var lock sync.Mutex
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(3) wg.Add(3)
tw.Drain(func(key, value any) { tw.Drain(func(key, value interface{}) {
lock.Lock() lock.Lock()
defer lock.Unlock() defer lock.Unlock()
keys = append(keys, key.(string)) keys = append(keys, key.(string))
@@ -50,19 +50,19 @@ func TestTimingWheel_Drain(t *testing.T) {
assert.EqualValues(t, []string{"first", "second", "third"}, keys) assert.EqualValues(t, []string{"first", "second", "third"}, keys)
assert.EqualValues(t, []int{3, 5, 7}, vals) assert.EqualValues(t, []int{3, 5, 7}, vals)
var count int var count int
tw.Drain(func(key, value any) { tw.Drain(func(key, value interface{}) {
count++ count++
}) })
time.Sleep(time.Millisecond * 100) time.Sleep(time.Millisecond * 100)
assert.Equal(t, 0, count) assert.Equal(t, 0, count)
tw.Stop() tw.Stop()
assert.Equal(t, ErrClosed, tw.Drain(func(key, value any) {})) assert.Equal(t, ErrClosed, tw.Drain(func(key, value interface{}) {}))
} }
func TestTimingWheel_SetTimerSoon(t *testing.T) { func TestTimingWheel_SetTimerSoon(t *testing.T) {
run := syncx.NewAtomicBool() run := syncx.NewAtomicBool()
ticker := timex.NewFakeTicker() ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) { tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {
assert.True(t, run.CompareAndSwap(false, true)) assert.True(t, run.CompareAndSwap(false, true))
assert.Equal(t, "any", k) assert.Equal(t, "any", k)
assert.Equal(t, 3, v.(int)) assert.Equal(t, 3, v.(int))
@@ -78,7 +78,7 @@ func TestTimingWheel_SetTimerSoon(t *testing.T) {
func TestTimingWheel_SetTimerTwice(t *testing.T) { func TestTimingWheel_SetTimerTwice(t *testing.T) {
run := syncx.NewAtomicBool() run := syncx.NewAtomicBool()
ticker := timex.NewFakeTicker() ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) { tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {
assert.True(t, run.CompareAndSwap(false, true)) assert.True(t, run.CompareAndSwap(false, true))
assert.Equal(t, "any", k) assert.Equal(t, "any", k)
assert.Equal(t, 5, v.(int)) assert.Equal(t, 5, v.(int))
@@ -96,7 +96,7 @@ func TestTimingWheel_SetTimerTwice(t *testing.T) {
func TestTimingWheel_SetTimerWrongDelay(t *testing.T) { func TestTimingWheel_SetTimerWrongDelay(t *testing.T) {
ticker := timex.NewFakeTicker() ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) {}, ticker) tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {}, ticker)
defer tw.Stop() defer tw.Stop()
assert.NotPanics(t, func() { assert.NotPanics(t, func() {
tw.SetTimer("any", 3, -testStep) tw.SetTimer("any", 3, -testStep)
@@ -105,7 +105,7 @@ func TestTimingWheel_SetTimerWrongDelay(t *testing.T) {
func TestTimingWheel_SetTimerAfterClose(t *testing.T) { func TestTimingWheel_SetTimerAfterClose(t *testing.T) {
ticker := timex.NewFakeTicker() ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) {}, ticker) tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {}, ticker)
tw.Stop() tw.Stop()
assert.Equal(t, ErrClosed, tw.SetTimer("any", 3, testStep)) assert.Equal(t, ErrClosed, tw.SetTimer("any", 3, testStep))
} }
@@ -113,7 +113,7 @@ func TestTimingWheel_SetTimerAfterClose(t *testing.T) {
func TestTimingWheel_MoveTimer(t *testing.T) { func TestTimingWheel_MoveTimer(t *testing.T) {
run := syncx.NewAtomicBool() run := syncx.NewAtomicBool()
ticker := timex.NewFakeTicker() ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 3, func(k, v any) { tw, _ := NewTimingWheelWithTicker(testStep, 3, func(k, v interface{}) {
assert.True(t, run.CompareAndSwap(false, true)) assert.True(t, run.CompareAndSwap(false, true))
assert.Equal(t, "any", k) assert.Equal(t, "any", k)
assert.Equal(t, 3, v.(int)) assert.Equal(t, 3, v.(int))
@@ -139,7 +139,7 @@ func TestTimingWheel_MoveTimer(t *testing.T) {
func TestTimingWheel_MoveTimerSoon(t *testing.T) { func TestTimingWheel_MoveTimerSoon(t *testing.T) {
run := syncx.NewAtomicBool() run := syncx.NewAtomicBool()
ticker := timex.NewFakeTicker() ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 3, func(k, v any) { tw, _ := NewTimingWheelWithTicker(testStep, 3, func(k, v interface{}) {
assert.True(t, run.CompareAndSwap(false, true)) assert.True(t, run.CompareAndSwap(false, true))
assert.Equal(t, "any", k) assert.Equal(t, "any", k)
assert.Equal(t, 3, v.(int)) assert.Equal(t, 3, v.(int))
@@ -155,7 +155,7 @@ func TestTimingWheel_MoveTimerSoon(t *testing.T) {
func TestTimingWheel_MoveTimerEarlier(t *testing.T) { func TestTimingWheel_MoveTimerEarlier(t *testing.T) {
run := syncx.NewAtomicBool() run := syncx.NewAtomicBool()
ticker := timex.NewFakeTicker() ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) { tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {
assert.True(t, run.CompareAndSwap(false, true)) assert.True(t, run.CompareAndSwap(false, true))
assert.Equal(t, "any", k) assert.Equal(t, "any", k)
assert.Equal(t, 3, v.(int)) assert.Equal(t, 3, v.(int))
@@ -173,7 +173,7 @@ func TestTimingWheel_MoveTimerEarlier(t *testing.T) {
func TestTimingWheel_RemoveTimer(t *testing.T) { func TestTimingWheel_RemoveTimer(t *testing.T) {
ticker := timex.NewFakeTicker() ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) {}, ticker) tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {}, ticker)
tw.SetTimer("any", 3, testStep) tw.SetTimer("any", 3, testStep)
assert.NotPanics(t, func() { assert.NotPanics(t, func() {
tw.RemoveTimer("any") tw.RemoveTimer("any")
@@ -236,7 +236,7 @@ func TestTimingWheel_SetTimer(t *testing.T) {
} }
var actual int32 var actual int32
done := make(chan lang.PlaceholderType) done := make(chan lang.PlaceholderType)
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value any) { tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value interface{}) {
assert.Equal(t, 1, key.(int)) assert.Equal(t, 1, key.(int))
assert.Equal(t, 2, value.(int)) assert.Equal(t, 2, value.(int))
actual = atomic.LoadInt32(&count) actual = atomic.LoadInt32(&count)
@@ -317,7 +317,7 @@ func TestTimingWheel_SetAndMoveThenStart(t *testing.T) {
} }
var actual int32 var actual int32
done := make(chan lang.PlaceholderType) done := make(chan lang.PlaceholderType)
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value any) { tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value interface{}) {
actual = atomic.LoadInt32(&count) actual = atomic.LoadInt32(&count)
close(done) close(done)
}, ticker) }, ticker)
@@ -405,7 +405,7 @@ func TestTimingWheel_SetAndMoveTwice(t *testing.T) {
} }
var actual int32 var actual int32
done := make(chan lang.PlaceholderType) done := make(chan lang.PlaceholderType)
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value any) { tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value interface{}) {
actual = atomic.LoadInt32(&count) actual = atomic.LoadInt32(&count)
close(done) close(done)
}, ticker) }, ticker)
@@ -486,7 +486,7 @@ func TestTimingWheel_ElapsedAndSet(t *testing.T) {
} }
var actual int32 var actual int32
done := make(chan lang.PlaceholderType) done := make(chan lang.PlaceholderType)
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value any) { tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value interface{}) {
actual = atomic.LoadInt32(&count) actual = atomic.LoadInt32(&count)
close(done) close(done)
}, ticker) }, ticker)
@@ -577,7 +577,7 @@ func TestTimingWheel_ElapsedAndSetThenMove(t *testing.T) {
} }
var actual int32 var actual int32
done := make(chan lang.PlaceholderType) done := make(chan lang.PlaceholderType)
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value any) { tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value interface{}) {
actual = atomic.LoadInt32(&count) actual = atomic.LoadInt32(&count)
close(done) close(done)
}, ticker) }, ticker)
@@ -612,7 +612,7 @@ func TestMoveAndRemoveTask(t *testing.T) {
} }
} }
var keys []int var keys []int
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) { tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {
assert.Equal(t, "any", k) assert.Equal(t, "any", k)
assert.Equal(t, 3, v.(int)) assert.Equal(t, 3, v.(int))
keys = append(keys, v.(int)) keys = append(keys, v.(int))
@@ -632,7 +632,7 @@ func TestMoveAndRemoveTask(t *testing.T) {
func BenchmarkTimingWheel(b *testing.B) { func BenchmarkTimingWheel(b *testing.B) {
b.ReportAllocs() b.ReportAllocs()
tw, _ := NewTimingWheel(time.Second, 100, func(k, v any) {}) tw, _ := NewTimingWheel(time.Second, 100, func(k, v interface{}) {})
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
tw.SetTimer(i, i, time.Second) tw.SetTimer(i, i, time.Second)
tw.SetTimer(b.N+i, b.N+i, time.Second) tw.SetTimer(b.N+i, b.N+i, time.Second)

View File

@@ -13,14 +13,11 @@ import (
"github.com/zeromicro/go-zero/internal/encoding" "github.com/zeromicro/go-zero/internal/encoding"
) )
const ( const jsonTagKey = "json"
jsonTagKey = "json"
jsonTagSep = ','
)
var ( var (
fillDefaultUnmarshaler = mapping.NewUnmarshaler(jsonTagKey, mapping.WithDefault()) fillDefaultUnmarshaler = mapping.NewUnmarshaler(jsonTagKey, mapping.WithDefault())
loaders = map[string]func([]byte, any) error{ loaders = map[string]func([]byte, interface{}) error{
".json": LoadFromJsonBytes, ".json": LoadFromJsonBytes,
".toml": LoadFromTomlBytes, ".toml": LoadFromTomlBytes,
".yaml": LoadFromYamlBytes, ".yaml": LoadFromYamlBytes,
@@ -37,12 +34,12 @@ type fieldInfo struct {
// FillDefault fills the default values for the given v, // FillDefault fills the default values for the given v,
// and the premise is that the value of v must be guaranteed to be empty. // and the premise is that the value of v must be guaranteed to be empty.
func FillDefault(v any) error { func FillDefault(v interface{}) error {
return fillDefaultUnmarshaler.Unmarshal(map[string]any{}, v) return fillDefaultUnmarshaler.Unmarshal(map[string]interface{}{}, v)
} }
// Load loads config into v from file, .json, .yaml and .yml are acceptable. // Load loads config into v from file, .json, .yaml and .yml are acceptable.
func Load(file string, v any, opts ...Option) error { func Load(file string, v interface{}, opts ...Option) error {
content, err := os.ReadFile(file) content, err := os.ReadFile(file)
if err != nil { if err != nil {
return err return err
@@ -62,49 +59,40 @@ func Load(file string, v any, opts ...Option) error {
return loader([]byte(os.ExpandEnv(string(content))), v) return loader([]byte(os.ExpandEnv(string(content))), v)
} }
if err = loader(content, v); err != nil { return loader(content, v)
return err
}
return validate(v)
} }
// LoadConfig loads config into v from file, .json, .yaml and .yml are acceptable. // LoadConfig loads config into v from file, .json, .yaml and .yml are acceptable.
// Deprecated: use Load instead. // Deprecated: use Load instead.
func LoadConfig(file string, v any, opts ...Option) error { func LoadConfig(file string, v interface{}, opts ...Option) error {
return Load(file, v, opts...) return Load(file, v, opts...)
} }
// LoadFromJsonBytes loads config into v from content json bytes. // LoadFromJsonBytes loads config into v from content json bytes.
func LoadFromJsonBytes(content []byte, v any) error { func LoadFromJsonBytes(content []byte, v interface{}) error {
info, err := buildFieldsInfo(reflect.TypeOf(v), "") info, err := buildFieldsInfo(reflect.TypeOf(v))
if err != nil { if err != nil {
return err return err
} }
var m map[string]any var m map[string]interface{}
if err = jsonx.Unmarshal(content, &m); err != nil { if err := jsonx.Unmarshal(content, &m); err != nil {
return err return err
} }
lowerCaseKeyMap := toLowerCaseKeyMap(m, info) lowerCaseKeyMap := toLowerCaseKeyMap(m, info)
if err = mapping.UnmarshalJsonMap(lowerCaseKeyMap, v, return mapping.UnmarshalJsonMap(lowerCaseKeyMap, v, mapping.WithCanonicalKeyFunc(toLowerCase))
mapping.WithCanonicalKeyFunc(toLowerCase)); err != nil {
return err
}
return validate(v)
} }
// LoadConfigFromJsonBytes loads config into v from content json bytes. // LoadConfigFromJsonBytes loads config into v from content json bytes.
// Deprecated: use LoadFromJsonBytes instead. // Deprecated: use LoadFromJsonBytes instead.
func LoadConfigFromJsonBytes(content []byte, v any) error { func LoadConfigFromJsonBytes(content []byte, v interface{}) error {
return LoadFromJsonBytes(content, v) return LoadFromJsonBytes(content, v)
} }
// LoadFromTomlBytes loads config into v from content toml bytes. // LoadFromTomlBytes loads config into v from content toml bytes.
func LoadFromTomlBytes(content []byte, v any) error { func LoadFromTomlBytes(content []byte, v interface{}) error {
b, err := encoding.TomlToJson(content) b, err := encoding.TomlToJson(content)
if err != nil { if err != nil {
return err return err
@@ -114,7 +102,7 @@ func LoadFromTomlBytes(content []byte, v any) error {
} }
// LoadFromYamlBytes loads config into v from content yaml bytes. // LoadFromYamlBytes loads config into v from content yaml bytes.
func LoadFromYamlBytes(content []byte, v any) error { func LoadFromYamlBytes(content []byte, v interface{}) error {
b, err := encoding.YamlToJson(content) b, err := encoding.YamlToJson(content)
if err != nil { if err != nil {
return err return err
@@ -125,24 +113,24 @@ func LoadFromYamlBytes(content []byte, v any) error {
// LoadConfigFromYamlBytes loads config into v from content yaml bytes. // LoadConfigFromYamlBytes loads config into v from content yaml bytes.
// Deprecated: use LoadFromYamlBytes instead. // Deprecated: use LoadFromYamlBytes instead.
func LoadConfigFromYamlBytes(content []byte, v any) error { func LoadConfigFromYamlBytes(content []byte, v interface{}) error {
return LoadFromYamlBytes(content, v) return LoadFromYamlBytes(content, v)
} }
// MustLoad loads config into v from path, exits on error. // MustLoad loads config into v from path, exits on error.
func MustLoad(path string, v any, opts ...Option) { func MustLoad(path string, v interface{}, opts ...Option) {
if err := Load(path, v, opts...); err != nil { if err := Load(path, v, opts...); err != nil {
log.Fatalf("error: config file %s, %s", path, err.Error()) log.Fatalf("error: config file %s, %s", path, err.Error())
} }
} }
func addOrMergeFields(info *fieldInfo, key string, child *fieldInfo, fullName string) error { func addOrMergeFields(info *fieldInfo, key string, child *fieldInfo) error {
if prev, ok := info.children[key]; ok { if prev, ok := info.children[key]; ok {
if child.mapField != nil { if child.mapField != nil {
return newConflictKeyError(fullName) return newDupKeyError(key)
} }
if err := mergeFields(prev, child.children, fullName); err != nil { if err := mergeFields(prev, key, child.children); err != nil {
return err return err
} }
} else { } else {
@@ -152,27 +140,27 @@ func addOrMergeFields(info *fieldInfo, key string, child *fieldInfo, fullName st
return nil return nil
} }
func buildAnonymousFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type, fullName string) error { func buildAnonymousFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type) error {
switch ft.Kind() { switch ft.Kind() {
case reflect.Struct: case reflect.Struct:
fields, err := buildFieldsInfo(ft, fullName) fields, err := buildFieldsInfo(ft)
if err != nil { if err != nil {
return err return err
} }
for k, v := range fields.children { for k, v := range fields.children {
if err = addOrMergeFields(info, k, v, fullName); err != nil { if err = addOrMergeFields(info, k, v); err != nil {
return err return err
} }
} }
case reflect.Map: case reflect.Map:
elemField, err := buildFieldsInfo(mapping.Deref(ft.Elem()), fullName) elemField, err := buildFieldsInfo(mapping.Deref(ft.Elem()))
if err != nil { if err != nil {
return err return err
} }
if _, ok := info.children[lowerCaseName]; ok { if _, ok := info.children[lowerCaseName]; ok {
return newConflictKeyError(fullName) return newDupKeyError(lowerCaseName)
} }
info.children[lowerCaseName] = &fieldInfo{ info.children[lowerCaseName] = &fieldInfo{
@@ -181,7 +169,7 @@ func buildAnonymousFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.T
} }
default: default:
if _, ok := info.children[lowerCaseName]; ok { if _, ok := info.children[lowerCaseName]; ok {
return newConflictKeyError(fullName) return newDupKeyError(lowerCaseName)
} }
info.children[lowerCaseName] = &fieldInfo{ info.children[lowerCaseName] = &fieldInfo{
@@ -192,16 +180,16 @@ func buildAnonymousFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.T
return nil return nil
} }
func buildFieldsInfo(tp reflect.Type, fullName string) (*fieldInfo, error) { func buildFieldsInfo(tp reflect.Type) (*fieldInfo, error) {
tp = mapping.Deref(tp) tp = mapping.Deref(tp)
switch tp.Kind() { switch tp.Kind() {
case reflect.Struct: case reflect.Struct:
return buildStructFieldsInfo(tp, fullName) return buildStructFieldsInfo(tp)
case reflect.Array, reflect.Slice, reflect.Map: case reflect.Array, reflect.Slice:
return buildFieldsInfo(mapping.Deref(tp.Elem()), fullName) return buildFieldsInfo(mapping.Deref(tp.Elem()))
case reflect.Chan, reflect.Func: case reflect.Chan, reflect.Func:
return nil, fmt.Errorf("unsupported type: %s, fullName: %s", tp.Kind(), fullName) return nil, fmt.Errorf("unsupported type: %s", tp.Kind())
default: default:
return &fieldInfo{ return &fieldInfo{
children: make(map[string]*fieldInfo), children: make(map[string]*fieldInfo),
@@ -209,23 +197,23 @@ func buildFieldsInfo(tp reflect.Type, fullName string) (*fieldInfo, error) {
} }
} }
func buildNamedFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type, fullName string) error { func buildNamedFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type) error {
var finfo *fieldInfo var finfo *fieldInfo
var err error var err error
switch ft.Kind() { switch ft.Kind() {
case reflect.Struct: case reflect.Struct:
finfo, err = buildFieldsInfo(ft, fullName) finfo, err = buildFieldsInfo(ft)
if err != nil { if err != nil {
return err return err
} }
case reflect.Array, reflect.Slice: case reflect.Array, reflect.Slice:
finfo, err = buildFieldsInfo(ft.Elem(), fullName) finfo, err = buildFieldsInfo(ft.Elem())
if err != nil { if err != nil {
return err return err
} }
case reflect.Map: case reflect.Map:
elemInfo, err := buildFieldsInfo(mapping.Deref(ft.Elem()), fullName) elemInfo, err := buildFieldsInfo(mapping.Deref(ft.Elem()))
if err != nil { if err != nil {
return err return err
} }
@@ -235,37 +223,31 @@ func buildNamedFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type,
mapField: elemInfo, mapField: elemInfo,
} }
default: default:
finfo, err = buildFieldsInfo(ft, fullName) finfo, err = buildFieldsInfo(ft)
if err != nil { if err != nil {
return err return err
} }
} }
return addOrMergeFields(info, lowerCaseName, finfo, fullName) return addOrMergeFields(info, lowerCaseName, finfo)
} }
func buildStructFieldsInfo(tp reflect.Type, fullName string) (*fieldInfo, error) { func buildStructFieldsInfo(tp reflect.Type) (*fieldInfo, error) {
info := &fieldInfo{ info := &fieldInfo{
children: make(map[string]*fieldInfo), children: make(map[string]*fieldInfo),
} }
for i := 0; i < tp.NumField(); i++ { for i := 0; i < tp.NumField(); i++ {
field := tp.Field(i) field := tp.Field(i)
if !field.IsExported() { name := field.Name
continue
}
name := getTagName(field)
lowerCaseName := toLowerCase(name) lowerCaseName := toLowerCase(name)
ft := mapping.Deref(field.Type) ft := mapping.Deref(field.Type)
// flatten anonymous fields // flatten anonymous fields
if field.Anonymous { if field.Anonymous {
if err := buildAnonymousFieldInfo(info, lowerCaseName, ft, if err := buildAnonymousFieldInfo(info, lowerCaseName, ft); err != nil {
getFullName(fullName, lowerCaseName)); err != nil {
return nil, err return nil, err
} }
} else if err := buildNamedFieldInfo(info, lowerCaseName, ft, } else if err := buildNamedFieldInfo(info, lowerCaseName, ft); err != nil {
getFullName(fullName, lowerCaseName)); err != nil {
return nil, err return nil, err
} }
} }
@@ -273,32 +255,15 @@ func buildStructFieldsInfo(tp reflect.Type, fullName string) (*fieldInfo, error)
return info, nil return info, nil
} }
// getTagName get the tag name of the given field, if no tag name, use file.Name. func mergeFields(prev *fieldInfo, key string, children map[string]*fieldInfo) error {
// field.Name is returned on tags like `json:""` and `json:",optional"`.
func getTagName(field reflect.StructField) string {
if tag, ok := field.Tag.Lookup(jsonTagKey); ok {
if pos := strings.IndexByte(tag, jsonTagSep); pos >= 0 {
tag = tag[:pos]
}
tag = strings.TrimSpace(tag)
if len(tag) > 0 {
return tag
}
}
return field.Name
}
func mergeFields(prev *fieldInfo, children map[string]*fieldInfo, fullName string) error {
if len(prev.children) == 0 || len(children) == 0 { if len(prev.children) == 0 || len(children) == 0 {
return newConflictKeyError(fullName) return newDupKeyError(key)
} }
// merge fields // merge fields
for k, v := range children { for k, v := range children {
if _, ok := prev.children[k]; ok { if _, ok := prev.children[k]; ok {
return newConflictKeyError(fullName) return newDupKeyError(k)
} }
prev.children[k] = v prev.children[k] = v
@@ -311,12 +276,12 @@ func toLowerCase(s string) string {
return strings.ToLower(s) return strings.ToLower(s)
} }
func toLowerCaseInterface(v any, info *fieldInfo) any { func toLowerCaseInterface(v interface{}, info *fieldInfo) interface{} {
switch vv := v.(type) { switch vv := v.(type) {
case map[string]any: case map[string]interface{}:
return toLowerCaseKeyMap(vv, info) return toLowerCaseKeyMap(vv, info)
case []any: case []interface{}:
arr := make([]any, 0, len(vv)) var arr []interface{}
for _, vvv := range vv { for _, vvv := range vv {
arr = append(arr, toLowerCaseInterface(vvv, info)) arr = append(arr, toLowerCaseInterface(vvv, info))
} }
@@ -326,8 +291,8 @@ func toLowerCaseInterface(v any, info *fieldInfo) any {
} }
} }
func toLowerCaseKeyMap(m map[string]any, info *fieldInfo) map[string]any { func toLowerCaseKeyMap(m map[string]interface{}, info *fieldInfo) map[string]interface{} {
res := make(map[string]any) res := make(map[string]interface{})
for k, v := range m { for k, v := range m {
ti, ok := info.children[k] ti, ok := info.children[k]
@@ -341,8 +306,6 @@ func toLowerCaseKeyMap(m map[string]any, info *fieldInfo) map[string]any {
res[lk] = toLowerCaseInterface(v, ti) res[lk] = toLowerCaseInterface(v, ti)
} else if info.mapField != nil { } else if info.mapField != nil {
res[k] = toLowerCaseInterface(v, info.mapField) res[k] = toLowerCaseInterface(v, info.mapField)
} else if vv, ok := v.(map[string]any); ok {
res[k] = toLowerCaseKeyMap(vv, info)
} else { } else {
res[k] = v res[k] = v
} }
@@ -351,22 +314,14 @@ func toLowerCaseKeyMap(m map[string]any, info *fieldInfo) map[string]any {
return res return res
} }
type conflictKeyError struct { type dupKeyError struct {
key string key string
} }
func newConflictKeyError(key string) conflictKeyError { func newDupKeyError(key string) dupKeyError {
return conflictKeyError{key: key} return dupKeyError{key: key}
} }
func (e conflictKeyError) Error() string { func (e dupKeyError) Error() string {
return fmt.Sprintf("conflict key %s, pay attention to anonymous fields", e.key) return fmt.Sprintf("duplicated key %s", e.key)
}
func getFullName(parent, child string) string {
if len(parent) == 0 {
return child
}
return strings.Join([]string{parent, child}, ".")
} }

View File

@@ -1,9 +1,7 @@
package conf package conf
import ( import (
"errors"
"os" "os"
"reflect"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -11,7 +9,7 @@ import (
"github.com/zeromicro/go-zero/core/hash" "github.com/zeromicro/go-zero/core/hash"
) )
var dupErr conflictKeyError var dupErr dupKeyError
func TestLoadConfig_notExists(t *testing.T) { func TestLoadConfig_notExists(t *testing.T) {
assert.NotNil(t, Load("not_a_file", nil)) assert.NotNil(t, Load("not_a_file", nil))
@@ -36,13 +34,14 @@ func TestConfigJson(t *testing.T) {
"c": "${FOO}", "c": "${FOO}",
"d": "abcd!@#$112" "d": "abcd!@#$112"
}` }`
t.Setenv("FOO", "2")
for _, test := range tests { for _, test := range tests {
test := test test := test
t.Run(test, func(t *testing.T) { t.Run(test, func(t *testing.T) {
tmpfile, err := createTempFile(t, test, text) os.Setenv("FOO", "2")
defer os.Unsetenv("FOO")
tmpfile, err := createTempFile(test, text)
assert.Nil(t, err) assert.Nil(t, err)
defer os.Remove(tmpfile)
var val struct { var val struct {
A string `json:"a"` A string `json:"a"`
@@ -81,9 +80,11 @@ b = 1
c = "${FOO}" c = "${FOO}"
d = "abcd!@#$112" d = "abcd!@#$112"
` `
t.Setenv("FOO", "2") os.Setenv("FOO", "2")
tmpfile, err := createTempFile(t, ".toml", text) defer os.Unsetenv("FOO")
tmpfile, err := createTempFile(".toml", text)
assert.Nil(t, err) assert.Nil(t, err)
defer os.Remove(tmpfile)
var val struct { var val struct {
A string `json:"a"` A string `json:"a"`
@@ -104,8 +105,9 @@ b = 1
c = "FOO" c = "FOO"
d = "abcd" d = "abcd"
` `
tmpfile, err := createTempFile(t, ".toml", text) tmpfile, err := createTempFile(".toml", text)
assert.Nil(t, err) assert.Nil(t, err)
defer os.Remove(tmpfile)
var val struct { var val struct {
A string `json:"a"` A string `json:"a"`
@@ -121,23 +123,6 @@ d = "abcd"
} }
} }
func TestConfigWithLower(t *testing.T) {
text := `a = "foo"
b = 1
`
tmpfile, err := createTempFile(t, ".toml", text)
assert.Nil(t, err)
var val struct {
A string `json:"a"`
b int
}
if assert.NoError(t, Load(tmpfile, &val)) {
assert.Equal(t, "foo", val.A)
assert.Equal(t, 0, val.b)
}
}
func TestConfigJsonCanonical(t *testing.T) { func TestConfigJsonCanonical(t *testing.T) {
text := []byte(`{"a": "foo", "B": "bar"}`) text := []byte(`{"a": "foo", "B": "bar"}`)
@@ -203,9 +188,11 @@ b = 1
c = "${FOO}" c = "${FOO}"
d = "abcd!@#112" d = "abcd!@#112"
` `
t.Setenv("FOO", "2") os.Setenv("FOO", "2")
tmpfile, err := createTempFile(t, ".toml", text) defer os.Unsetenv("FOO")
tmpfile, err := createTempFile(".toml", text)
assert.Nil(t, err) assert.Nil(t, err)
defer os.Remove(tmpfile)
var val struct { var val struct {
A string `json:"a"` A string `json:"a"`
@@ -233,12 +220,14 @@ func TestConfigJsonEnv(t *testing.T) {
"c": "${FOO}", "c": "${FOO}",
"d": "abcd!@#$a12 3" "d": "abcd!@#$a12 3"
}` }`
t.Setenv("FOO", "2")
for _, test := range tests { for _, test := range tests {
test := test test := test
t.Run(test, func(t *testing.T) { t.Run(test, func(t *testing.T) {
tmpfile, err := createTempFile(t, test, text) os.Setenv("FOO", "2")
defer os.Unsetenv("FOO")
tmpfile, err := createTempFile(test, text)
assert.Nil(t, err) assert.Nil(t, err)
defer os.Remove(tmpfile)
var val struct { var val struct {
A string `json:"a"` A string `json:"a"`
@@ -643,7 +632,7 @@ func Test_FieldOverwrite(t *testing.T) {
Name2 *string Name2 *string
} }
validate := func(val any) { validate := func(val interface{}) {
input := []byte(`{"Name": "hello", "Name2": "world"}`) input := []byte(`{"Name": "hello", "Name2": "world"}`)
assert.NoError(t, LoadFromJsonBytes(input, val)) assert.NoError(t, LoadFromJsonBytes(input, val))
} }
@@ -679,11 +668,11 @@ func Test_FieldOverwrite(t *testing.T) {
Name *string Name *string
} }
validate := func(val any) { validate := func(val interface{}) {
input := []byte(`{"Name": "hello"}`) input := []byte(`{"Name": "hello"}`)
err := LoadFromJsonBytes(input, val) err := LoadFromJsonBytes(input, val)
assert.ErrorAs(t, err, &dupErr) assert.ErrorAs(t, err, &dupErr)
assert.Equal(t, newConflictKeyError("name").Error(), err.Error()) assert.Equal(t, newDupKeyError("name").Error(), err.Error())
} }
validate(&St1{}) validate(&St1{})
@@ -722,11 +711,11 @@ func Test_FieldOverwrite(t *testing.T) {
Name *int Name *int
} }
validate := func(val any) { validate := func(val interface{}) {
input := []byte(`{"Name": "hello"}`) input := []byte(`{"Name": "hello"}`)
err := LoadFromJsonBytes(input, val) err := LoadFromJsonBytes(input, val)
assert.ErrorAs(t, err, &dupErr) assert.ErrorAs(t, err, &dupErr)
assert.Error(t, err) assert.Equal(t, newDupKeyError("name").Error(), err.Error())
} }
validate(&St0{}) validate(&St0{})
@@ -1033,22 +1022,22 @@ func TestLoadNamedFieldOverwritten(t *testing.T) {
}) })
} }
func TestLoadLowerMemberShouldNotConflict(t *testing.T) { func createTempFile(ext, text string) (string, error) {
type ( tmpfile, err := os.CreateTemp(os.TempDir(), hash.Md5Hex([]byte(text))+"*"+ext)
Redis struct { if err != nil {
db uint return "", err
} }
Config struct { if err := os.WriteFile(tmpfile.Name(), []byte(text), os.ModeTemporary); err != nil {
db uint return "", err
Redis }
}
)
var c Config filename := tmpfile.Name()
assert.NoError(t, LoadFromJsonBytes([]byte(`{}`), &c)) if err = tmpfile.Close(); err != nil {
assert.Zero(t, c.db) return "", err
assert.Zero(t, c.Redis.db) }
return filename, nil
} }
func TestFillDefaultUnmarshal(t *testing.T) { func TestFillDefaultUnmarshal(t *testing.T) {
@@ -1090,7 +1079,7 @@ func TestFillDefaultUnmarshal(t *testing.T) {
assert.Equal(t, st.C, "c") assert.Equal(t, st.C, "c")
}) })
t.Run("has value", func(t *testing.T) { t.Run("has vaue", func(t *testing.T) {
type St struct { type St struct {
A string `json:",default=a"` A string `json:",default=a"`
B string B string
@@ -1102,278 +1091,3 @@ func TestFillDefaultUnmarshal(t *testing.T) {
assert.Error(t, err) assert.Error(t, err)
}) })
} }
func TestConfigWithJsonTag(t *testing.T) {
t.Run("map with value", func(t *testing.T) {
var input = []byte(`[Value]
[Value.first]
Email = "foo"
[Value.second]
Email = "bar"`)
type Value struct {
Email string
}
type Config struct {
ValueMap map[string]Value `json:"Value"`
}
var c Config
if assert.NoError(t, LoadFromTomlBytes(input, &c)) {
assert.Len(t, c.ValueMap, 2)
}
})
t.Run("map with ptr value", func(t *testing.T) {
var input = []byte(`[Value]
[Value.first]
Email = "foo"
[Value.second]
Email = "bar"`)
type Value struct {
Email string
}
type Config struct {
ValueMap map[string]*Value `json:"Value"`
}
var c Config
if assert.NoError(t, LoadFromTomlBytes(input, &c)) {
assert.Len(t, c.ValueMap, 2)
}
})
t.Run("map with optional", func(t *testing.T) {
var input = []byte(`[Value]
[Value.first]
Email = "foo"
[Value.second]
Email = "bar"`)
type Value struct {
Email string
}
type Config struct {
Value map[string]Value `json:",optional"`
}
var c Config
if assert.NoError(t, LoadFromTomlBytes(input, &c)) {
assert.Len(t, c.Value, 2)
}
})
t.Run("map with empty tag", func(t *testing.T) {
var input = []byte(`[Value]
[Value.first]
Email = "foo"
[Value.second]
Email = "bar"`)
type Value struct {
Email string
}
type Config struct {
Value map[string]Value `json:" "`
}
var c Config
if assert.NoError(t, LoadFromTomlBytes(input, &c)) {
assert.Len(t, c.Value, 2)
}
})
t.Run("multi layer map", func(t *testing.T) {
type Value struct {
User struct {
Name string
}
}
type Config struct {
Value map[string]map[string]Value
}
var input = []byte(`
[Value.first.User1.User]
Name = "foo"
[Value.second.User2.User]
Name = "bar"
`)
var c Config
if assert.NoError(t, LoadFromTomlBytes(input, &c)) {
assert.Len(t, c.Value, 2)
}
})
}
func Test_LoadBadConfig(t *testing.T) {
type Config struct {
Name string `json:"name,options=foo|bar"`
}
file, err := createTempFile(t, ".json", `{"name": "baz"}`)
assert.NoError(t, err)
var c Config
err = Load(file, &c)
assert.Error(t, err)
}
func Test_getFullName(t *testing.T) {
assert.Equal(t, "a.b", getFullName("a", "b"))
assert.Equal(t, "a", getFullName("", "a"))
}
func TestValidate(t *testing.T) {
t.Run("normal config", func(t *testing.T) {
var c mockConfig
err := LoadFromJsonBytes([]byte(`{"val": "hello", "number": 8}`), &c)
assert.NoError(t, err)
})
t.Run("error no int", func(t *testing.T) {
var c mockConfig
err := LoadFromJsonBytes([]byte(`{"val": "hello"}`), &c)
assert.Error(t, err)
})
t.Run("error no string", func(t *testing.T) {
var c mockConfig
err := LoadFromJsonBytes([]byte(`{"number": 8}`), &c)
assert.Error(t, err)
})
}
func Test_buildFieldsInfo(t *testing.T) {
type ParentSt struct {
Name string
M map[string]int
}
tests := []struct {
name string
t reflect.Type
ok bool
containsKey string
}{
{
name: "normal",
t: reflect.TypeOf(struct{ A string }{}),
ok: true,
},
{
name: "struct anonymous",
t: reflect.TypeOf(struct {
ParentSt
Name string
}{}),
ok: false,
containsKey: newConflictKeyError("name").Error(),
},
{
name: "struct ptr anonymous",
t: reflect.TypeOf(struct {
*ParentSt
Name string
}{}),
ok: false,
containsKey: newConflictKeyError("name").Error(),
},
{
name: "more struct anonymous",
t: reflect.TypeOf(struct {
Value struct {
ParentSt
Name string
}
}{}),
ok: false,
containsKey: newConflictKeyError("value.name").Error(),
},
{
name: "map anonymous",
t: reflect.TypeOf(struct {
ParentSt
M string
}{}),
ok: false,
containsKey: newConflictKeyError("m").Error(),
},
{
name: "map more anonymous",
t: reflect.TypeOf(struct {
Value struct {
ParentSt
M string
}
}{}),
ok: false,
containsKey: newConflictKeyError("value.m").Error(),
},
{
name: "struct slice anonymous",
t: reflect.TypeOf([]struct {
ParentSt
Name string
}{}),
ok: false,
containsKey: newConflictKeyError("name").Error(),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := buildFieldsInfo(tt.t, "")
if tt.ok {
assert.NoError(t, err)
} else {
assert.Error(t, err)
assert.Equal(t, err.Error(), tt.containsKey)
}
})
}
}
func createTempFile(t *testing.T, ext, text string) (string, error) {
tmpFile, err := os.CreateTemp(os.TempDir(), hash.Md5Hex([]byte(text))+"*"+ext)
if err != nil {
return "", err
}
if err = os.WriteFile(tmpFile.Name(), []byte(text), os.ModeTemporary); err != nil {
return "", err
}
filename := tmpFile.Name()
if err = tmpFile.Close(); err != nil {
return "", err
}
t.Cleanup(func() {
_ = os.Remove(filename)
})
return filename, nil
}
type mockConfig struct {
Val string
Number int
}
func (m mockConfig) Validate() error {
if len(m.Val) == 0 {
return errors.New("val is empty")
}
if m.Number == 0 {
return errors.New("number is zero")
}
return nil
}

View File

@@ -45,7 +45,8 @@ func TestPropertiesEnv(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
defer os.Remove(tmpfile) defer os.Remove(tmpfile)
t.Setenv("FOO", "2") os.Setenv("FOO", "2")
defer os.Unsetenv("FOO")
props, err := LoadProperties(tmpfile, UseEnv()) props, err := LoadProperties(tmpfile, UseEnv())
assert.Nil(t, err) assert.Nil(t, err)

View File

@@ -12,7 +12,7 @@ type RestfulConf struct {
MaxConns int `json:",default=10000"` MaxConns int `json:",default=10000"`
MaxBytes int64 `json:",default=1048576"` MaxBytes int64 `json:",default=1048576"`
Timeout time.Duration `json:",default=3s"` Timeout time.Duration `json:",default=3s"`
CpuThreshold int64 `json:",default=900,range=[0:1000)"` CpuThreshold int64 `json:",default=900,range=[0:1000]"`
} }
``` ```

View File

@@ -1,12 +0,0 @@
package conf
import "github.com/zeromicro/go-zero/core/validation"
// validate validates the value if it implements the Validator interface.
func validate(v any) error {
if val, ok := v.(validation.Validator); ok {
return val.Validate()
}
return nil
}

View File

@@ -1,81 +0,0 @@
package conf
import (
"errors"
"testing"
"github.com/stretchr/testify/assert"
)
type mockType int
func (m mockType) Validate() error {
if m < 10 {
return errors.New("invalid value")
}
return nil
}
type anotherMockType int
func Test_validate(t *testing.T) {
tests := []struct {
name string
v any
wantErr bool
}{
{
name: "invalid",
v: mockType(5),
wantErr: true,
},
{
name: "valid",
v: mockType(10),
wantErr: false,
},
{
name: "not validator",
v: anotherMockType(5),
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validate(tt.v)
assert.Equal(t, tt.wantErr, err != nil)
})
}
}
type mockVal struct {
}
func (m mockVal) Validate() error {
return errors.New("invalid value")
}
func Test_validateValPtr(t *testing.T) {
tests := []struct {
name string
v any
wantErr bool
}{
{
name: "invalid",
v: mockVal{},
},
{
name: "invalid value",
v: &mockVal{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Error(t, validate(tt.v))
})
}
}

View File

@@ -1,200 +0,0 @@
package configurator
import (
"errors"
"fmt"
"reflect"
"strings"
"sync"
"sync/atomic"
"github.com/zeromicro/go-zero/core/configcenter/subscriber"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/mapping"
"github.com/zeromicro/go-zero/core/threading"
)
var (
errEmptyConfig = errors.New("empty config value")
errMissingUnmarshalerType = errors.New("missing unmarshaler type")
)
// Configurator is the interface for configuration center.
type Configurator[T any] interface {
// GetConfig returns the subscription value.
GetConfig() (T, error)
// AddListener adds a listener to the subscriber.
AddListener(listener func())
}
type (
// Config is the configuration for Configurator.
Config struct {
// Type is the value type, yaml, json or toml.
Type string `json:",default=yaml,options=[yaml,json,toml]"`
// Log is the flag to control logging.
Log bool `json:",default=true"`
}
configCenter[T any] struct {
conf Config
unmarshaler LoaderFn
subscriber subscriber.Subscriber
listeners []func()
lock sync.Mutex
snapshot atomic.Value
}
value[T any] struct {
data string
marshalData T
err error
}
)
// Configurator is the interface for configuration center.
var _ Configurator[any] = (*configCenter[any])(nil)
// MustNewConfigCenter returns a Configurator, exits on errors.
func MustNewConfigCenter[T any](c Config, subscriber subscriber.Subscriber) Configurator[T] {
cc, err := NewConfigCenter[T](c, subscriber)
logx.Must(err)
return cc
}
// NewConfigCenter returns a Configurator.
func NewConfigCenter[T any](c Config, subscriber subscriber.Subscriber) (Configurator[T], error) {
unmarshaler, ok := Unmarshaler(strings.ToLower(c.Type))
if !ok {
return nil, fmt.Errorf("unknown format: %s", c.Type)
}
cc := &configCenter[T]{
conf: c,
unmarshaler: unmarshaler,
subscriber: subscriber,
}
if err := cc.loadConfig(); err != nil {
return nil, err
}
if err := cc.subscriber.AddListener(cc.onChange); err != nil {
return nil, err
}
if _, err := cc.GetConfig(); err != nil {
return nil, err
}
return cc, nil
}
// AddListener adds listener to s.
func (c *configCenter[T]) AddListener(listener func()) {
c.lock.Lock()
defer c.lock.Unlock()
c.listeners = append(c.listeners, listener)
}
// GetConfig return structured config.
func (c *configCenter[T]) GetConfig() (T, error) {
v := c.value()
if v == nil || len(v.data) == 0 {
var empty T
return empty, errEmptyConfig
}
return v.marshalData, v.err
}
// Value returns the subscription value.
func (c *configCenter[T]) Value() string {
v := c.value()
if v == nil {
return ""
}
return v.data
}
func (c *configCenter[T]) loadConfig() error {
v, err := c.subscriber.Value()
if err != nil {
if c.conf.Log {
logx.Errorf("ConfigCenter loads changed configuration, error: %v", err)
}
return err
}
if c.conf.Log {
logx.Infof("ConfigCenter loads changed configuration, content [%s]", v)
}
c.snapshot.Store(c.genValue(v))
return nil
}
func (c *configCenter[T]) onChange() {
if err := c.loadConfig(); err != nil {
return
}
c.lock.Lock()
listeners := make([]func(), len(c.listeners))
copy(listeners, c.listeners)
c.lock.Unlock()
for _, l := range listeners {
threading.GoSafe(l)
}
}
func (c *configCenter[T]) value() *value[T] {
content := c.snapshot.Load()
if content == nil {
return nil
}
return content.(*value[T])
}
func (c *configCenter[T]) genValue(data string) *value[T] {
v := &value[T]{
data: data,
}
if len(data) == 0 {
return v
}
t := reflect.TypeOf(v.marshalData)
// if the type is nil, it means that the user has not set the type of the configuration.
if t == nil {
v.err = errMissingUnmarshalerType
return v
}
t = mapping.Deref(t)
switch t.Kind() {
case reflect.Struct, reflect.Array, reflect.Slice:
if err := c.unmarshaler([]byte(data), &v.marshalData); err != nil {
v.err = err
if c.conf.Log {
logx.Errorf("ConfigCenter unmarshal configuration failed, err: %+v, content [%s]",
err.Error(), data)
}
}
case reflect.String:
if str, ok := any(data).(T); ok {
v.marshalData = str
} else {
v.err = errMissingUnmarshalerType
}
default:
if c.conf.Log {
logx.Errorf("ConfigCenter unmarshal configuration missing unmarshaler for type: %s, content [%s]",
t.Kind(), data)
}
v.err = errMissingUnmarshalerType
}
return v
}

View File

@@ -1,233 +0,0 @@
package configurator
import (
"errors"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestNewConfigCenter(t *testing.T) {
_, err := NewConfigCenter[any](Config{
Log: true,
}, &mockSubscriber{})
assert.Error(t, err)
_, err = NewConfigCenter[any](Config{
Type: "json",
Log: true,
}, &mockSubscriber{})
assert.Error(t, err)
}
func TestConfigCenter_GetConfig(t *testing.T) {
mock := &mockSubscriber{}
type Data struct {
Name string `json:"name"`
}
mock.v = `{"name": "go-zero"}`
c1, err := NewConfigCenter[Data](Config{
Type: "json",
Log: true,
}, mock)
assert.NoError(t, err)
data, err := c1.GetConfig()
assert.NoError(t, err)
assert.Equal(t, "go-zero", data.Name)
mock.v = `{"name": "111"}`
c2, err := NewConfigCenter[Data](Config{Type: "json"}, mock)
assert.NoError(t, err)
mock.v = `{}`
c3, err := NewConfigCenter[string](Config{
Type: "json",
Log: true,
}, mock)
assert.NoError(t, err)
_, err = c3.GetConfig()
assert.NoError(t, err)
data, err = c2.GetConfig()
assert.NoError(t, err)
mock.lisErr = errors.New("mock error")
_, err = NewConfigCenter[Data](Config{
Type: "json",
Log: true,
}, mock)
assert.Error(t, err)
}
func TestConfigCenter_onChange(t *testing.T) {
mock := &mockSubscriber{}
type Data struct {
Name string `json:"name"`
}
mock.v = `{"name": "go-zero"}`
c1, err := NewConfigCenter[Data](Config{Type: "json", Log: true}, mock)
assert.NoError(t, err)
data, err := c1.GetConfig()
assert.NoError(t, err)
assert.Equal(t, "go-zero", data.Name)
mock.v = `{"name": "go-zero2"}`
mock.change()
data, err = c1.GetConfig()
assert.NoError(t, err)
assert.Equal(t, "go-zero2", data.Name)
mock.valErr = errors.New("mock error")
_, err = NewConfigCenter[Data](Config{Type: "json", Log: false}, mock)
assert.Error(t, err)
}
func TestConfigCenter_Value(t *testing.T) {
mock := &mockSubscriber{}
mock.v = "1234"
c, err := NewConfigCenter[string](Config{
Type: "json",
Log: true,
}, mock)
assert.NoError(t, err)
cc := c.(*configCenter[string])
assert.Equal(t, cc.Value(), "1234")
mock.valErr = errors.New("mock error")
_, err = NewConfigCenter[any](Config{
Type: "json",
Log: true,
}, mock)
assert.Error(t, err)
}
func TestConfigCenter_AddListener(t *testing.T) {
mock := &mockSubscriber{}
mock.v = "1234"
c, err := NewConfigCenter[string](Config{
Type: "json",
Log: true,
}, mock)
assert.NoError(t, err)
cc := c.(*configCenter[string])
var a, b int
var mutex sync.Mutex
cc.AddListener(func() {
mutex.Lock()
a = 1
mutex.Unlock()
})
cc.AddListener(func() {
mutex.Lock()
b = 2
mutex.Unlock()
})
assert.Equal(t, 2, len(cc.listeners))
mock.change()
time.Sleep(time.Millisecond * 100)
mutex.Lock()
assert.Equal(t, 1, a)
assert.Equal(t, 2, b)
mutex.Unlock()
}
func TestConfigCenter_genValue(t *testing.T) {
t.Run("data is empty", func(t *testing.T) {
c := &configCenter[string]{
unmarshaler: registry.unmarshalers["json"],
conf: Config{Log: true},
}
v := c.genValue("")
assert.Equal(t, "", v.data)
})
t.Run("invalid template type", func(t *testing.T) {
c := &configCenter[any]{
unmarshaler: registry.unmarshalers["json"],
conf: Config{Log: true},
}
v := c.genValue("xxxx")
assert.Equal(t, errMissingUnmarshalerType, v.err)
})
t.Run("unsupported template type", func(t *testing.T) {
c := &configCenter[int]{
unmarshaler: registry.unmarshalers["json"],
conf: Config{Log: true},
}
v := c.genValue("1")
assert.Equal(t, errMissingUnmarshalerType, v.err)
})
t.Run("supported template string type", func(t *testing.T) {
c := &configCenter[string]{
unmarshaler: registry.unmarshalers["json"],
conf: Config{Log: true},
}
v := c.genValue("12345")
assert.NoError(t, v.err)
assert.Equal(t, "12345", v.data)
})
t.Run("unmarshal fail", func(t *testing.T) {
c := &configCenter[struct {
Name string `json:"name"`
}]{
unmarshaler: registry.unmarshalers["json"],
conf: Config{Log: true},
}
v := c.genValue(`{"name":"new name}`)
assert.Equal(t, `{"name":"new name}`, v.data)
assert.Error(t, v.err)
})
t.Run("success", func(t *testing.T) {
c := &configCenter[struct {
Name string `json:"name"`
}]{
unmarshaler: registry.unmarshalers["json"],
conf: Config{Log: true},
}
v := c.genValue(`{"name":"new name"}`)
assert.Equal(t, `{"name":"new name"}`, v.data)
assert.Equal(t, "new name", v.marshalData.Name)
assert.NoError(t, v.err)
})
}
type mockSubscriber struct {
v string
lisErr, valErr error
listener func()
}
func (m *mockSubscriber) AddListener(listener func()) error {
m.listener = listener
return m.lisErr
}
func (m *mockSubscriber) Value() (string, error) {
return m.v, m.valErr
}
func (m *mockSubscriber) change() {
if m.listener != nil {
m.listener()
}
}

View File

@@ -1,67 +0,0 @@
package subscriber
import (
"github.com/zeromicro/go-zero/core/discov"
"github.com/zeromicro/go-zero/core/logx"
)
type (
// etcdSubscriber is a subscriber that subscribes to etcd.
etcdSubscriber struct {
*discov.Subscriber
}
// EtcdConf is the configuration for etcd.
EtcdConf = discov.EtcdConf
)
// MustNewEtcdSubscriber returns an etcd Subscriber, exits on errors.
func MustNewEtcdSubscriber(conf EtcdConf) Subscriber {
s, err := NewEtcdSubscriber(conf)
logx.Must(err)
return s
}
// NewEtcdSubscriber returns an etcd Subscriber.
func NewEtcdSubscriber(conf EtcdConf) (Subscriber, error) {
opts := buildSubOptions(conf)
s, err := discov.NewSubscriber(conf.Hosts, conf.Key, opts...)
if err != nil {
return nil, err
}
return &etcdSubscriber{Subscriber: s}, nil
}
// buildSubOptions constructs the options for creating a new etcd subscriber.
func buildSubOptions(conf EtcdConf) []discov.SubOption {
opts := []discov.SubOption{
discov.WithExactMatch(),
}
if len(conf.User) > 0 {
opts = append(opts, discov.WithSubEtcdAccount(conf.User, conf.Pass))
}
if len(conf.CertFile) > 0 || len(conf.CertKeyFile) > 0 || len(conf.CACertFile) > 0 {
opts = append(opts, discov.WithSubEtcdTLS(conf.CertFile, conf.CertKeyFile,
conf.CACertFile, conf.InsecureSkipVerify))
}
return opts
}
// AddListener adds a listener to the subscriber.
func (s *etcdSubscriber) AddListener(listener func()) error {
s.Subscriber.AddListener(listener)
return nil
}
// Value returns the value of the subscriber.
func (s *etcdSubscriber) Value() (string, error) {
vs := s.Subscriber.Values()
if len(vs) > 0 {
return vs[len(vs)-1], nil
}
return "", nil
}

View File

@@ -1,9 +0,0 @@
package subscriber
// Subscriber is the interface for configcenter subscribers.
type Subscriber interface {
// AddListener adds a listener to the subscriber.
AddListener(listener func()) error
// Value returns the value of the subscriber.
Value() (string, error)
}

View File

@@ -1,41 +0,0 @@
package configurator
import (
"sync"
"github.com/zeromicro/go-zero/core/conf"
)
var registry = &unmarshalerRegistry{
unmarshalers: map[string]LoaderFn{
"json": conf.LoadFromJsonBytes,
"toml": conf.LoadFromTomlBytes,
"yaml": conf.LoadFromYamlBytes,
},
}
type (
// LoaderFn is the function type for loading configuration.
LoaderFn func([]byte, any) error
// unmarshalerRegistry is the registry for unmarshalers.
unmarshalerRegistry struct {
unmarshalers map[string]LoaderFn
mu sync.RWMutex
}
)
// RegisterUnmarshaler registers an unmarshaler.
func RegisterUnmarshaler(name string, fn LoaderFn) {
registry.mu.Lock()
defer registry.mu.Unlock()
registry.unmarshalers[name] = fn
}
// Unmarshaler returns the unmarshaler by name.
func Unmarshaler(name string) (LoaderFn, bool) {
registry.mu.RLock()
defer registry.mu.RUnlock()
fn, ok := registry.unmarshalers[name]
return fn, ok
}

View File

@@ -1,28 +0,0 @@
package configurator
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestRegisterUnmarshaler(t *testing.T) {
RegisterUnmarshaler("test", func(data []byte, v interface{}) error {
return nil
})
_, ok := Unmarshaler("test")
assert.True(t, ok)
_, ok = Unmarshaler("test2")
assert.False(t, ok)
_, ok = Unmarshaler("json")
assert.True(t, ok)
_, ok = Unmarshaler("toml")
assert.True(t, ok)
_, ok = Unmarshaler("yaml")
assert.True(t, ok)
}

View File

@@ -14,13 +14,13 @@ type contextValuer struct {
context.Context context.Context
} }
func (cv contextValuer) Value(key string) (any, bool) { func (cv contextValuer) Value(key string) (interface{}, bool) {
v := cv.Context.Value(key) v := cv.Context.Value(key)
return v, v != nil return v, v != nil
} }
// For unmarshals ctx into v. // For unmarshals ctx into v.
func For(ctx context.Context, v any) error { func For(ctx context.Context, v interface{}) error {
return unmarshaler.UnmarshalValuer(contextValuer{ return unmarshaler.UnmarshalValuer(contextValuer{
Context: ctx, Context: ctx,
}, v) }, v)

View File

@@ -13,7 +13,6 @@ var (
type EtcdConf struct { type EtcdConf struct {
Hosts []string Hosts []string
Key string Key string
ID int64 `json:",optional"`
User string `json:",optional"` User string `json:",optional"`
Pass string `json:",optional"` Pass string `json:",optional"`
CertFile string `json:",optional"` CertFile string `json:",optional"`
@@ -27,11 +26,6 @@ func (c EtcdConf) HasAccount() bool {
return len(c.User) > 0 && len(c.Pass) > 0 return len(c.User) > 0 && len(c.Pass) > 0
} }
// HasID returns if ID provided.
func (c EtcdConf) HasID() bool {
return c.ID > 0
}
// HasTLS returns if TLS CertFile/CertKeyFile/CACertFile are provided. // HasTLS returns if TLS CertFile/CertKeyFile/CACertFile are provided.
func (c EtcdConf) HasTLS() bool { func (c EtcdConf) HasTLS() bool {
return len(c.CertFile) > 0 && len(c.CertKeyFile) > 0 && len(c.CACertFile) > 0 return len(c.CertFile) > 0 && len(c.CertKeyFile) > 0 && len(c.CACertFile) > 0

View File

@@ -80,90 +80,3 @@ func TestEtcdConf_HasAccount(t *testing.T) {
assert.Equal(t, test.hasAccount, test.EtcdConf.HasAccount()) assert.Equal(t, test.hasAccount, test.EtcdConf.HasAccount())
} }
} }
func TestEtcdConf_HasID(t *testing.T) {
tests := []struct {
EtcdConf
hasServerID bool
}{
{
EtcdConf: EtcdConf{
Hosts: []string{"any"},
ID: -1,
},
hasServerID: false,
},
{
EtcdConf: EtcdConf{
Hosts: []string{"any"},
ID: 0,
},
hasServerID: false,
},
{
EtcdConf: EtcdConf{
Hosts: []string{"any"},
ID: 10000,
},
hasServerID: true,
},
}
for _, test := range tests {
assert.Equal(t, test.hasServerID, test.EtcdConf.HasID())
}
}
func TestEtcdConf_HasTLS(t *testing.T) {
tests := []struct {
name string
conf EtcdConf
want bool
}{
{
name: "empty config",
conf: EtcdConf{},
want: false,
},
{
name: "missing CertFile",
conf: EtcdConf{
CertKeyFile: "key",
CACertFile: "ca",
},
want: false,
},
{
name: "missing CertKeyFile",
conf: EtcdConf{
CertFile: "cert",
CACertFile: "ca",
},
want: false,
},
{
name: "missing CACertFile",
conf: EtcdConf{
CertFile: "cert",
CertKeyFile: "key",
},
want: false,
},
{
name: "valid config",
conf: EtcdConf{
CertFile: "cert",
CertKeyFile: "key",
CACertFile: "ca",
},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.conf.HasTLS()
assert.Equal(t, tt.want, got)
})
}
}

View File

@@ -1,85 +1,12 @@
package internal package internal
import ( import (
"os"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/stringx" "github.com/zeromicro/go-zero/core/stringx"
) )
const (
certContent = `-----BEGIN CERTIFICATE-----
MIIDazCCAlOgAwIBAgIUEg9GVO2oaPn+YSmiqmFIuAo10WIwDQYJKoZIhvcNAQEM
BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAgFw0yMzAzMTExMzIxMjNaGA8yMTIz
MDIxNTEzMjEyM1owRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUx
ITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDCCASIwDQYJKoZIhvcN
AQEBBQADggEPADCCAQoCggEBALplXlWsIf0O/IgnIplmiZHKGnxyfyufyE2FBRNk
OofRqbKuPH8GNqbkvZm7N29fwTDAQ+mViAggCkDht4hOzoWJMA7KYJt8JnTSWL48
M1lcrpc9DL2gszC/JF/FGvyANbBtLklkZPFBGdHUX14pjrT937wqPtm+SqUHSvRT
B7bmwmm2drRcmhpVm98LSlV7uQ2EgnJgsLjBPITKUejLmVLHfgX0RwQ2xIpX9pS4
FCe1BTacwl2gGp7Mje7y4Mfv3o0ArJW6Tuwbjx59ZXwb1KIP71b7bT04AVS8ZeYO
UMLKKuB5UR9x9Rn6cLXOTWBpcMVyzDgrAFLZjnE9LPUolZMCAwEAAaNRME8wHwYD
VR0jBBgwFoAUeW8w8pmhncbRgTsl48k4/7wnfx8wCQYDVR0TBAIwADALBgNVHQ8E
BAMCBPAwFAYDVR0RBA0wC4IJbG9jYWxob3N0MA0GCSqGSIb3DQEBDAUAA4IBAQAI
y9xaoS88CLPBsX6mxfcTAFVfGNTRW9VN9Ng1cCnUR+YGoXGM/l+qP4f7p8ocdGwK
iYZErVTzXYIn+D27//wpY3klJk3gAnEUBT3QRkStBw7XnpbeZ2oPBK+cmDnCnZPS
BIF1wxPX7vIgaxs5Zsdqwk3qvZ4Djr2wP7LabNWTLSBKgQoUY45Liw6pffLwcGF9
UKlu54bvGze2SufISCR3ib+I+FLvqpvJhXToZWYb/pfI/HccuCL1oot1x8vx6DQy
U+TYxlZsKS5mdNxAX3dqEkEMsgEi+g/tzDPXJImfeCGGBhIOXLm8SRypiuGdEbc9
xkWYxRPegajuEZGvCqVs
-----END CERTIFICATE-----`
keyContent = `-----BEGIN RSA PRIVATE KEY-----
MIIEowIBAAKCAQEAumVeVawh/Q78iCcimWaJkcoafHJ/K5/ITYUFE2Q6h9Gpsq48
fwY2puS9mbs3b1/BMMBD6ZWICCAKQOG3iE7OhYkwDspgm3wmdNJYvjwzWVyulz0M
vaCzML8kX8Ua/IA1sG0uSWRk8UEZ0dRfXimOtP3fvCo+2b5KpQdK9FMHtubCabZ2
tFyaGlWb3wtKVXu5DYSCcmCwuME8hMpR6MuZUsd+BfRHBDbEilf2lLgUJ7UFNpzC
XaAansyN7vLgx+/ejQCslbpO7BuPHn1lfBvUog/vVvttPTgBVLxl5g5Qwsoq4HlR
H3H1Gfpwtc5NYGlwxXLMOCsAUtmOcT0s9SiVkwIDAQABAoIBAD5meTJNMgO55Kjg
ESExxpRcCIno+tHr5+6rvYtEXqPheOIsmmwb9Gfi4+Z3WpOaht5/Pz0Ppj6yGzyl
U//6AgGKb+BDuBvVcDpjwPnOxZIBCSHwejdxeQu0scSuA97MPS0XIAvJ5FEv7ijk
5Bht6SyGYURpECltHygoTNuGgGqmO+McCJRLE9L09lTBI6UQ/JQwWJqSr7wx6iPU
M1Ze/srIV+7cyEPu6i0DGjS1gSQKkX68Lqn1w6oE290O+OZvleO0gZ02fLDWCZke
aeD9+EU/Pw+rqm3H6o0szOFIpzhRp41FUdW9sybB3Yp3u7c/574E+04Z/e30LMKs
TCtE1QECgYEA3K7KIpw0NH2HXL5C3RHcLmr204xeBfS70riBQQuVUgYdmxak2ima
80RInskY8hRhSGTg0l+VYIH8cmjcUyqMSOELS5XfRH99r4QPiK8AguXg80T4VumY
W3Pf+zEC2ssgP/gYthV0g0Xj5m2QxktOF9tRw5nkg739ZR4dI9lm/iECgYEA2Dnf
uwEDGqHiQRF6/fh5BG/nGVMvrefkqx6WvTJQ3k/M/9WhxB+lr/8yH46TuS8N2b29
FoTf3Mr9T7pr/PWkOPzoY3P56nYbKU8xSwCim9xMzhBMzj8/N9ukJvXy27/VOz56
eQaKqnvdXNGtPJrIMDGHps2KKWlKLyAlapzjVTMCgYAA/W++tACv85g13EykfT4F
n0k4LbsGP9DP4zABQLIMyiY72eAncmRVjwrcW36XJ2xATOONTgx3gF3HjZzfaqNy
eD/6uNNllUTVEryXGmHgNHPL45VRnn6memCY2eFvZdXhM5W4y2PYaunY0MkDercA
+GTngbs6tBF88KOk04bYwQKBgFl68cRgsdkmnwwQYNaTKfmVGYzYaQXNzkqmWPko
xmCJo6tHzC7ubdG8iRCYHzfmahPuuj6EdGPZuSRyYFgJi5Ftz/nAN+84OxtIQ3zn
YWOgskQgaLh9YfsKsQ7Sf1NDOsnOnD5TX7UXl07fEpLe9vNCvAFiU8e5Y9LGudU5
4bYTAoGBAMdX3a3bXp4cZvXNBJ/QLVyxC6fP1Q4haCR1Od3m+T00Jth2IX2dk/fl
p6xiJT1av5JtYabv1dFKaXOS5s1kLGGuCCSKpkvFZm826aQ2AFm0XGqEQDLeei5b
A52Kpy/YJ+RkG4BTFtAooFq6DmA0cnoP6oPvG2h6XtDJwDTPInJb
-----END RSA PRIVATE KEY-----`
caContent = `-----BEGIN CERTIFICATE-----
MIIDbTCCAlWgAwIBAgIUBJvFoCowKich7MMfseJ+DYzzirowDQYJKoZIhvcNAQEM
BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAgFw0yMzAzMTExMzIxMDNaGA8yMTIz
MDIxNTEzMjEwM1owRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUx
ITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDCCASIwDQYJKoZIhvcN
AQEBBQADggEPADCCAQoCggEBAO4to2YMYj0bxgr2FCiweSTSFuPx33zSw2x/s9Wf
OR41bm2DFsyYT5f3sOIKlXZEdLmOKty2e3ho3yC0EyNpVHdykkkHT3aDI17quZax
kYi/URqqtl1Z08A22txolc04hAZisg2BypGi3vql81UW1t3zyloGnJoIAeXR9uca
ljP6Bk3bwsxoVBLi1JtHrO0hHLQaeHmKhAyrys06X0LRdn7Px48yRZlt6FaLSa8X
YiRM0G44bVy/h6BkoQjMYGwVmCVk6zjJ9U7ZPFqdnDMNxAfR+hjDnYodqdLDMTTR
1NPVrnEnNwFx0AMLvgt/ba/45vZCEAmSZnFXFAJJcM7ai9ECAwEAAaNTMFEwHQYD
VR0OBBYEFHlvMPKZoZ3G0YE7JePJOP+8J38fMB8GA1UdIwQYMBaAFHlvMPKZoZ3G
0YE7JePJOP+8J38fMA8GA1UdEwEB/wQFMAMBAf8wDQYJKoZIhvcNAQEMBQADggEB
AMX8dNulADOo9uQgBMyFb9TVra7iY0zZjzv4GY5XY7scd52n6CnfAPvYBBDnTr/O
BgNp5jaujb4+9u/2qhV3f9n+/3WOb2CmPehBgVSzlXqHeQ9lshmgwZPeem2T+8Tm
Nnc/xQnsUfCFszUDxpkr55+aLVM22j02RWqcZ4q7TAaVYL+kdFVMc8FoqG/0ro6A
BjE/Qn0Nn7ciX1VUjDt8l+k7ummPJTmzdi6i6E4AwO9dzrGNgGJ4aWL8cC6xYcIX
goVIRTFeONXSDno/oPjWHpIPt7L15heMpKBHNuzPkKx2YVqPHE5QZxWfS+Lzgx+Q
E2oTTM0rYKOZ8p6000mhvKI=
-----END CERTIFICATE-----`
)
func TestAccount(t *testing.T) { func TestAccount(t *testing.T) {
endpoints := []string{ endpoints := []string{
"192.168.0.2:2379", "192.168.0.2:2379",
@@ -105,34 +32,3 @@ func TestAccount(t *testing.T) {
assert.Equal(t, username, account.User) assert.Equal(t, username, account.User)
assert.Equal(t, anotherPassword, account.Pass) assert.Equal(t, anotherPassword, account.Pass)
} }
func TestTLSMethods(t *testing.T) {
certFile := createTempFile(t, []byte(certContent))
defer os.Remove(certFile)
keyFile := createTempFile(t, []byte(keyContent))
defer os.Remove(keyFile)
caFile := createTempFile(t, []byte(caContent))
defer os.Remove(caFile)
assert.NoError(t, AddTLS([]string{"foo"}, certFile, keyFile, caFile, false))
cfg, ok := GetTLS([]string{"foo"})
assert.True(t, ok)
assert.NotNil(t, cfg)
assert.Error(t, AddTLS([]string{"bar"}, "bad-file", keyFile, caFile, false))
assert.Error(t, AddTLS([]string{"bar"}, certFile, keyFile, "bad-file", false))
}
func createTempFile(t *testing.T, body []byte) string {
tmpFile, err := os.CreateTemp(os.TempDir(), "go-unit-*.tmp")
if err != nil {
t.Fatal(err)
}
tmpFile.Close()
if err = os.WriteFile(tmpFile.Name(), body, os.ModePerm); err != nil {
t.Fatal(err)
}
return tmpFile.Name()
}

View File

@@ -1,10 +1,5 @@
// Code generated by MockGen. DO NOT EDIT. // Code generated by MockGen. DO NOT EDIT.
// Source: etcdclient.go // Source: etcdclient.go
//
// Generated by this command:
//
// mockgen -package internal -destination etcdclient_mock.go -source etcdclient.go EtcdClient
//
// Package internal is a generated GoMock package. // Package internal is a generated GoMock package.
package internal package internal
@@ -13,36 +8,35 @@ import (
context "context" context "context"
reflect "reflect" reflect "reflect"
gomock "github.com/golang/mock/gomock"
clientv3 "go.etcd.io/etcd/client/v3" clientv3 "go.etcd.io/etcd/client/v3"
gomock "go.uber.org/mock/gomock"
grpc "google.golang.org/grpc" grpc "google.golang.org/grpc"
) )
// MockEtcdClient is a mock of EtcdClient interface. // MockEtcdClient is a mock of EtcdClient interface
type MockEtcdClient struct { type MockEtcdClient struct {
ctrl *gomock.Controller ctrl *gomock.Controller
recorder *MockEtcdClientMockRecorder recorder *MockEtcdClientMockRecorder
isgomock struct{}
} }
// MockEtcdClientMockRecorder is the mock recorder for MockEtcdClient. // MockEtcdClientMockRecorder is the mock recorder for MockEtcdClient
type MockEtcdClientMockRecorder struct { type MockEtcdClientMockRecorder struct {
mock *MockEtcdClient mock *MockEtcdClient
} }
// NewMockEtcdClient creates a new mock instance. // NewMockEtcdClient creates a new mock instance
func NewMockEtcdClient(ctrl *gomock.Controller) *MockEtcdClient { func NewMockEtcdClient(ctrl *gomock.Controller) *MockEtcdClient {
mock := &MockEtcdClient{ctrl: ctrl} mock := &MockEtcdClient{ctrl: ctrl}
mock.recorder = &MockEtcdClientMockRecorder{mock} mock.recorder = &MockEtcdClientMockRecorder{mock}
return mock return mock
} }
// EXPECT returns an object that allows the caller to indicate expected use. // EXPECT returns an object that allows the caller to indicate expected use
func (m *MockEtcdClient) EXPECT() *MockEtcdClientMockRecorder { func (m *MockEtcdClient) EXPECT() *MockEtcdClientMockRecorder {
return m.recorder return m.recorder
} }
// ActiveConnection mocks base method. // ActiveConnection mocks base method
func (m *MockEtcdClient) ActiveConnection() *grpc.ClientConn { func (m *MockEtcdClient) ActiveConnection() *grpc.ClientConn {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ActiveConnection") ret := m.ctrl.Call(m, "ActiveConnection")
@@ -50,13 +44,13 @@ func (m *MockEtcdClient) ActiveConnection() *grpc.ClientConn {
return ret0 return ret0
} }
// ActiveConnection indicates an expected call of ActiveConnection. // ActiveConnection indicates an expected call of ActiveConnection
func (mr *MockEtcdClientMockRecorder) ActiveConnection() *gomock.Call { func (mr *MockEtcdClientMockRecorder) ActiveConnection() *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ActiveConnection", reflect.TypeOf((*MockEtcdClient)(nil).ActiveConnection)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ActiveConnection", reflect.TypeOf((*MockEtcdClient)(nil).ActiveConnection))
} }
// Close mocks base method. // Close mocks base method
func (m *MockEtcdClient) Close() error { func (m *MockEtcdClient) Close() error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close") ret := m.ctrl.Call(m, "Close")
@@ -64,13 +58,13 @@ func (m *MockEtcdClient) Close() error {
return ret0 return ret0
} }
// Close indicates an expected call of Close. // Close indicates an expected call of Close
func (mr *MockEtcdClientMockRecorder) Close() *gomock.Call { func (mr *MockEtcdClientMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockEtcdClient)(nil).Close)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockEtcdClient)(nil).Close))
} }
// Ctx mocks base method. // Ctx mocks base method
func (m *MockEtcdClient) Ctx() context.Context { func (m *MockEtcdClient) Ctx() context.Context {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Ctx") ret := m.ctrl.Call(m, "Ctx")
@@ -78,16 +72,16 @@ func (m *MockEtcdClient) Ctx() context.Context {
return ret0 return ret0
} }
// Ctx indicates an expected call of Ctx. // Ctx indicates an expected call of Ctx
func (mr *MockEtcdClientMockRecorder) Ctx() *gomock.Call { func (mr *MockEtcdClientMockRecorder) Ctx() *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ctx", reflect.TypeOf((*MockEtcdClient)(nil).Ctx)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ctx", reflect.TypeOf((*MockEtcdClient)(nil).Ctx))
} }
// Get mocks base method. // Get mocks base method
func (m *MockEtcdClient) Get(ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) { func (m *MockEtcdClient) Get(ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
varargs := []any{ctx, key} varargs := []interface{}{ctx, key}
for _, a := range opts { for _, a := range opts {
varargs = append(varargs, a) varargs = append(varargs, a)
} }
@@ -97,14 +91,14 @@ func (m *MockEtcdClient) Get(ctx context.Context, key string, opts ...clientv3.O
return ret0, ret1 return ret0, ret1
} }
// Get indicates an expected call of Get. // Get indicates an expected call of Get
func (mr *MockEtcdClientMockRecorder) Get(ctx, key any, opts ...any) *gomock.Call { func (mr *MockEtcdClientMockRecorder) Get(ctx, key interface{}, opts ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, key}, opts...) varargs := append([]interface{}{ctx, key}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockEtcdClient)(nil).Get), varargs...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockEtcdClient)(nil).Get), varargs...)
} }
// Grant mocks base method. // Grant mocks base method
func (m *MockEtcdClient) Grant(ctx context.Context, ttl int64) (*clientv3.LeaseGrantResponse, error) { func (m *MockEtcdClient) Grant(ctx context.Context, ttl int64) (*clientv3.LeaseGrantResponse, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Grant", ctx, ttl) ret := m.ctrl.Call(m, "Grant", ctx, ttl)
@@ -113,13 +107,13 @@ func (m *MockEtcdClient) Grant(ctx context.Context, ttl int64) (*clientv3.LeaseG
return ret0, ret1 return ret0, ret1
} }
// Grant indicates an expected call of Grant. // Grant indicates an expected call of Grant
func (mr *MockEtcdClientMockRecorder) Grant(ctx, ttl any) *gomock.Call { func (mr *MockEtcdClientMockRecorder) Grant(ctx, ttl interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Grant", reflect.TypeOf((*MockEtcdClient)(nil).Grant), ctx, ttl) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Grant", reflect.TypeOf((*MockEtcdClient)(nil).Grant), ctx, ttl)
} }
// KeepAlive mocks base method. // KeepAlive mocks base method
func (m *MockEtcdClient) KeepAlive(ctx context.Context, id clientv3.LeaseID) (<-chan *clientv3.LeaseKeepAliveResponse, error) { func (m *MockEtcdClient) KeepAlive(ctx context.Context, id clientv3.LeaseID) (<-chan *clientv3.LeaseKeepAliveResponse, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "KeepAlive", ctx, id) ret := m.ctrl.Call(m, "KeepAlive", ctx, id)
@@ -128,16 +122,16 @@ func (m *MockEtcdClient) KeepAlive(ctx context.Context, id clientv3.LeaseID) (<-
return ret0, ret1 return ret0, ret1
} }
// KeepAlive indicates an expected call of KeepAlive. // KeepAlive indicates an expected call of KeepAlive
func (mr *MockEtcdClientMockRecorder) KeepAlive(ctx, id any) *gomock.Call { func (mr *MockEtcdClientMockRecorder) KeepAlive(ctx, id interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeepAlive", reflect.TypeOf((*MockEtcdClient)(nil).KeepAlive), ctx, id) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeepAlive", reflect.TypeOf((*MockEtcdClient)(nil).KeepAlive), ctx, id)
} }
// Put mocks base method. // Put mocks base method
func (m *MockEtcdClient) Put(ctx context.Context, key, val string, opts ...clientv3.OpOption) (*clientv3.PutResponse, error) { func (m *MockEtcdClient) Put(ctx context.Context, key, val string, opts ...clientv3.OpOption) (*clientv3.PutResponse, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
varargs := []any{ctx, key, val} varargs := []interface{}{ctx, key, val}
for _, a := range opts { for _, a := range opts {
varargs = append(varargs, a) varargs = append(varargs, a)
} }
@@ -147,14 +141,14 @@ func (m *MockEtcdClient) Put(ctx context.Context, key, val string, opts ...clien
return ret0, ret1 return ret0, ret1
} }
// Put indicates an expected call of Put. // Put indicates an expected call of Put
func (mr *MockEtcdClientMockRecorder) Put(ctx, key, val any, opts ...any) *gomock.Call { func (mr *MockEtcdClientMockRecorder) Put(ctx, key, val interface{}, opts ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, key, val}, opts...) varargs := append([]interface{}{ctx, key, val}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Put", reflect.TypeOf((*MockEtcdClient)(nil).Put), varargs...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Put", reflect.TypeOf((*MockEtcdClient)(nil).Put), varargs...)
} }
// Revoke mocks base method. // Revoke mocks base method
func (m *MockEtcdClient) Revoke(ctx context.Context, id clientv3.LeaseID) (*clientv3.LeaseRevokeResponse, error) { func (m *MockEtcdClient) Revoke(ctx context.Context, id clientv3.LeaseID) (*clientv3.LeaseRevokeResponse, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Revoke", ctx, id) ret := m.ctrl.Call(m, "Revoke", ctx, id)
@@ -163,16 +157,16 @@ func (m *MockEtcdClient) Revoke(ctx context.Context, id clientv3.LeaseID) (*clie
return ret0, ret1 return ret0, ret1
} }
// Revoke indicates an expected call of Revoke. // Revoke indicates an expected call of Revoke
func (mr *MockEtcdClientMockRecorder) Revoke(ctx, id any) *gomock.Call { func (mr *MockEtcdClientMockRecorder) Revoke(ctx, id interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Revoke", reflect.TypeOf((*MockEtcdClient)(nil).Revoke), ctx, id) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Revoke", reflect.TypeOf((*MockEtcdClient)(nil).Revoke), ctx, id)
} }
// Watch mocks base method. // Watch mocks base method
func (m *MockEtcdClient) Watch(ctx context.Context, key string, opts ...clientv3.OpOption) clientv3.WatchChan { func (m *MockEtcdClient) Watch(ctx context.Context, key string, opts ...clientv3.OpOption) clientv3.WatchChan {
m.ctrl.T.Helper() m.ctrl.T.Helper()
varargs := []any{ctx, key} varargs := []interface{}{ctx, key}
for _, a := range opts { for _, a := range opts {
varargs = append(varargs, a) varargs = append(varargs, a)
} }
@@ -181,9 +175,9 @@ func (m *MockEtcdClient) Watch(ctx context.Context, key string, opts ...clientv3
return ret0 return ret0
} }
// Watch indicates an expected call of Watch. // Watch indicates an expected call of Watch
func (mr *MockEtcdClientMockRecorder) Watch(ctx, key any, opts ...any) *gomock.Call { func (mr *MockEtcdClientMockRecorder) Watch(ctx, key interface{}, opts ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
varargs := append([]any{ctx, key}, opts...) varargs := append([]interface{}{ctx, key}, opts...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Watch", reflect.TypeOf((*MockEtcdClient)(nil).Watch), varargs...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Watch", reflect.TypeOf((*MockEtcdClient)(nil).Watch), varargs...)
} }

View File

@@ -2,7 +2,6 @@ package internal
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"io" "io"
"sort" "sort"
@@ -10,30 +9,25 @@ import (
"sync" "sync"
"time" "time"
"github.com/zeromicro/go-zero/core/contextx"
"github.com/zeromicro/go-zero/core/lang" "github.com/zeromicro/go-zero/core/lang"
"github.com/zeromicro/go-zero/core/logc" "github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/mathx"
"github.com/zeromicro/go-zero/core/syncx" "github.com/zeromicro/go-zero/core/syncx"
"github.com/zeromicro/go-zero/core/threading" "github.com/zeromicro/go-zero/core/threading"
"go.etcd.io/etcd/api/v3/v3rpc/rpctypes"
clientv3 "go.etcd.io/etcd/client/v3" clientv3 "go.etcd.io/etcd/client/v3"
) )
const coolDownDeviation = 0.05
var ( var (
registry = Registry{ registry = Registry{
clusters: make(map[string]*cluster), clusters: make(map[string]*cluster),
} }
connManager = syncx.NewResourceManager() connManager = syncx.NewResourceManager()
coolDownUnstable = mathx.NewUnstable(coolDownDeviation)
errClosed = errors.New("etcd monitor chan has been closed")
) )
// A Registry is a registry that manages the etcd client connections. // A Registry is a registry that manages the etcd client connections.
type Registry struct { type Registry struct {
clusters map[string]*cluster clusters map[string]*cluster
lock sync.RWMutex lock sync.Mutex
} }
// GetRegistry returns a global Registry. // GetRegistry returns a global Registry.
@@ -43,148 +37,60 @@ func GetRegistry() *Registry {
// GetConn returns an etcd client connection associated with given endpoints. // GetConn returns an etcd client connection associated with given endpoints.
func (r *Registry) GetConn(endpoints []string) (EtcdClient, error) { func (r *Registry) GetConn(endpoints []string) (EtcdClient, error) {
c, _ := r.getOrCreateCluster(endpoints) c, _ := r.getCluster(endpoints)
return c.getClient() return c.getClient()
} }
// Monitor monitors the key on given etcd endpoints, notify with the given UpdateListener. // Monitor monitors the key on given etcd endpoints, notify with the given UpdateListener.
func (r *Registry) Monitor(endpoints []string, key string, exactMatch bool, l UpdateListener) error { func (r *Registry) Monitor(endpoints []string, key string, l UpdateListener) error {
wkey := watchKey{ c, exists := r.getCluster(endpoints)
key: key,
exactMatch: exactMatch,
}
c, exists := r.getOrCreateCluster(endpoints)
// if exists, the existing values should be updated to the listener. // if exists, the existing values should be updated to the listener.
if exists { if exists {
c.lock.Lock() kvs := c.getCurrent(key)
watcher, ok := c.watchers[wkey] for _, kv := range kvs {
if ok { l.OnAdd(kv)
watcher.listeners = append(watcher.listeners, l)
}
c.lock.Unlock()
if ok {
kvs := c.getCurrent(wkey)
for _, kv := range kvs {
l.OnAdd(kv)
}
return nil
} }
} }
return c.monitor(wkey, l) return c.monitor(key, l)
} }
func (r *Registry) Unmonitor(endpoints []string, key string, exactMatch bool, l UpdateListener) { func (r *Registry) getCluster(endpoints []string) (c *cluster, exists bool) {
c, exists := r.getCluster(endpoints)
if !exists {
return
}
wkey := watchKey{
key: key,
exactMatch: exactMatch,
}
c.lock.Lock()
defer c.lock.Unlock()
watcher, ok := c.watchers[wkey]
if !ok {
return
}
for i, listener := range watcher.listeners {
if listener == l {
watcher.listeners = append(watcher.listeners[:i], watcher.listeners[i+1:]...)
break
}
}
if len(watcher.listeners) == 0 {
if watcher.cancel != nil {
watcher.cancel()
}
delete(c.watchers, wkey)
}
}
func (r *Registry) getCluster(endpoints []string) (*cluster, bool) {
clusterKey := getClusterKey(endpoints) clusterKey := getClusterKey(endpoints)
r.lock.Lock()
r.lock.RLock() defer r.lock.Unlock()
c, ok := r.clusters[clusterKey] c, exists = r.clusters[clusterKey]
r.lock.RUnlock()
return c, ok
}
func (r *Registry) getOrCreateCluster(endpoints []string) (c *cluster, exists bool) {
c, exists = r.getCluster(endpoints)
if !exists { if !exists {
clusterKey := getClusterKey(endpoints) c = newCluster(endpoints)
r.clusters[clusterKey] = c
r.lock.Lock()
defer r.lock.Unlock()
// double-check locking
c, exists = r.clusters[clusterKey]
if !exists {
c = newCluster(endpoints)
r.clusters[clusterKey] = c
}
} }
return return
} }
type ( type cluster struct {
watchKey struct { endpoints []string
key string key string
exactMatch bool values map[string]map[string]string
} listeners map[string][]UpdateListener
watchGroup *threading.RoutineGroup
watchValue struct { done chan lang.PlaceholderType
listeners []UpdateListener lock sync.Mutex
values map[string]string }
cancel context.CancelFunc
}
cluster struct {
endpoints []string
key string
watchers map[watchKey]*watchValue
watchGroup *threading.RoutineGroup
done chan lang.PlaceholderType
lock sync.RWMutex
}
)
func newCluster(endpoints []string) *cluster { func newCluster(endpoints []string) *cluster {
return &cluster{ return &cluster{
endpoints: endpoints, endpoints: endpoints,
key: getClusterKey(endpoints), key: getClusterKey(endpoints),
watchers: make(map[watchKey]*watchValue), values: make(map[string]map[string]string),
listeners: make(map[string][]UpdateListener),
watchGroup: threading.NewRoutineGroup(), watchGroup: threading.NewRoutineGroup(),
done: make(chan lang.PlaceholderType), done: make(chan lang.PlaceholderType),
} }
} }
func (c *cluster) addListener(key watchKey, l UpdateListener) { func (c *cluster) context(cli EtcdClient) context.Context {
c.lock.Lock() return contextx.ValueOnlyFrom(cli.Ctx())
defer c.lock.Unlock()
watcher, ok := c.watchers[key]
if ok {
watcher.listeners = append(watcher.listeners, l)
return
}
val := newWatchValue()
val.listeners = []UpdateListener{l}
c.watchers[key] = val
} }
func (c *cluster) getClient() (EtcdClient, error) { func (c *cluster) getClient() (EtcdClient, error) {
@@ -198,17 +104,12 @@ func (c *cluster) getClient() (EtcdClient, error) {
return val.(EtcdClient), nil return val.(EtcdClient), nil
} }
func (c *cluster) getCurrent(key watchKey) []KV { func (c *cluster) getCurrent(key string) []KV {
c.lock.RLock() c.lock.Lock()
defer c.lock.RUnlock() defer c.lock.Unlock()
watcher, ok := c.watchers[key] var kvs []KV
if !ok { for k, v := range c.values[key] {
return nil
}
kvs := make([]KV, 0, len(watcher.values))
for k, v := range watcher.values {
kvs = append(kvs, KV{ kvs = append(kvs, KV{
Key: k, Key: k,
Val: v, Val: v,
@@ -218,23 +119,42 @@ func (c *cluster) getCurrent(key watchKey) []KV {
return kvs return kvs
} }
func (c *cluster) handleChanges(key watchKey, kvs []KV) { func (c *cluster) handleChanges(key string, kvs []KV) {
var add []KV
var remove []KV
c.lock.Lock() c.lock.Lock()
watcher, ok := c.watchers[key] listeners := append([]UpdateListener(nil), c.listeners[key]...)
vals, ok := c.values[key]
if !ok { if !ok {
c.lock.Unlock() add = kvs
return vals = make(map[string]string)
for _, kv := range kvs {
vals[kv.Key] = kv.Val
}
c.values[key] = vals
} else {
m := make(map[string]string)
for _, kv := range kvs {
m[kv.Key] = kv.Val
}
for k, v := range vals {
if val, ok := m[k]; !ok || v != val {
remove = append(remove, KV{
Key: k,
Val: v,
})
}
}
for k, v := range m {
if val, ok := vals[k]; !ok || v != val {
add = append(add, KV{
Key: k,
Val: v,
})
}
}
c.values[key] = m
} }
listeners := append([]UpdateListener(nil), watcher.listeners...)
// watcher.values cannot be nil
vals := watcher.values
newVals := make(map[string]string, len(kvs)+len(vals))
for _, kv := range kvs {
newVals[kv.Key] = kv.Val
}
add, remove := calculateChanges(vals, newVals)
watcher.values = newVals
c.lock.Unlock() c.lock.Unlock()
for _, kv := range add { for _, kv := range add {
@@ -249,22 +169,20 @@ func (c *cluster) handleChanges(key watchKey, kvs []KV) {
} }
} }
func (c *cluster) handleWatchEvents(ctx context.Context, key watchKey, events []*clientv3.Event) { func (c *cluster) handleWatchEvents(key string, events []*clientv3.Event) {
c.lock.RLock() c.lock.Lock()
watcher, ok := c.watchers[key] listeners := append([]UpdateListener(nil), c.listeners[key]...)
if !ok { c.lock.Unlock()
c.lock.RUnlock()
return
}
listeners := append([]UpdateListener(nil), watcher.listeners...)
c.lock.RUnlock()
for _, ev := range events { for _, ev := range events {
switch ev.Type { switch ev.Type {
case clientv3.EventTypePut: case clientv3.EventTypePut:
c.lock.Lock() c.lock.Lock()
watcher.values[string(ev.Kv.Key)] = string(ev.Kv.Value) if vals, ok := c.values[key]; ok {
vals[string(ev.Kv.Key)] = string(ev.Kv.Value)
} else {
c.values[key] = map[string]string{string(ev.Kv.Key): string(ev.Kv.Value)}
}
c.lock.Unlock() c.lock.Unlock()
for _, l := range listeners { for _, l := range listeners {
l.OnAdd(KV{ l.OnAdd(KV{
@@ -274,7 +192,9 @@ func (c *cluster) handleWatchEvents(ctx context.Context, key watchKey, events []
} }
case clientv3.EventTypeDelete: case clientv3.EventTypeDelete:
c.lock.Lock() c.lock.Lock()
delete(watcher.values, string(ev.Kv.Key)) if vals, ok := c.values[key]; ok {
delete(vals, string(ev.Kv.Key))
}
c.lock.Unlock() c.lock.Unlock()
for _, l := range listeners { for _, l := range listeners {
l.OnDelete(KV{ l.OnDelete(KV{
@@ -283,32 +203,27 @@ func (c *cluster) handleWatchEvents(ctx context.Context, key watchKey, events []
}) })
} }
default: default:
logc.Errorf(ctx, "Unknown event type: %v", ev.Type) logx.Errorf("Unknown event type: %v", ev.Type)
} }
} }
} }
func (c *cluster) load(cli EtcdClient, key watchKey) int64 { func (c *cluster) load(cli EtcdClient, key string) int64 {
var resp *clientv3.GetResponse var resp *clientv3.GetResponse
for { for {
var err error var err error
ctx, cancel := context.WithTimeout(cli.Ctx(), RequestTimeout) ctx, cancel := context.WithTimeout(c.context(cli), RequestTimeout)
if key.exactMatch { resp, err = cli.Get(ctx, makeKeyPrefix(key), clientv3.WithPrefix())
resp, err = cli.Get(ctx, key.key)
} else {
resp, err = cli.Get(ctx, makeKeyPrefix(key.key), clientv3.WithPrefix())
}
cancel() cancel()
if err == nil { if err == nil {
break break
} }
logc.Errorf(cli.Ctx(), "%s, key: %s, exactMatch: %t", err.Error(), key.key, key.exactMatch) logx.Error(err)
time.Sleep(coolDownUnstable.AroundDuration(coolDownInterval)) time.Sleep(coolDownInterval)
} }
kvs := make([]KV, 0, len(resp.Kvs)) var kvs []KV
for _, ev := range resp.Kvs { for _, ev := range resp.Kvs {
kvs = append(kvs, KV{ kvs = append(kvs, KV{
Key: string(ev.Key), Key: string(ev.Key),
@@ -321,13 +236,16 @@ func (c *cluster) load(cli EtcdClient, key watchKey) int64 {
return resp.Header.Revision return resp.Header.Revision
} }
func (c *cluster) monitor(key watchKey, l UpdateListener) error { func (c *cluster) monitor(key string, l UpdateListener) error {
c.lock.Lock()
c.listeners[key] = append(c.listeners[key], l)
c.lock.Unlock()
cli, err := c.getClient() cli, err := c.getClient()
if err != nil { if err != nil {
return err return err
} }
c.addListener(key, l)
rev := c.load(cli, key) rev := c.load(cli, key)
c.watchGroup.Run(func() { c.watchGroup.Run(func() {
c.watch(cli, key, rev) c.watch(cli, key, rev)
@@ -349,22 +267,16 @@ func (c *cluster) newClient() (EtcdClient, error) {
func (c *cluster) reload(cli EtcdClient) { func (c *cluster) reload(cli EtcdClient) {
c.lock.Lock() c.lock.Lock()
// cancel the previous watches
close(c.done) close(c.done)
c.watchGroup.Wait() c.watchGroup.Wait()
keys := make([]watchKey, 0, len(c.watchers))
for wk, wval := range c.watchers {
keys = append(keys, wk)
if wval.cancel != nil {
wval.cancel()
}
}
c.done = make(chan lang.PlaceholderType) c.done = make(chan lang.PlaceholderType)
c.watchGroup = threading.NewRoutineGroup() c.watchGroup = threading.NewRoutineGroup()
var keys []string
for k := range c.listeners {
keys = append(keys, k)
}
c.lock.Unlock() c.lock.Unlock()
// start new watches
for _, key := range keys { for _, key := range keys {
k := key k := key
c.watchGroup.Run(func() { c.watchGroup.Run(func() {
@@ -374,80 +286,46 @@ func (c *cluster) reload(cli EtcdClient) {
} }
} }
func (c *cluster) watch(cli EtcdClient, key watchKey, rev int64) { func (c *cluster) watch(cli EtcdClient, key string, rev int64) {
for { for {
err := c.watchStream(cli, key, rev) if c.watchStream(cli, key, rev) {
if err == nil {
return return
} }
if rev != 0 && errors.Is(err, rpctypes.ErrCompacted) {
logc.Errorf(cli.Ctx(), "etcd watch stream has been compacted, try to reload, rev %d", rev)
rev = c.load(cli, key)
}
// log the error and retry
logc.Error(cli.Ctx(), err)
} }
} }
func (c *cluster) watchStream(cli EtcdClient, key watchKey, rev int64) error { func (c *cluster) watchStream(cli EtcdClient, key string, rev int64) bool {
ctx, rch := c.setupWatch(cli, key, rev) var rch clientv3.WatchChan
if rev != 0 {
rch = cli.Watch(clientv3.WithRequireLeader(c.context(cli)), makeKeyPrefix(key), clientv3.WithPrefix(),
clientv3.WithRev(rev+1))
} else {
rch = cli.Watch(clientv3.WithRequireLeader(c.context(cli)), makeKeyPrefix(key), clientv3.WithPrefix())
}
for { for {
select { select {
case wresp, ok := <-rch: case wresp, ok := <-rch:
if !ok { if !ok {
return errClosed logx.Error("etcd monitor chan has been closed")
return false
} }
if wresp.Canceled { if wresp.Canceled {
return fmt.Errorf("etcd monitor chan has been canceled, error: %w", wresp.Err()) logx.Errorf("etcd monitor chan has been canceled, error: %v", wresp.Err())
return false
} }
if wresp.Err() != nil { if wresp.Err() != nil {
return fmt.Errorf("etcd monitor chan error: %w", wresp.Err()) logx.Error(fmt.Sprintf("etcd monitor chan error: %v", wresp.Err()))
return false
} }
c.handleWatchEvents(ctx, key, wresp.Events) c.handleWatchEvents(key, wresp.Events)
case <-ctx.Done():
return nil
case <-c.done: case <-c.done:
return nil return true
} }
} }
} }
func (c *cluster) setupWatch(cli EtcdClient, key watchKey, rev int64) (context.Context, clientv3.WatchChan) {
var (
rch clientv3.WatchChan
ops []clientv3.OpOption
wkey = key.key
)
if !key.exactMatch {
wkey = makeKeyPrefix(key.key)
ops = append(ops, clientv3.WithPrefix())
}
if rev != 0 {
ops = append(ops, clientv3.WithRev(rev+1))
}
ctx, cancel := context.WithCancel(cli.Ctx())
if watcher, ok := c.watchers[key]; ok {
watcher.cancel = cancel
} else {
val := newWatchValue()
val.cancel = cancel
c.lock.Lock()
c.watchers[key] = val
c.lock.Unlock()
}
rch = cli.Watch(clientv3.WithRequireLeader(ctx), wkey, ops...)
return ctx, rch
}
func (c *cluster) watchConnState(cli EtcdClient) { func (c *cluster) watchConnState(cli EtcdClient) {
watcher := newStateWatcher() watcher := newStateWatcher()
watcher.addListener(func() { watcher.addListener(func() {
@@ -459,11 +337,13 @@ func (c *cluster) watchConnState(cli EtcdClient) {
// DialClient dials an etcd cluster with given endpoints. // DialClient dials an etcd cluster with given endpoints.
func DialClient(endpoints []string) (EtcdClient, error) { func DialClient(endpoints []string) (EtcdClient, error) {
cfg := clientv3.Config{ cfg := clientv3.Config{
Endpoints: endpoints, Endpoints: endpoints,
AutoSyncInterval: autoSyncInterval, AutoSyncInterval: autoSyncInterval,
DialTimeout: DialTimeout, DialTimeout: DialTimeout,
RejectOldCluster: true, DialKeepAliveTime: dialKeepAliveTime,
PermitWithoutStream: true, DialKeepAliveTimeout: DialTimeout,
RejectOldCluster: true,
PermitWithoutStream: true,
} }
if account, ok := GetAccount(endpoints); ok { if account, ok := GetAccount(endpoints); ok {
cfg.Username = account.User cfg.Username = account.User
@@ -476,28 +356,6 @@ func DialClient(endpoints []string) (EtcdClient, error) {
return clientv3.New(cfg) return clientv3.New(cfg)
} }
func calculateChanges(oldVals, newVals map[string]string) (add, remove []KV) {
for k, v := range newVals {
if val, ok := oldVals[k]; !ok || v != val {
add = append(add, KV{
Key: k,
Val: v,
})
}
}
for k, v := range oldVals {
if val, ok := newVals[k]; !ok || v != val {
remove = append(remove, KV{
Key: k,
Val: v,
})
}
}
return add, remove
}
func getClusterKey(endpoints []string) string { func getClusterKey(endpoints []string) string {
sort.Strings(endpoints) sort.Strings(endpoints)
return strings.Join(endpoints, endpointsSeparator) return strings.Join(endpoints, endpointsSeparator)
@@ -506,10 +364,3 @@ func getClusterKey(endpoints []string) string {
func makeKeyPrefix(key string) string { func makeKeyPrefix(key string) string {
return fmt.Sprintf("%s%c", key, Delimiter) return fmt.Sprintf("%s%c", key, Delimiter)
} }
// NewClient returns a watchValue that make sure values are not nil.
func newWatchValue() *watchValue {
return &watchValue{
values: make(map[string]string),
}
}

View File

@@ -2,22 +2,18 @@ package internal
import ( import (
"context" "context"
"os"
"sync" "sync"
"testing" "testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/contextx" "github.com/zeromicro/go-zero/core/contextx"
"github.com/zeromicro/go-zero/core/lang" "github.com/zeromicro/go-zero/core/lang"
"github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stringx" "github.com/zeromicro/go-zero/core/stringx"
"github.com/zeromicro/go-zero/core/threading"
"go.etcd.io/etcd/api/v3/etcdserverpb" "go.etcd.io/etcd/api/v3/etcdserverpb"
"go.etcd.io/etcd/api/v3/mvccpb" "go.etcd.io/etcd/api/v3/mvccpb"
clientv3 "go.etcd.io/etcd/client/v3" clientv3 "go.etcd.io/etcd/client/v3"
"go.etcd.io/etcd/client/v3/mock/mockserver"
"go.uber.org/mock/gomock"
) )
var mockLock sync.Mutex var mockLock sync.Mutex
@@ -39,9 +35,9 @@ func setMockClient(cli EtcdClient) func() {
func TestGetCluster(t *testing.T) { func TestGetCluster(t *testing.T) {
AddAccount([]string{"first"}, "foo", "bar") AddAccount([]string{"first"}, "foo", "bar")
c1, _ := GetRegistry().getOrCreateCluster([]string{"first"}) c1, _ := GetRegistry().getCluster([]string{"first"})
c2, _ := GetRegistry().getOrCreateCluster([]string{"second"}) c2, _ := GetRegistry().getCluster([]string{"second"})
c3, _ := GetRegistry().getOrCreateCluster([]string{"first"}) c3, _ := GetRegistry().getCluster([]string{"first"})
assert.Equal(t, c1, c3) assert.Equal(t, c1, c3)
assert.NotEqual(t, c1, c2) assert.NotEqual(t, c1, c2)
} }
@@ -51,36 +47,6 @@ func TestGetClusterKey(t *testing.T) {
getClusterKey([]string{"remotehost:5678", "localhost:1234"})) getClusterKey([]string{"remotehost:5678", "localhost:1234"}))
} }
func TestUnmonitor(t *testing.T) {
t.Run("no listener", func(t *testing.T) {
reg := &Registry{
clusters: map[string]*cluster{},
}
assert.NotPanics(t, func() {
reg.Unmonitor([]string{"any"}, "any", false, nil)
})
})
t.Run("no value", func(t *testing.T) {
reg := &Registry{
clusters: map[string]*cluster{
"any": {
watchers: map[watchKey]*watchValue{
{
key: "any",
}: {
values: map[string]string{},
},
},
},
},
}
assert.NotPanics(t, func() {
reg.Unmonitor([]string{"any"}, "another", false, nil)
})
})
}
func TestCluster_HandleChanges(t *testing.T) { func TestCluster_HandleChanges(t *testing.T) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
l := NewMockUpdateListener(ctrl) l := NewMockUpdateListener(ctrl)
@@ -109,14 +75,8 @@ func TestCluster_HandleChanges(t *testing.T) {
Val: "4", Val: "4",
}) })
c := newCluster([]string{"any"}) c := newCluster([]string{"any"})
key := watchKey{ c.listeners["any"] = []UpdateListener{l}
key: "any", c.handleChanges("any", []KV{
exactMatch: false,
}
c.watchers[key] = &watchValue{
listeners: []UpdateListener{l},
}
c.handleChanges(key, []KV{
{ {
Key: "first", Key: "first",
Val: "1", Val: "1",
@@ -129,8 +89,8 @@ func TestCluster_HandleChanges(t *testing.T) {
assert.EqualValues(t, map[string]string{ assert.EqualValues(t, map[string]string{
"first": "1", "first": "1",
"second": "2", "second": "2",
}, c.watchers[key].values) }, c.values["any"])
c.handleChanges(key, []KV{ c.handleChanges("any", []KV{
{ {
Key: "third", Key: "third",
Val: "3", Val: "3",
@@ -143,7 +103,7 @@ func TestCluster_HandleChanges(t *testing.T) {
assert.EqualValues(t, map[string]string{ assert.EqualValues(t, map[string]string{
"third": "3", "third": "3",
"fourth": "4", "fourth": "4",
}, c.watchers[key].values) }, c.values["any"])
} }
func TestCluster_Load(t *testing.T) { func TestCluster_Load(t *testing.T) {
@@ -163,11 +123,9 @@ func TestCluster_Load(t *testing.T) {
}, nil) }, nil)
cli.EXPECT().Ctx().Return(context.Background()) cli.EXPECT().Ctx().Return(context.Background())
c := &cluster{ c := &cluster{
watchers: make(map[watchKey]*watchValue), values: make(map[string]map[string]string),
} }
c.load(cli, watchKey{ c.load(cli, "any")
key: "any",
})
} }
func TestCluster_Watch(t *testing.T) { func TestCluster_Watch(t *testing.T) {
@@ -199,25 +157,20 @@ func TestCluster_Watch(t *testing.T) {
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
c := &cluster{ c := &cluster{
watchers: make(map[watchKey]*watchValue), listeners: make(map[string][]UpdateListener),
} values: make(map[string]map[string]string),
key := watchKey{
key: "any",
} }
listener := NewMockUpdateListener(ctrl) listener := NewMockUpdateListener(ctrl)
c.watchers[key] = &watchValue{ c.listeners["any"] = []UpdateListener{listener}
listeners: []UpdateListener{listener},
values: make(map[string]string),
}
listener.EXPECT().OnAdd(gomock.Any()).Do(func(kv KV) { listener.EXPECT().OnAdd(gomock.Any()).Do(func(kv KV) {
assert.Equal(t, "hello", kv.Key) assert.Equal(t, "hello", kv.Key)
assert.Equal(t, "world", kv.Val) assert.Equal(t, "world", kv.Val)
wg.Done() wg.Done()
}).MaxTimes(1) }).MaxTimes(1)
listener.EXPECT().OnDelete(gomock.Any()).Do(func(_ any) { listener.EXPECT().OnDelete(gomock.Any()).Do(func(_ interface{}) {
wg.Done() wg.Done()
}).MaxTimes(1) }).MaxTimes(1)
go c.watch(cli, key, 0) go c.watch(cli, "any", 0)
ch <- clientv3.WatchResponse{ ch <- clientv3.WatchResponse{
Events: []*clientv3.Event{ Events: []*clientv3.Event{
{ {
@@ -255,111 +208,17 @@ func TestClusterWatch_RespFailures(t *testing.T) {
ch := make(chan clientv3.WatchResponse) ch := make(chan clientv3.WatchResponse)
cli.EXPECT().Watch(gomock.Any(), "any/", gomock.Any()).Return(ch).AnyTimes() cli.EXPECT().Watch(gomock.Any(), "any/", gomock.Any()).Return(ch).AnyTimes()
cli.EXPECT().Ctx().Return(context.Background()).AnyTimes() cli.EXPECT().Ctx().Return(context.Background()).AnyTimes()
c := &cluster{ c := new(cluster)
watchers: make(map[watchKey]*watchValue),
}
c.done = make(chan lang.PlaceholderType) c.done = make(chan lang.PlaceholderType)
go func() { go func() {
ch <- resp ch <- resp
close(c.done) close(c.done)
}() }()
key := watchKey{ c.watch(cli, "any", 0)
key: "any",
}
c.watch(cli, key, 0)
}) })
} }
} }
func TestCluster_getCurrent(t *testing.T) {
t.Run("no value", func(t *testing.T) {
c := &cluster{
watchers: map[watchKey]*watchValue{
{
key: "any",
}: {
values: map[string]string{},
},
},
}
assert.Nil(t, c.getCurrent(watchKey{
key: "another",
}))
})
}
func TestCluster_handleWatchEvents(t *testing.T) {
t.Run("no value", func(t *testing.T) {
c := &cluster{
watchers: map[watchKey]*watchValue{
{
key: "any",
}: {
values: map[string]string{},
},
},
}
assert.NotPanics(t, func() {
c.handleWatchEvents(context.Background(), watchKey{
key: "another",
}, nil)
})
})
}
func TestCluster_addListener(t *testing.T) {
t.Run("has listener", func(t *testing.T) {
c := &cluster{
watchers: map[watchKey]*watchValue{
{
key: "any",
}: {
listeners: make([]UpdateListener, 0),
},
},
}
assert.NotPanics(t, func() {
c.addListener(watchKey{
key: "any",
}, nil)
})
})
t.Run("no listener", func(t *testing.T) {
c := &cluster{
watchers: map[watchKey]*watchValue{
{
key: "any",
}: {
listeners: make([]UpdateListener, 0),
},
},
}
assert.NotPanics(t, func() {
c.addListener(watchKey{
key: "another",
}, nil)
})
})
}
func TestCluster_reload(t *testing.T) {
c := &cluster{
watchers: map[watchKey]*watchValue{},
watchGroup: threading.NewRoutineGroup(),
done: make(chan lang.PlaceholderType),
}
ctrl := gomock.NewController(t)
defer ctrl.Finish()
cli := NewMockEtcdClient(ctrl)
restore := setMockClient(cli)
defer restore()
assert.NotPanics(t, func() {
c.reload(cli)
})
}
func TestClusterWatch_CloseChan(t *testing.T) { func TestClusterWatch_CloseChan(t *testing.T) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
defer ctrl.Finish() defer ctrl.Finish()
@@ -369,17 +228,13 @@ func TestClusterWatch_CloseChan(t *testing.T) {
ch := make(chan clientv3.WatchResponse) ch := make(chan clientv3.WatchResponse)
cli.EXPECT().Watch(gomock.Any(), "any/", gomock.Any()).Return(ch).AnyTimes() cli.EXPECT().Watch(gomock.Any(), "any/", gomock.Any()).Return(ch).AnyTimes()
cli.EXPECT().Ctx().Return(context.Background()).AnyTimes() cli.EXPECT().Ctx().Return(context.Background()).AnyTimes()
c := &cluster{ c := new(cluster)
watchers: make(map[watchKey]*watchValue),
}
c.done = make(chan lang.PlaceholderType) c.done = make(chan lang.PlaceholderType)
go func() { go func() {
close(ch) close(ch)
close(c.done) close(c.done)
}() }()
c.watch(cli, watchKey{ c.watch(cli, "any", 0)
key: "any",
}, 0)
} }
func TestValueOnlyContext(t *testing.T) { func TestValueOnlyContext(t *testing.T) {
@@ -387,101 +242,3 @@ func TestValueOnlyContext(t *testing.T) {
ctx.Done() ctx.Done()
assert.Nil(t, ctx.Err()) assert.Nil(t, ctx.Err())
} }
func TestDialClient(t *testing.T) {
svr, err := mockserver.StartMockServers(1)
assert.NoError(t, err)
svr.StartAt(0)
certFile := createTempFile(t, []byte(certContent))
defer os.Remove(certFile)
keyFile := createTempFile(t, []byte(keyContent))
defer os.Remove(keyFile)
caFile := createTempFile(t, []byte(caContent))
defer os.Remove(caFile)
endpoints := []string{svr.Servers[0].Address}
AddAccount(endpoints, "foo", "bar")
assert.NoError(t, AddTLS(endpoints, certFile, keyFile, caFile, false))
old := DialTimeout
DialTimeout = time.Millisecond
defer func() {
DialTimeout = old
}()
_, err = DialClient(endpoints)
assert.Error(t, err)
}
func TestRegistry_Monitor(t *testing.T) {
svr, err := mockserver.StartMockServers(1)
assert.NoError(t, err)
svr.StartAt(0)
endpoints := []string{svr.Servers[0].Address}
GetRegistry().lock.Lock()
GetRegistry().clusters = map[string]*cluster{
getClusterKey(endpoints): {
watchers: map[watchKey]*watchValue{
{
key: "foo",
exactMatch: true,
}: {
values: map[string]string{
"bar": "baz",
},
},
},
},
}
GetRegistry().lock.Unlock()
assert.Error(t, GetRegistry().Monitor(endpoints, "foo", false, new(mockListener)))
}
func TestRegistry_Unmonitor(t *testing.T) {
svr, err := mockserver.StartMockServers(1)
assert.NoError(t, err)
svr.StartAt(0)
_, cancel := context.WithCancel(context.Background())
endpoints := []string{svr.Servers[0].Address}
GetRegistry().lock.Lock()
GetRegistry().clusters = map[string]*cluster{
getClusterKey(endpoints): {
watchers: map[watchKey]*watchValue{
{
key: "foo",
exactMatch: true,
}: {
values: map[string]string{
"bar": "baz",
},
cancel: cancel,
},
},
},
}
GetRegistry().lock.Unlock()
l := new(mockListener)
assert.NoError(t, GetRegistry().Monitor(endpoints, "foo", true, l))
watchVals := GetRegistry().clusters[getClusterKey(endpoints)].watchers[watchKey{
key: "foo",
exactMatch: true,
}]
assert.Equal(t, 1, len(watchVals.listeners))
GetRegistry().Unmonitor(endpoints, "foo", true, l)
watchVals = GetRegistry().clusters[getClusterKey(endpoints)].watchers[watchKey{
key: "foo",
exactMatch: true,
}]
assert.Nil(t, watchVals)
}
type mockListener struct {
}
func (m *mockListener) OnAdd(_ KV) {
}
func (m *mockListener) OnDelete(_ KV) {
}

View File

@@ -1,10 +1,5 @@
// Code generated by MockGen. DO NOT EDIT. // Code generated by MockGen. DO NOT EDIT.
// Source: statewatcher.go // Source: statewatcher.go
//
// Generated by this command:
//
// mockgen -package internal -destination statewatcher_mock.go -source statewatcher.go etcdConn
//
// Package internal is a generated GoMock package. // Package internal is a generated GoMock package.
package internal package internal
@@ -13,35 +8,34 @@ import (
context "context" context "context"
reflect "reflect" reflect "reflect"
gomock "go.uber.org/mock/gomock" gomock "github.com/golang/mock/gomock"
connectivity "google.golang.org/grpc/connectivity" connectivity "google.golang.org/grpc/connectivity"
) )
// MocketcdConn is a mock of etcdConn interface. // MocketcdConn is a mock of etcdConn interface
type MocketcdConn struct { type MocketcdConn struct {
ctrl *gomock.Controller ctrl *gomock.Controller
recorder *MocketcdConnMockRecorder recorder *MocketcdConnMockRecorder
isgomock struct{}
} }
// MocketcdConnMockRecorder is the mock recorder for MocketcdConn. // MocketcdConnMockRecorder is the mock recorder for MocketcdConn
type MocketcdConnMockRecorder struct { type MocketcdConnMockRecorder struct {
mock *MocketcdConn mock *MocketcdConn
} }
// NewMocketcdConn creates a new mock instance. // NewMocketcdConn creates a new mock instance
func NewMocketcdConn(ctrl *gomock.Controller) *MocketcdConn { func NewMocketcdConn(ctrl *gomock.Controller) *MocketcdConn {
mock := &MocketcdConn{ctrl: ctrl} mock := &MocketcdConn{ctrl: ctrl}
mock.recorder = &MocketcdConnMockRecorder{mock} mock.recorder = &MocketcdConnMockRecorder{mock}
return mock return mock
} }
// EXPECT returns an object that allows the caller to indicate expected use. // EXPECT returns an object that allows the caller to indicate expected use
func (m *MocketcdConn) EXPECT() *MocketcdConnMockRecorder { func (m *MocketcdConn) EXPECT() *MocketcdConnMockRecorder {
return m.recorder return m.recorder
} }
// GetState mocks base method. // GetState mocks base method
func (m *MocketcdConn) GetState() connectivity.State { func (m *MocketcdConn) GetState() connectivity.State {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetState") ret := m.ctrl.Call(m, "GetState")
@@ -49,13 +43,13 @@ func (m *MocketcdConn) GetState() connectivity.State {
return ret0 return ret0
} }
// GetState indicates an expected call of GetState. // GetState indicates an expected call of GetState
func (mr *MocketcdConnMockRecorder) GetState() *gomock.Call { func (mr *MocketcdConnMockRecorder) GetState() *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetState", reflect.TypeOf((*MocketcdConn)(nil).GetState)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetState", reflect.TypeOf((*MocketcdConn)(nil).GetState))
} }
// WaitForStateChange mocks base method. // WaitForStateChange mocks base method
func (m *MocketcdConn) WaitForStateChange(ctx context.Context, sourceState connectivity.State) bool { func (m *MocketcdConn) WaitForStateChange(ctx context.Context, sourceState connectivity.State) bool {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "WaitForStateChange", ctx, sourceState) ret := m.ctrl.Call(m, "WaitForStateChange", ctx, sourceState)
@@ -63,8 +57,8 @@ func (m *MocketcdConn) WaitForStateChange(ctx context.Context, sourceState conne
return ret0 return ret0
} }
// WaitForStateChange indicates an expected call of WaitForStateChange. // WaitForStateChange indicates an expected call of WaitForStateChange
func (mr *MocketcdConnMockRecorder) WaitForStateChange(ctx, sourceState any) *gomock.Call { func (mr *MocketcdConnMockRecorder) WaitForStateChange(ctx, sourceState interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WaitForStateChange", reflect.TypeOf((*MocketcdConn)(nil).WaitForStateChange), ctx, sourceState) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WaitForStateChange", reflect.TypeOf((*MocketcdConn)(nil).WaitForStateChange), ctx, sourceState)
} }

View File

@@ -4,7 +4,7 @@ import (
"sync" "sync"
"testing" "testing"
"go.uber.org/mock/gomock" "github.com/golang/mock/gomock"
"google.golang.org/grpc/connectivity" "google.golang.org/grpc/connectivity"
) )

View File

@@ -10,7 +10,6 @@ type (
} }
// UpdateListener wraps the OnAdd and OnDelete methods. // UpdateListener wraps the OnAdd and OnDelete methods.
// The implementation should be thread-safe and idempotent.
UpdateListener interface { UpdateListener interface {
OnAdd(kv KV) OnAdd(kv KV)
OnDelete(kv KV) OnDelete(kv KV)

View File

@@ -1,10 +1,5 @@
// Code generated by MockGen. DO NOT EDIT. // Code generated by MockGen. DO NOT EDIT.
// Source: updatelistener.go // Source: updatelistener.go
//
// Generated by this command:
//
// mockgen -package internal -destination updatelistener_mock.go -source updatelistener.go UpdateListener
//
// Package internal is a generated GoMock package. // Package internal is a generated GoMock package.
package internal package internal
@@ -12,53 +7,52 @@ package internal
import ( import (
reflect "reflect" reflect "reflect"
gomock "go.uber.org/mock/gomock" gomock "github.com/golang/mock/gomock"
) )
// MockUpdateListener is a mock of UpdateListener interface. // MockUpdateListener is a mock of UpdateListener interface
type MockUpdateListener struct { type MockUpdateListener struct {
ctrl *gomock.Controller ctrl *gomock.Controller
recorder *MockUpdateListenerMockRecorder recorder *MockUpdateListenerMockRecorder
isgomock struct{}
} }
// MockUpdateListenerMockRecorder is the mock recorder for MockUpdateListener. // MockUpdateListenerMockRecorder is the mock recorder for MockUpdateListener
type MockUpdateListenerMockRecorder struct { type MockUpdateListenerMockRecorder struct {
mock *MockUpdateListener mock *MockUpdateListener
} }
// NewMockUpdateListener creates a new mock instance. // NewMockUpdateListener creates a new mock instance
func NewMockUpdateListener(ctrl *gomock.Controller) *MockUpdateListener { func NewMockUpdateListener(ctrl *gomock.Controller) *MockUpdateListener {
mock := &MockUpdateListener{ctrl: ctrl} mock := &MockUpdateListener{ctrl: ctrl}
mock.recorder = &MockUpdateListenerMockRecorder{mock} mock.recorder = &MockUpdateListenerMockRecorder{mock}
return mock return mock
} }
// EXPECT returns an object that allows the caller to indicate expected use. // EXPECT returns an object that allows the caller to indicate expected use
func (m *MockUpdateListener) EXPECT() *MockUpdateListenerMockRecorder { func (m *MockUpdateListener) EXPECT() *MockUpdateListenerMockRecorder {
return m.recorder return m.recorder
} }
// OnAdd mocks base method. // OnAdd mocks base method
func (m *MockUpdateListener) OnAdd(kv KV) { func (m *MockUpdateListener) OnAdd(kv KV) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
m.ctrl.Call(m, "OnAdd", kv) m.ctrl.Call(m, "OnAdd", kv)
} }
// OnAdd indicates an expected call of OnAdd. // OnAdd indicates an expected call of OnAdd
func (mr *MockUpdateListenerMockRecorder) OnAdd(kv any) *gomock.Call { func (mr *MockUpdateListenerMockRecorder) OnAdd(kv interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnAdd", reflect.TypeOf((*MockUpdateListener)(nil).OnAdd), kv) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnAdd", reflect.TypeOf((*MockUpdateListener)(nil).OnAdd), kv)
} }
// OnDelete mocks base method. // OnDelete mocks base method
func (m *MockUpdateListener) OnDelete(kv KV) { func (m *MockUpdateListener) OnDelete(kv KV) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
m.ctrl.Call(m, "OnDelete", kv) m.ctrl.Call(m, "OnDelete", kv)
} }
// OnDelete indicates an expected call of OnDelete. // OnDelete indicates an expected call of OnDelete
func (mr *MockUpdateListenerMockRecorder) OnDelete(kv any) *gomock.Call { func (mr *MockUpdateListenerMockRecorder) OnDelete(kv interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnDelete", reflect.TypeOf((*MockUpdateListener)(nil).OnDelete), kv) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnDelete", reflect.TypeOf((*MockUpdateListener)(nil).OnDelete), kv)
} }

View File

@@ -9,6 +9,7 @@ const (
autoSyncInterval = time.Minute autoSyncInterval = time.Minute
coolDownInterval = time.Second coolDownInterval = time.Second
dialTimeout = 5 * time.Second dialTimeout = 5 * time.Second
dialKeepAliveTime = 5 * time.Second
requestTimeout = 3 * time.Second requestTimeout = 3 * time.Second
endpointsSeparator = "," endpointsSeparator = ","
) )

View File

@@ -5,7 +5,6 @@ import (
"github.com/zeromicro/go-zero/core/discov/internal" "github.com/zeromicro/go-zero/core/discov/internal"
"github.com/zeromicro/go-zero/core/lang" "github.com/zeromicro/go-zero/core/lang"
"github.com/zeromicro/go-zero/core/logc"
"github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/proc" "github.com/zeromicro/go-zero/core/proc"
"github.com/zeromicro/go-zero/core/syncx" "github.com/zeromicro/go-zero/core/syncx"
@@ -92,12 +91,12 @@ func (p *Publisher) doKeepAlive() error {
default: default:
cli, err := p.doRegister() cli, err := p.doRegister()
if err != nil { if err != nil {
logc.Errorf(cli.Ctx(), "etcd publisher doRegister: %v", err) logx.Errorf("etcd publisher doRegister: %s", err.Error())
break break
} }
if err := p.keepAliveAsync(cli); err != nil { if err := p.keepAliveAsync(cli); err != nil {
logc.Errorf(cli.Ctx(), "etcd publisher keepAliveAsync: %v", err) logx.Errorf("etcd publisher keepAliveAsync: %s", err.Error())
break break
} }
@@ -125,48 +124,23 @@ func (p *Publisher) keepAliveAsync(cli internal.EtcdClient) error {
} }
threading.GoSafe(func() { threading.GoSafe(func() {
wch := cli.Watch(cli.Ctx(), p.fullKey, clientv3.WithFilterPut())
for { for {
select { select {
case _, ok := <-ch: case _, ok := <-ch:
if !ok { if !ok {
p.revoke(cli) p.revoke(cli)
if err := p.doKeepAlive(); err != nil { if err := p.doKeepAlive(); err != nil {
logc.Errorf(cli.Ctx(), "etcd publisher KeepAlive: %v", err) logx.Errorf("etcd publisher KeepAlive: %s", err.Error())
} }
return return
} }
case c := <-wch:
if c.Err() != nil {
logc.Errorf(cli.Ctx(), "etcd publisher watch: %v", c.Err())
if err := p.doKeepAlive(); err != nil {
logc.Errorf(cli.Ctx(), "etcd publisher KeepAlive: %v", err)
}
return
}
for _, evt := range c.Events {
if evt.Type == clientv3.EventTypeDelete {
logc.Infof(cli.Ctx(), "etcd publisher watch: %s, event: %v",
evt.Kv.Key, evt.Type)
_, err := cli.Put(cli.Ctx(), p.fullKey, p.value, clientv3.WithLease(p.lease))
if err != nil {
logc.Errorf(cli.Ctx(), "etcd publisher re-put key: %v", err)
} else {
logc.Infof(cli.Ctx(), "etcd publisher re-put key: %s, value: %s",
p.fullKey, p.value)
}
}
}
case <-p.pauseChan: case <-p.pauseChan:
logc.Infof(cli.Ctx(), "paused etcd renew, key: %s, value: %s", p.key, p.value) logx.Infof("paused etcd renew, key: %s, value: %s", p.key, p.value)
p.revoke(cli) p.revoke(cli)
select { select {
case <-p.resumeChan: case <-p.resumeChan:
if err := p.doKeepAlive(); err != nil { if err := p.doKeepAlive(); err != nil {
logc.Errorf(cli.Ctx(), "etcd publisher KeepAlive: %v", err) logx.Errorf("etcd publisher KeepAlive: %s", err.Error())
} }
return return
case <-p.quit.Done(): case <-p.quit.Done():
@@ -201,7 +175,7 @@ func (p *Publisher) register(client internal.EtcdClient) (clientv3.LeaseID, erro
func (p *Publisher) revoke(cli internal.EtcdClient) { func (p *Publisher) revoke(cli internal.EtcdClient) {
if _, err := cli.Revoke(cli.Ctx(), p.lease); err != nil { if _, err := cli.Revoke(cli.Ctx(), p.lease); err != nil {
logc.Errorf(cli.Ctx(), "etcd publisher revoke: %v", err) logx.Errorf("etcd publisher revoke: %s", err.Error())
} }
} }

View File

@@ -1,99 +1,18 @@
package discov package discov
import ( import (
"context"
"errors" "errors"
"net"
"os"
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/discov/internal" "github.com/zeromicro/go-zero/core/discov/internal"
"github.com/zeromicro/go-zero/core/lang" "github.com/zeromicro/go-zero/core/lang"
"github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stringx" "github.com/zeromicro/go-zero/core/stringx"
"go.etcd.io/etcd/api/v3/mvccpb"
clientv3 "go.etcd.io/etcd/client/v3" clientv3 "go.etcd.io/etcd/client/v3"
"go.uber.org/mock/gomock"
"golang.org/x/net/http2"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/resolver/manual"
)
const (
certContent = `-----BEGIN CERTIFICATE-----
MIIDazCCAlOgAwIBAgIUEg9GVO2oaPn+YSmiqmFIuAo10WIwDQYJKoZIhvcNAQEM
BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAgFw0yMzAzMTExMzIxMjNaGA8yMTIz
MDIxNTEzMjEyM1owRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUx
ITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDCCASIwDQYJKoZIhvcN
AQEBBQADggEPADCCAQoCggEBALplXlWsIf0O/IgnIplmiZHKGnxyfyufyE2FBRNk
OofRqbKuPH8GNqbkvZm7N29fwTDAQ+mViAggCkDht4hOzoWJMA7KYJt8JnTSWL48
M1lcrpc9DL2gszC/JF/FGvyANbBtLklkZPFBGdHUX14pjrT937wqPtm+SqUHSvRT
B7bmwmm2drRcmhpVm98LSlV7uQ2EgnJgsLjBPITKUejLmVLHfgX0RwQ2xIpX9pS4
FCe1BTacwl2gGp7Mje7y4Mfv3o0ArJW6Tuwbjx59ZXwb1KIP71b7bT04AVS8ZeYO
UMLKKuB5UR9x9Rn6cLXOTWBpcMVyzDgrAFLZjnE9LPUolZMCAwEAAaNRME8wHwYD
VR0jBBgwFoAUeW8w8pmhncbRgTsl48k4/7wnfx8wCQYDVR0TBAIwADALBgNVHQ8E
BAMCBPAwFAYDVR0RBA0wC4IJbG9jYWxob3N0MA0GCSqGSIb3DQEBDAUAA4IBAQAI
y9xaoS88CLPBsX6mxfcTAFVfGNTRW9VN9Ng1cCnUR+YGoXGM/l+qP4f7p8ocdGwK
iYZErVTzXYIn+D27//wpY3klJk3gAnEUBT3QRkStBw7XnpbeZ2oPBK+cmDnCnZPS
BIF1wxPX7vIgaxs5Zsdqwk3qvZ4Djr2wP7LabNWTLSBKgQoUY45Liw6pffLwcGF9
UKlu54bvGze2SufISCR3ib+I+FLvqpvJhXToZWYb/pfI/HccuCL1oot1x8vx6DQy
U+TYxlZsKS5mdNxAX3dqEkEMsgEi+g/tzDPXJImfeCGGBhIOXLm8SRypiuGdEbc9
xkWYxRPegajuEZGvCqVs
-----END CERTIFICATE-----`
keyContent = `-----BEGIN RSA PRIVATE KEY-----
MIIEowIBAAKCAQEAumVeVawh/Q78iCcimWaJkcoafHJ/K5/ITYUFE2Q6h9Gpsq48
fwY2puS9mbs3b1/BMMBD6ZWICCAKQOG3iE7OhYkwDspgm3wmdNJYvjwzWVyulz0M
vaCzML8kX8Ua/IA1sG0uSWRk8UEZ0dRfXimOtP3fvCo+2b5KpQdK9FMHtubCabZ2
tFyaGlWb3wtKVXu5DYSCcmCwuME8hMpR6MuZUsd+BfRHBDbEilf2lLgUJ7UFNpzC
XaAansyN7vLgx+/ejQCslbpO7BuPHn1lfBvUog/vVvttPTgBVLxl5g5Qwsoq4HlR
H3H1Gfpwtc5NYGlwxXLMOCsAUtmOcT0s9SiVkwIDAQABAoIBAD5meTJNMgO55Kjg
ESExxpRcCIno+tHr5+6rvYtEXqPheOIsmmwb9Gfi4+Z3WpOaht5/Pz0Ppj6yGzyl
U//6AgGKb+BDuBvVcDpjwPnOxZIBCSHwejdxeQu0scSuA97MPS0XIAvJ5FEv7ijk
5Bht6SyGYURpECltHygoTNuGgGqmO+McCJRLE9L09lTBI6UQ/JQwWJqSr7wx6iPU
M1Ze/srIV+7cyEPu6i0DGjS1gSQKkX68Lqn1w6oE290O+OZvleO0gZ02fLDWCZke
aeD9+EU/Pw+rqm3H6o0szOFIpzhRp41FUdW9sybB3Yp3u7c/574E+04Z/e30LMKs
TCtE1QECgYEA3K7KIpw0NH2HXL5C3RHcLmr204xeBfS70riBQQuVUgYdmxak2ima
80RInskY8hRhSGTg0l+VYIH8cmjcUyqMSOELS5XfRH99r4QPiK8AguXg80T4VumY
W3Pf+zEC2ssgP/gYthV0g0Xj5m2QxktOF9tRw5nkg739ZR4dI9lm/iECgYEA2Dnf
uwEDGqHiQRF6/fh5BG/nGVMvrefkqx6WvTJQ3k/M/9WhxB+lr/8yH46TuS8N2b29
FoTf3Mr9T7pr/PWkOPzoY3P56nYbKU8xSwCim9xMzhBMzj8/N9ukJvXy27/VOz56
eQaKqnvdXNGtPJrIMDGHps2KKWlKLyAlapzjVTMCgYAA/W++tACv85g13EykfT4F
n0k4LbsGP9DP4zABQLIMyiY72eAncmRVjwrcW36XJ2xATOONTgx3gF3HjZzfaqNy
eD/6uNNllUTVEryXGmHgNHPL45VRnn6memCY2eFvZdXhM5W4y2PYaunY0MkDercA
+GTngbs6tBF88KOk04bYwQKBgFl68cRgsdkmnwwQYNaTKfmVGYzYaQXNzkqmWPko
xmCJo6tHzC7ubdG8iRCYHzfmahPuuj6EdGPZuSRyYFgJi5Ftz/nAN+84OxtIQ3zn
YWOgskQgaLh9YfsKsQ7Sf1NDOsnOnD5TX7UXl07fEpLe9vNCvAFiU8e5Y9LGudU5
4bYTAoGBAMdX3a3bXp4cZvXNBJ/QLVyxC6fP1Q4haCR1Od3m+T00Jth2IX2dk/fl
p6xiJT1av5JtYabv1dFKaXOS5s1kLGGuCCSKpkvFZm826aQ2AFm0XGqEQDLeei5b
A52Kpy/YJ+RkG4BTFtAooFq6DmA0cnoP6oPvG2h6XtDJwDTPInJb
-----END RSA PRIVATE KEY-----`
caContent = `-----BEGIN CERTIFICATE-----
MIIDbTCCAlWgAwIBAgIUBJvFoCowKich7MMfseJ+DYzzirowDQYJKoZIhvcNAQEM
BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAgFw0yMzAzMTExMzIxMDNaGA8yMTIz
MDIxNTEzMjEwM1owRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUx
ITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDCCASIwDQYJKoZIhvcN
AQEBBQADggEPADCCAQoCggEBAO4to2YMYj0bxgr2FCiweSTSFuPx33zSw2x/s9Wf
OR41bm2DFsyYT5f3sOIKlXZEdLmOKty2e3ho3yC0EyNpVHdykkkHT3aDI17quZax
kYi/URqqtl1Z08A22txolc04hAZisg2BypGi3vql81UW1t3zyloGnJoIAeXR9uca
ljP6Bk3bwsxoVBLi1JtHrO0hHLQaeHmKhAyrys06X0LRdn7Px48yRZlt6FaLSa8X
YiRM0G44bVy/h6BkoQjMYGwVmCVk6zjJ9U7ZPFqdnDMNxAfR+hjDnYodqdLDMTTR
1NPVrnEnNwFx0AMLvgt/ba/45vZCEAmSZnFXFAJJcM7ai9ECAwEAAaNTMFEwHQYD
VR0OBBYEFHlvMPKZoZ3G0YE7JePJOP+8J38fMB8GA1UdIwQYMBaAFHlvMPKZoZ3G
0YE7JePJOP+8J38fMA8GA1UdEwEB/wQFMAMBAf8wDQYJKoZIhvcNAQEMBQADggEB
AMX8dNulADOo9uQgBMyFb9TVra7iY0zZjzv4GY5XY7scd52n6CnfAPvYBBDnTr/O
BgNp5jaujb4+9u/2qhV3f9n+/3WOb2CmPehBgVSzlXqHeQ9lshmgwZPeem2T+8Tm
Nnc/xQnsUfCFszUDxpkr55+aLVM22j02RWqcZ4q7TAaVYL+kdFVMc8FoqG/0ro6A
BjE/Qn0Nn7ciX1VUjDt8l+k7ummPJTmzdi6i6E4AwO9dzrGNgGJ4aWL8cC6xYcIX
goVIRTFeONXSDno/oPjWHpIPt7L15heMpKBHNuzPkKx2YVqPHE5QZxWfS+Lzgx+Q
E2oTTM0rYKOZ8p6000mhvKI=
-----END CERTIFICATE-----`
) )
func init() { func init() {
@@ -118,7 +37,7 @@ func TestPublisher_register(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
} }
func TestPublisher_registerWithOptions(t *testing.T) { func TestPublisher_registerWithId(t *testing.T) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
defer ctrl.Finish() defer ctrl.Finish()
const id = 2 const id = 2
@@ -130,15 +49,7 @@ func TestPublisher_registerWithOptions(t *testing.T) {
ID: 1, ID: 1,
}, nil) }, nil)
cli.EXPECT().Put(gomock.Any(), makeEtcdKey("thekey", id), "thevalue", gomock.Any()) cli.EXPECT().Put(gomock.Any(), makeEtcdKey("thekey", id), "thevalue", gomock.Any())
pub := NewPublisher(nil, "thekey", "thevalue", WithId(id))
certFile := createTempFile(t, []byte(certContent))
defer os.Remove(certFile)
keyFile := createTempFile(t, []byte(keyContent))
defer os.Remove(keyFile)
caFile := createTempFile(t, []byte(caContent))
defer os.Remove(caFile)
pub := NewPublisher(nil, "thekey", "thevalue", WithId(id),
WithPubEtcdTLS(certFile, keyFile, caFile, true))
_, err := pub.register(cli) _, err := pub.register(cli)
assert.Nil(t, err) assert.Nil(t, err)
} }
@@ -212,12 +123,9 @@ func TestPublisher_keepAliveAsyncQuit(t *testing.T) {
defer restore() defer restore()
cli.EXPECT().Ctx().AnyTimes() cli.EXPECT().Ctx().AnyTimes()
cli.EXPECT().KeepAlive(gomock.Any(), id) cli.EXPECT().KeepAlive(gomock.Any(), id)
// Add Watch expectation for the new watch mechanism
watchChan := make(<-chan clientv3.WatchResponse)
cli.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any()).Return(watchChan)
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
cli.EXPECT().Revoke(gomock.Any(), id).Do(func(_, _ any) { cli.EXPECT().Revoke(gomock.Any(), id).Do(func(_, _ interface{}) {
wg.Done() wg.Done()
}) })
pub := NewPublisher(nil, "thekey", "thevalue") pub := NewPublisher(nil, "thekey", "thevalue")
@@ -236,13 +144,10 @@ func TestPublisher_keepAliveAsyncPause(t *testing.T) {
defer restore() defer restore()
cli.EXPECT().Ctx().AnyTimes() cli.EXPECT().Ctx().AnyTimes()
cli.EXPECT().KeepAlive(gomock.Any(), id) cli.EXPECT().KeepAlive(gomock.Any(), id)
// Add Watch expectation for the new watch mechanism
watchChan := make(<-chan clientv3.WatchResponse)
cli.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any()).Return(watchChan)
pub := NewPublisher(nil, "thekey", "thevalue") pub := NewPublisher(nil, "thekey", "thevalue")
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
cli.EXPECT().Revoke(gomock.Any(), id).Do(func(_, _ any) { cli.EXPECT().Revoke(gomock.Any(), id).Do(func(_, _ interface{}) {
pub.Stop() pub.Stop()
wg.Done() wg.Done()
}) })
@@ -252,112 +157,6 @@ func TestPublisher_keepAliveAsyncPause(t *testing.T) {
wg.Wait() wg.Wait()
} }
// Test case for key deletion and re-registration (covers lines 148-155)
func TestPublisher_keepAliveAsyncKeyDeletion(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
const id clientv3.LeaseID = 1
cli := internal.NewMockEtcdClient(ctrl)
restore := setMockClient(cli)
defer restore()
cli.EXPECT().Ctx().AnyTimes()
cli.EXPECT().KeepAlive(gomock.Any(), id)
// Create a watch channel that will send a delete event
watchChan := make(chan clientv3.WatchResponse, 1)
watchResp := clientv3.WatchResponse{
Events: []*clientv3.Event{{
Type: clientv3.EventTypeDelete,
Kv: &mvccpb.KeyValue{
Key: []byte("thekey"),
},
}},
}
watchChan <- watchResp
cli.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any()).Return((<-chan clientv3.WatchResponse)(watchChan))
var wg sync.WaitGroup
wg.Add(1) // Only wait for Revoke call
// Use a channel to signal when Put has been called
putCalled := make(chan struct{})
// Expect the re-put operation when key is deleted
cli.EXPECT().Put(gomock.Any(), "thekey", "thevalue", gomock.Any()).Do(func(_, _, _, _ any) {
close(putCalled) // Signal that Put has been called
}).Return(nil, nil)
// Expect revoke when Stop is called
cli.EXPECT().Revoke(gomock.Any(), id).Do(func(_, _ any) {
wg.Done()
})
pub := NewPublisher(nil, "thekey", "thevalue")
pub.lease = id
pub.fullKey = "thekey"
assert.Nil(t, pub.keepAliveAsync(cli))
// Wait for Put to be called, then stop
<-putCalled
pub.Stop()
wg.Wait()
}
// Test case for key deletion with re-put error (covers error branch in lines 151-152)
func TestPublisher_keepAliveAsyncKeyDeletionPutError(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
const id clientv3.LeaseID = 1
cli := internal.NewMockEtcdClient(ctrl)
restore := setMockClient(cli)
defer restore()
cli.EXPECT().Ctx().AnyTimes()
cli.EXPECT().KeepAlive(gomock.Any(), id)
// Create a watch channel that will send a delete event
watchChan := make(chan clientv3.WatchResponse, 1)
watchResp := clientv3.WatchResponse{
Events: []*clientv3.Event{{
Type: clientv3.EventTypeDelete,
Kv: &mvccpb.KeyValue{
Key: []byte("thekey"),
},
}},
}
watchChan <- watchResp
cli.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any()).Return((<-chan clientv3.WatchResponse)(watchChan))
var wg sync.WaitGroup
wg.Add(1) // Only wait for Revoke call
// Use a channel to signal when Put has been called
putCalled := make(chan struct{})
// Expect the re-put operation to fail
cli.EXPECT().Put(gomock.Any(), "thekey", "thevalue", gomock.Any()).Do(func(_, _, _, _ any) {
close(putCalled) // Signal that Put has been called
}).Return(nil, errors.New("put error"))
// Expect revoke when Stop is called
cli.EXPECT().Revoke(gomock.Any(), id).Do(func(_, _ any) {
wg.Done()
})
pub := NewPublisher(nil, "thekey", "thevalue")
pub.lease = id
pub.fullKey = "thekey"
assert.Nil(t, pub.keepAliveAsync(cli))
// Wait for Put to be called, then stop
<-putCalled
pub.Stop()
wg.Wait()
}
func TestPublisher_Resume(t *testing.T) { func TestPublisher_Resume(t *testing.T) {
publisher := new(Publisher) publisher := new(Publisher)
publisher.resumeChan = make(chan lang.PlaceholderType) publisher.resumeChan = make(chan lang.PlaceholderType)
@@ -370,95 +169,3 @@ func TestPublisher_Resume(t *testing.T) {
}() }()
<-publisher.resumeChan <-publisher.resumeChan
} }
func TestPublisher_keepAliveAsync(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
const id clientv3.LeaseID = 1
conn := createMockConn(t)
defer conn.Close()
cli := internal.NewMockEtcdClient(ctrl)
cli.EXPECT().ActiveConnection().Return(conn).AnyTimes()
cli.EXPECT().Close()
defer cli.Close()
cli.ActiveConnection()
restore := setMockClient(cli)
defer restore()
cli.EXPECT().Ctx().AnyTimes()
cli.EXPECT().KeepAlive(gomock.Any(), id)
// Add Watch expectation for the new watch mechanism
watchChan := make(<-chan clientv3.WatchResponse)
cli.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any()).Return(watchChan)
cli.EXPECT().Grant(gomock.Any(), timeToLive).Return(&clientv3.LeaseGrantResponse{
ID: 1,
}, nil)
cli.EXPECT().Put(gomock.Any(), makeEtcdKey("thekey", int64(id)), "thevalue", gomock.Any())
var wg sync.WaitGroup
wg.Add(1)
cli.EXPECT().Revoke(gomock.Any(), id).Do(func(_, _ any) {
wg.Done()
})
pub := NewPublisher([]string{"the-endpoint"}, "thekey", "thevalue")
pub.lease = id
assert.Nil(t, pub.KeepAlive())
pub.Stop()
wg.Wait()
}
func createMockConn(t *testing.T) *grpc.ClientConn {
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Error while listening. Err: %v", err)
}
defer lis.Close()
lisAddr := resolver.Address{Addr: lis.Addr().String()}
lisDone := make(chan struct{})
dialDone := make(chan struct{})
// 1st listener accepts the connection and then does nothing
go func() {
defer close(lisDone)
conn, err := lis.Accept()
if err != nil {
t.Errorf("Error while accepting. Err: %v", err)
return
}
framer := http2.NewFramer(conn, conn)
if err := framer.WriteSettings(http2.Setting{}); err != nil {
t.Errorf("Error while writing settings. Err: %v", err)
return
}
<-dialDone // Close conn only after dial returns.
}()
r := manual.NewBuilderWithScheme("whatever")
r.InitialState(resolver.State{Addresses: []resolver.Address{lisAddr}})
client, err := grpc.DialContext(context.Background(), r.Scheme()+":///test.server",
grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithResolvers(r))
close(dialDone)
if err != nil {
t.Fatalf("Dial failed. Err: %v", err)
}
timeout := time.After(1 * time.Second)
select {
case <-timeout:
t.Fatal("timed out waiting for server to finish")
case <-lisDone:
}
return client
}
func createTempFile(t *testing.T, body []byte) string {
tmpFile, err := os.CreateTemp(os.TempDir(), "go-unit-*.tmp")
if err != nil {
t.Fatal(err)
}
tmpFile.Close()
if err = os.WriteFile(tmpFile.Name(), body, os.ModePerm); err != nil {
t.Fatal(err)
}
return tmpFile.Name()
}

View File

@@ -15,11 +15,9 @@ type (
// A Subscriber is used to subscribe the given key on an etcd cluster. // A Subscriber is used to subscribe the given key on an etcd cluster.
Subscriber struct { Subscriber struct {
endpoints []string endpoints []string
exclusive bool exclusive bool
key string items *container
exactMatch bool
items *container
} }
) )
@@ -30,14 +28,13 @@ type (
func NewSubscriber(endpoints []string, key string, opts ...SubOption) (*Subscriber, error) { func NewSubscriber(endpoints []string, key string, opts ...SubOption) (*Subscriber, error) {
sub := &Subscriber{ sub := &Subscriber{
endpoints: endpoints, endpoints: endpoints,
key: key,
} }
for _, opt := range opts { for _, opt := range opts {
opt(sub) opt(sub)
} }
sub.items = newContainer(sub.exclusive) sub.items = newContainer(sub.exclusive)
if err := internal.GetRegistry().Monitor(endpoints, key, sub.exactMatch, sub.items); err != nil { if err := internal.GetRegistry().Monitor(endpoints, key, sub.items); err != nil {
return nil, err return nil, err
} }
@@ -49,11 +46,6 @@ func (s *Subscriber) AddListener(listener func()) {
s.items.addListener(listener) s.items.addListener(listener)
} }
// Close closes the subscriber.
func (s *Subscriber) Close() {
internal.GetRegistry().Unmonitor(s.endpoints, s.key, s.exactMatch, s.items)
}
// Values returns all the subscription values. // Values returns all the subscription values.
func (s *Subscriber) Values() []string { func (s *Subscriber) Values() []string {
return s.items.getValues() return s.items.getValues()
@@ -67,13 +59,6 @@ func Exclusive() SubOption {
} }
} }
// WithExactMatch turn off querying using key prefixes.
func WithExactMatch() SubOption {
return func(sub *Subscriber) {
sub.exactMatch = true
}
}
// WithSubEtcdAccount provides the etcd username/password. // WithSubEtcdAccount provides the etcd username/password.
func WithSubEtcdAccount(user, pass string) SubOption { func WithSubEtcdAccount(user, pass string) SubOption {
return func(sub *Subscriber) { return func(sub *Subscriber) {

View File

@@ -225,28 +225,3 @@ func TestWithSubEtcdAccount(t *testing.T) {
assert.Equal(t, user, account.User) assert.Equal(t, user, account.User)
assert.Equal(t, "bar", account.Pass) assert.Equal(t, "bar", account.Pass)
} }
func TestWithExactMatch(t *testing.T) {
sub := new(Subscriber)
WithExactMatch()(sub)
sub.items = newContainer(sub.exclusive)
var count int32
sub.AddListener(func() {
atomic.AddInt32(&count, 1)
})
sub.items.notifyChange()
assert.Empty(t, sub.Values())
assert.Equal(t, int32(1), atomic.LoadInt32(&count))
}
func TestSubscriberClose(t *testing.T) {
l := newContainer(false)
sub := &Subscriber{
endpoints: []string{"localhost:12379"},
key: "foo",
items: l,
}
assert.NotPanics(t, func() {
sub.Close()
})
}

View File

@@ -1,21 +1,18 @@
package errorx package errorx
import ( import "bytes"
"errors"
"sync" type (
// A BatchError is an error that can hold multiple errors.
BatchError struct {
errs errorArray
}
errorArray []error
) )
// BatchError is an error that can hold multiple errors. // Add adds errs to be, nil errors are ignored.
type BatchError struct {
errs []error
lock sync.RWMutex
}
// Add adds one or more non-nil errors to the BatchError instance.
func (be *BatchError) Add(errs ...error) { func (be *BatchError) Add(errs ...error) {
be.lock.Lock()
defer be.lock.Unlock()
for _, err := range errs { for _, err := range errs {
if err != nil { if err != nil {
be.errs = append(be.errs, err) be.errs = append(be.errs, err)
@@ -23,20 +20,33 @@ func (be *BatchError) Add(errs ...error) {
} }
} }
// Err returns an error that represents all accumulated errors. // Err returns an error that represents all errors.
// It returns nil if there are no errors.
func (be *BatchError) Err() error { func (be *BatchError) Err() error {
be.lock.RLock() switch len(be.errs) {
defer be.lock.RUnlock() case 0:
return nil
// If there are no non-nil errors, errors.Join(...) returns nil. case 1:
return errors.Join(be.errs...) return be.errs[0]
default:
return be.errs
}
} }
// NotNil checks if there is at least one error inside the BatchError. // NotNil checks if any error inside.
func (be *BatchError) NotNil() bool { func (be *BatchError) NotNil() bool {
be.lock.RLock()
defer be.lock.RUnlock()
return len(be.errs) > 0 return len(be.errs) > 0
} }
// Error returns a string that represents inside errors.
func (ea errorArray) Error() string {
var buf bytes.Buffer
for i := range ea {
if i > 0 {
buf.WriteByte('\n')
}
buf.WriteString(ea[i].Error())
}
return buf.String()
}

View File

@@ -3,7 +3,6 @@ package errorx
import ( import (
"errors" "errors"
"fmt" "fmt"
"sync"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -34,7 +33,7 @@ func TestBatchErrorNilFromFunc(t *testing.T) {
func TestBatchErrorOneError(t *testing.T) { func TestBatchErrorOneError(t *testing.T) {
var batch BatchError var batch BatchError
batch.Add(errors.New(err1)) batch.Add(errors.New(err1))
assert.NotNil(t, batch.Err()) assert.NotNil(t, batch)
assert.Equal(t, err1, batch.Err().Error()) assert.Equal(t, err1, batch.Err().Error())
assert.True(t, batch.NotNil()) assert.True(t, batch.NotNil())
} }
@@ -43,105 +42,7 @@ func TestBatchErrorWithErrors(t *testing.T) {
var batch BatchError var batch BatchError
batch.Add(errors.New(err1)) batch.Add(errors.New(err1))
batch.Add(errors.New(err2)) batch.Add(errors.New(err2))
assert.NotNil(t, batch.Err()) assert.NotNil(t, batch)
assert.Equal(t, fmt.Sprintf("%s\n%s", err1, err2), batch.Err().Error()) assert.Equal(t, fmt.Sprintf("%s\n%s", err1, err2), batch.Err().Error())
assert.True(t, batch.NotNil()) assert.True(t, batch.NotNil())
} }
func TestBatchErrorConcurrentAdd(t *testing.T) {
const count = 10000
var batch BatchError
var wg sync.WaitGroup
wg.Add(count)
for i := 0; i < count; i++ {
go func() {
defer wg.Done()
batch.Add(errors.New(err1))
}()
}
wg.Wait()
assert.NotNil(t, batch.Err())
assert.Equal(t, count, len(batch.errs))
assert.True(t, batch.NotNil())
}
func TestBatchError_Unwrap(t *testing.T) {
t.Run("nil", func(t *testing.T) {
var be BatchError
assert.Nil(t, be.Err())
assert.True(t, errors.Is(be.Err(), nil))
})
t.Run("one error", func(t *testing.T) {
var errFoo = errors.New("foo")
var errBar = errors.New("bar")
var be BatchError
be.Add(errFoo)
assert.True(t, errors.Is(be.Err(), errFoo))
assert.False(t, errors.Is(be.Err(), errBar))
})
t.Run("two errors", func(t *testing.T) {
var errFoo = errors.New("foo")
var errBar = errors.New("bar")
var errBaz = errors.New("baz")
var be BatchError
be.Add(errFoo)
be.Add(errBar)
assert.True(t, errors.Is(be.Err(), errFoo))
assert.True(t, errors.Is(be.Err(), errBar))
assert.False(t, errors.Is(be.Err(), errBaz))
})
}
func TestBatchError_Add(t *testing.T) {
var be BatchError
// Test adding nil errors
be.Add(nil, nil)
assert.False(t, be.NotNil(), "Expected BatchError to be empty after adding nil errors")
// Test adding non-nil errors
err1 := errors.New("error 1")
err2 := errors.New("error 2")
be.Add(err1, err2)
assert.True(t, be.NotNil(), "Expected BatchError to be non-empty after adding errors")
// Test adding a mix of nil and non-nil errors
err3 := errors.New("error 3")
be.Add(nil, err3, nil)
assert.True(t, be.NotNil(), "Expected BatchError to be non-empty after adding a mix of nil and non-nil errors")
}
func TestBatchError_Err(t *testing.T) {
var be BatchError
// Test Err() on empty BatchError
assert.Nil(t, be.Err(), "Expected nil error for empty BatchError")
// Test Err() with multiple errors
err1 := errors.New("error 1")
err2 := errors.New("error 2")
be.Add(err1, err2)
combinedErr := be.Err()
assert.NotNil(t, combinedErr, "Expected nil error for BatchError with multiple errors")
// Check if the combined error contains both error messages
errString := combinedErr.Error()
assert.Truef(t, errors.Is(combinedErr, err1), "Combined error doesn't contain first error: %s", errString)
assert.Truef(t, errors.Is(combinedErr, err2), "Combined error doesn't contain second error: %s", errString)
}
func TestBatchError_NotNil(t *testing.T) {
var be BatchError
// Test NotNil() on empty BatchError
assert.Nil(t, be.Err(), "Expected nil error for empty BatchError")
// Test NotNil() after adding an error
be.Add(errors.New("test error"))
assert.NotNil(t, be.Err(), "Expected non-nil error after adding an error")
}

View File

@@ -1,14 +0,0 @@
package errorx
import "errors"
// In checks if the given err is one of errs.
func In(err error, errs ...error) bool {
for _, each := range errs {
if errors.Is(err, each) {
return true
}
}
return false
}

View File

@@ -1,70 +0,0 @@
package errorx
import (
"errors"
"testing"
)
func TestIn(t *testing.T) {
err1 := errors.New("error 1")
err2 := errors.New("error 2")
err3 := errors.New("error 3")
tests := []struct {
name string
err error
errs []error
want bool
}{
{
name: "Error matches one of the errors in the list",
err: err1,
errs: []error{err1, err2},
want: true,
},
{
name: "Error does not match any errors in the list",
err: err3,
errs: []error{err1, err2},
want: false,
},
{
name: "Empty error list",
err: err1,
errs: []error{},
want: false,
},
{
name: "Nil error with non-nil list",
err: nil,
errs: []error{err1, err2},
want: false,
},
{
name: "Non-nil error with nil in list",
err: err1,
errs: []error{nil, err2},
want: false,
},
{
name: "Error matches nil error in the list",
err: nil,
errs: []error{nil, err2},
want: true,
},
{
name: "Nil error with empty list",
err: nil,
errs: []error{},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := In(tt.err, tt.errs...); got != tt.want {
t.Errorf("In() = %v, want %v", got, tt.want)
}
})
}
}

View File

@@ -12,7 +12,7 @@ func Wrap(err error, message string) error {
} }
// Wrapf returns an error that wraps err with given format and args. // Wrapf returns an error that wraps err with given format and args.
func Wrapf(err error, format string, args ...any) error { func Wrapf(err error, format string, args ...interface{}) error {
if err == nil { if err == nil {
return nil return nil
} }

View File

@@ -42,7 +42,7 @@ func NewBulkExecutor(execute Execute, opts ...BulkOption) *BulkExecutor {
} }
// Add adds task into be. // Add adds task into be.
func (be *BulkExecutor) Add(task any) error { func (be *BulkExecutor) Add(task interface{}) error {
be.executor.Add(task) be.executor.Add(task)
return nil return nil
} }
@@ -79,22 +79,22 @@ func newBulkOptions() bulkOptions {
} }
type bulkContainer struct { type bulkContainer struct {
tasks []any tasks []interface{}
execute Execute execute Execute
maxTasks int maxTasks int
} }
func (bc *bulkContainer) AddTask(task any) bool { func (bc *bulkContainer) AddTask(task interface{}) bool {
bc.tasks = append(bc.tasks, task) bc.tasks = append(bc.tasks, task)
return len(bc.tasks) >= bc.maxTasks return len(bc.tasks) >= bc.maxTasks
} }
func (bc *bulkContainer) Execute(tasks any) { func (bc *bulkContainer) Execute(tasks interface{}) {
vals := tasks.([]any) vals := tasks.([]interface{})
bc.execute(vals) bc.execute(vals)
} }
func (bc *bulkContainer) RemoveAll() any { func (bc *bulkContainer) RemoveAll() interface{} {
tasks := bc.tasks tasks := bc.tasks
bc.tasks = nil bc.tasks = nil
return tasks return tasks

View File

@@ -12,7 +12,7 @@ func TestBulkExecutor(t *testing.T) {
var values []int var values []int
var lock sync.Mutex var lock sync.Mutex
executor := NewBulkExecutor(func(items []any) { executor := NewBulkExecutor(func(items []interface{}) {
lock.Lock() lock.Lock()
values = append(values, len(items)) values = append(values, len(items))
lock.Unlock() lock.Unlock()
@@ -40,7 +40,7 @@ func TestBulkExecutorFlushInterval(t *testing.T) {
var wait sync.WaitGroup var wait sync.WaitGroup
wait.Add(1) wait.Add(1)
executor := NewBulkExecutor(func(items []any) { executor := NewBulkExecutor(func(items []interface{}) {
assert.Equal(t, size, len(items)) assert.Equal(t, size, len(items))
wait.Done() wait.Done()
}, WithBulkTasks(caches), WithBulkInterval(time.Millisecond*100)) }, WithBulkTasks(caches), WithBulkInterval(time.Millisecond*100))
@@ -53,7 +53,7 @@ func TestBulkExecutorFlushInterval(t *testing.T) {
} }
func TestBulkExecutorEmpty(t *testing.T) { func TestBulkExecutorEmpty(t *testing.T) {
NewBulkExecutor(func(items []any) { NewBulkExecutor(func(items []interface{}) {
assert.Fail(t, "should not called") assert.Fail(t, "should not called")
}, WithBulkTasks(10), WithBulkInterval(time.Millisecond)) }, WithBulkTasks(10), WithBulkInterval(time.Millisecond))
time.Sleep(time.Millisecond * 100) time.Sleep(time.Millisecond * 100)
@@ -67,7 +67,7 @@ func TestBulkExecutorFlush(t *testing.T) {
var wait sync.WaitGroup var wait sync.WaitGroup
wait.Add(1) wait.Add(1)
be := NewBulkExecutor(func(items []any) { be := NewBulkExecutor(func(items []interface{}) {
assert.Equal(t, tasks, len(items)) assert.Equal(t, tasks, len(items))
wait.Done() wait.Done()
}, WithBulkTasks(caches), WithBulkInterval(time.Minute)) }, WithBulkTasks(caches), WithBulkInterval(time.Minute))
@@ -78,11 +78,11 @@ func TestBulkExecutorFlush(t *testing.T) {
wait.Wait() wait.Wait()
} }
func TestBulkExecutorFlushSlowTasks(t *testing.T) { func TestBuldExecutorFlushSlowTasks(t *testing.T) {
const total = 1500 const total = 1500
lock := new(sync.Mutex) lock := new(sync.Mutex)
result := make([]any, 0, 10000) result := make([]interface{}, 0, 10000)
exec := NewBulkExecutor(func(tasks []any) { exec := NewBulkExecutor(func(tasks []interface{}) {
time.Sleep(time.Millisecond * 100) time.Sleep(time.Millisecond * 100)
lock.Lock() lock.Lock()
defer lock.Unlock() defer lock.Unlock()
@@ -100,7 +100,7 @@ func TestBulkExecutorFlushSlowTasks(t *testing.T) {
func BenchmarkBulkExecutor(b *testing.B) { func BenchmarkBulkExecutor(b *testing.B) {
b.ReportAllocs() b.ReportAllocs()
be := NewBulkExecutor(func(tasks []any) { be := NewBulkExecutor(func(tasks []interface{}) {
time.Sleep(time.Millisecond * time.Duration(len(tasks))) time.Sleep(time.Millisecond * time.Duration(len(tasks)))
}) })
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {

View File

@@ -42,7 +42,7 @@ func NewChunkExecutor(execute Execute, opts ...ChunkOption) *ChunkExecutor {
} }
// Add adds task with given chunk size into ce. // Add adds task with given chunk size into ce.
func (ce *ChunkExecutor) Add(task any, size int) error { func (ce *ChunkExecutor) Add(task interface{}, size int) error {
ce.executor.Add(chunk{ ce.executor.Add(chunk{
val: task, val: task,
size: size, size: size,
@@ -82,25 +82,25 @@ func newChunkOptions() chunkOptions {
} }
type chunkContainer struct { type chunkContainer struct {
tasks []any tasks []interface{}
execute Execute execute Execute
size int size int
maxChunkSize int maxChunkSize int
} }
func (bc *chunkContainer) AddTask(task any) bool { func (bc *chunkContainer) AddTask(task interface{}) bool {
ck := task.(chunk) ck := task.(chunk)
bc.tasks = append(bc.tasks, ck.val) bc.tasks = append(bc.tasks, ck.val)
bc.size += ck.size bc.size += ck.size
return bc.size >= bc.maxChunkSize return bc.size >= bc.maxChunkSize
} }
func (bc *chunkContainer) Execute(tasks any) { func (bc *chunkContainer) Execute(tasks interface{}) {
vals := tasks.([]any) vals := tasks.([]interface{})
bc.execute(vals) bc.execute(vals)
} }
func (bc *chunkContainer) RemoveAll() any { func (bc *chunkContainer) RemoveAll() interface{} {
tasks := bc.tasks tasks := bc.tasks
bc.tasks = nil bc.tasks = nil
bc.size = 0 bc.size = 0
@@ -108,6 +108,6 @@ func (bc *chunkContainer) RemoveAll() any {
} }
type chunk struct { type chunk struct {
val any val interface{}
size int size int
} }

View File

@@ -12,7 +12,7 @@ func TestChunkExecutor(t *testing.T) {
var values []int var values []int
var lock sync.Mutex var lock sync.Mutex
executor := NewChunkExecutor(func(items []any) { executor := NewChunkExecutor(func(items []interface{}) {
lock.Lock() lock.Lock()
values = append(values, len(items)) values = append(values, len(items))
lock.Unlock() lock.Unlock()
@@ -40,7 +40,7 @@ func TestChunkExecutorFlushInterval(t *testing.T) {
var wait sync.WaitGroup var wait sync.WaitGroup
wait.Add(1) wait.Add(1)
executor := NewChunkExecutor(func(items []any) { executor := NewChunkExecutor(func(items []interface{}) {
assert.Equal(t, size, len(items)) assert.Equal(t, size, len(items))
wait.Done() wait.Done()
}, WithChunkBytes(caches), WithFlushInterval(time.Millisecond*100)) }, WithChunkBytes(caches), WithFlushInterval(time.Millisecond*100))
@@ -53,7 +53,7 @@ func TestChunkExecutorFlushInterval(t *testing.T) {
} }
func TestChunkExecutorEmpty(t *testing.T) { func TestChunkExecutorEmpty(t *testing.T) {
executor := NewChunkExecutor(func(items []any) { executor := NewChunkExecutor(func(items []interface{}) {
assert.Fail(t, "should not called") assert.Fail(t, "should not called")
}, WithChunkBytes(10), WithFlushInterval(time.Millisecond)) }, WithChunkBytes(10), WithFlushInterval(time.Millisecond))
time.Sleep(time.Millisecond * 100) time.Sleep(time.Millisecond * 100)
@@ -68,7 +68,7 @@ func TestChunkExecutorFlush(t *testing.T) {
var wait sync.WaitGroup var wait sync.WaitGroup
wait.Add(1) wait.Add(1)
be := NewChunkExecutor(func(items []any) { be := NewChunkExecutor(func(items []interface{}) {
assert.Equal(t, tasks, len(items)) assert.Equal(t, tasks, len(items))
wait.Done() wait.Done()
}, WithChunkBytes(caches), WithFlushInterval(time.Minute)) }, WithChunkBytes(caches), WithFlushInterval(time.Minute))
@@ -82,7 +82,7 @@ func TestChunkExecutorFlush(t *testing.T) {
func BenchmarkChunkExecutor(b *testing.B) { func BenchmarkChunkExecutor(b *testing.B) {
b.ReportAllocs() b.ReportAllocs()
be := NewChunkExecutor(func(tasks []any) { be := NewChunkExecutor(func(tasks []interface{}) {
time.Sleep(time.Millisecond * time.Duration(len(tasks))) time.Sleep(time.Millisecond * time.Duration(len(tasks)))
}) })
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {

View File

@@ -21,16 +21,16 @@ type (
TaskContainer interface { TaskContainer interface {
// AddTask adds the task into the container. // AddTask adds the task into the container.
// Returns true if the container needs to be flushed after the addition. // Returns true if the container needs to be flushed after the addition.
AddTask(task any) bool AddTask(task interface{}) bool
// Execute handles the collected tasks by the container when flushing. // Execute handles the collected tasks by the container when flushing.
Execute(tasks any) Execute(tasks interface{})
// RemoveAll removes the contained tasks, and return them. // RemoveAll removes the contained tasks, and return them.
RemoveAll() any RemoveAll() interface{}
} }
// A PeriodicalExecutor is an executor that periodically execute tasks. // A PeriodicalExecutor is an executor that periodically execute tasks.
PeriodicalExecutor struct { PeriodicalExecutor struct {
commander chan any commander chan interface{}
interval time.Duration interval time.Duration
container TaskContainer container TaskContainer
waitGroup sync.WaitGroup waitGroup sync.WaitGroup
@@ -48,7 +48,7 @@ type (
func NewPeriodicalExecutor(interval time.Duration, container TaskContainer) *PeriodicalExecutor { func NewPeriodicalExecutor(interval time.Duration, container TaskContainer) *PeriodicalExecutor {
executor := &PeriodicalExecutor{ executor := &PeriodicalExecutor{
// buffer 1 to let the caller go quickly // buffer 1 to let the caller go quickly
commander: make(chan any, 1), commander: make(chan interface{}, 1),
interval: interval, interval: interval,
container: container, container: container,
confirmChan: make(chan lang.PlaceholderType), confirmChan: make(chan lang.PlaceholderType),
@@ -64,7 +64,7 @@ func NewPeriodicalExecutor(interval time.Duration, container TaskContainer) *Per
} }
// Add adds tasks into pe. // Add adds tasks into pe.
func (pe *PeriodicalExecutor) Add(task any) { func (pe *PeriodicalExecutor) Add(task interface{}) {
if vals, ok := pe.addAndCheck(task); ok { if vals, ok := pe.addAndCheck(task); ok {
pe.commander <- vals pe.commander <- vals
<-pe.confirmChan <-pe.confirmChan
@@ -74,14 +74,14 @@ func (pe *PeriodicalExecutor) Add(task any) {
// Flush forces pe to execute tasks. // Flush forces pe to execute tasks.
func (pe *PeriodicalExecutor) Flush() bool { func (pe *PeriodicalExecutor) Flush() bool {
pe.enterExecution() pe.enterExecution()
return pe.executeTasks(func() any { return pe.executeTasks(func() interface{} {
pe.lock.Lock() pe.lock.Lock()
defer pe.lock.Unlock() defer pe.lock.Unlock()
return pe.container.RemoveAll() return pe.container.RemoveAll()
}()) }())
} }
// Sync lets caller run fn thread-safe with pe, especially for the underlying container. // Sync lets caller to run fn thread-safe with pe, especially for the underlying container.
func (pe *PeriodicalExecutor) Sync(fn func()) { func (pe *PeriodicalExecutor) Sync(fn func()) {
pe.lock.Lock() pe.lock.Lock()
defer pe.lock.Unlock() defer pe.lock.Unlock()
@@ -96,7 +96,7 @@ func (pe *PeriodicalExecutor) Wait() {
}) })
} }
func (pe *PeriodicalExecutor) addAndCheck(task any) (any, bool) { func (pe *PeriodicalExecutor) addAndCheck(task interface{}) (interface{}, bool) {
pe.lock.Lock() pe.lock.Lock()
defer func() { defer func() {
if !pe.guarded { if !pe.guarded {
@@ -116,7 +116,7 @@ func (pe *PeriodicalExecutor) addAndCheck(task any) (any, bool) {
} }
func (pe *PeriodicalExecutor) backgroundFlush() { func (pe *PeriodicalExecutor) backgroundFlush() {
go func() { threading.GoSafe(func() {
// flush before quit goroutine to avoid missing tasks // flush before quit goroutine to avoid missing tasks
defer pe.Flush() defer pe.Flush()
@@ -144,7 +144,7 @@ func (pe *PeriodicalExecutor) backgroundFlush() {
} }
} }
} }
}() })
} }
func (pe *PeriodicalExecutor) doneExecution() { func (pe *PeriodicalExecutor) doneExecution() {
@@ -157,20 +157,18 @@ func (pe *PeriodicalExecutor) enterExecution() {
}) })
} }
func (pe *PeriodicalExecutor) executeTasks(tasks any) bool { func (pe *PeriodicalExecutor) executeTasks(tasks interface{}) bool {
defer pe.doneExecution() defer pe.doneExecution()
ok := pe.hasTasks(tasks) ok := pe.hasTasks(tasks)
if ok { if ok {
threading.RunSafe(func() { pe.container.Execute(tasks)
pe.container.Execute(tasks)
})
} }
return ok return ok
} }
func (pe *PeriodicalExecutor) hasTasks(tasks any) bool { func (pe *PeriodicalExecutor) hasTasks(tasks interface{}) bool {
if tasks == nil { if tasks == nil {
return false return false
} }

View File

@@ -17,22 +17,22 @@ const threshold = 10
type container struct { type container struct {
interval time.Duration interval time.Duration
tasks []int tasks []int
execute func(tasks any) execute func(tasks interface{})
} }
func newContainer(interval time.Duration, execute func(tasks any)) *container { func newContainer(interval time.Duration, execute func(tasks interface{})) *container {
return &container{ return &container{
interval: interval, interval: interval,
execute: execute, execute: execute,
} }
} }
func (c *container) AddTask(task any) bool { func (c *container) AddTask(task interface{}) bool {
c.tasks = append(c.tasks, task.(int)) c.tasks = append(c.tasks, task.(int))
return len(c.tasks) > threshold return len(c.tasks) > threshold
} }
func (c *container) Execute(tasks any) { func (c *container) Execute(tasks interface{}) {
if c.execute != nil { if c.execute != nil {
c.execute(tasks) c.execute(tasks)
} else { } else {
@@ -40,7 +40,7 @@ func (c *container) Execute(tasks any) {
} }
} }
func (c *container) RemoveAll() any { func (c *container) RemoveAll() interface{} {
tasks := c.tasks tasks := c.tasks
c.tasks = nil c.tasks = nil
return tasks return tasks
@@ -76,7 +76,7 @@ func TestPeriodicalExecutor_Bulk(t *testing.T) {
var vals []int var vals []int
// avoid data race // avoid data race
var lock sync.Mutex var lock sync.Mutex
exec := NewPeriodicalExecutor(time.Millisecond, newContainer(time.Millisecond, func(tasks any) { exec := NewPeriodicalExecutor(time.Millisecond, newContainer(time.Millisecond, func(tasks interface{}) {
t := tasks.([]int) t := tasks.([]int)
for _, each := range t { for _, each := range t {
lock.Lock() lock.Lock()
@@ -108,83 +108,25 @@ func TestPeriodicalExecutor_Bulk(t *testing.T) {
lock.Unlock() lock.Unlock()
} }
func TestPeriodicalExecutor_Panic(t *testing.T) {
// avoid data race
var lock sync.Mutex
ticker := timex.NewFakeTicker()
var (
executedTasks []int
expected []int
)
executor := NewPeriodicalExecutor(time.Millisecond, newContainer(time.Millisecond, func(tasks any) {
tt := tasks.([]int)
lock.Lock()
executedTasks = append(executedTasks, tt...)
lock.Unlock()
if tt[0] == 0 {
panic("test")
}
}))
executor.newTicker = func(duration time.Duration) timex.Ticker {
return ticker
}
for i := 0; i < 30; i++ {
executor.Add(i)
expected = append(expected, i)
}
ticker.Tick()
ticker.Tick()
time.Sleep(time.Millisecond)
lock.Lock()
assert.Equal(t, expected, executedTasks)
lock.Unlock()
}
func TestPeriodicalExecutor_FlushPanic(t *testing.T) {
var (
executedTasks []int
expected []int
lock sync.Mutex
)
executor := NewPeriodicalExecutor(time.Millisecond, newContainer(time.Millisecond, func(tasks any) {
tt := tasks.([]int)
lock.Lock()
executedTasks = append(executedTasks, tt...)
lock.Unlock()
if tt[0] == 0 {
panic("flush panic")
}
}))
for i := 0; i < 8; i++ {
executor.Add(i)
expected = append(expected, i)
}
executor.Flush()
lock.Lock()
assert.Equal(t, expected, executedTasks)
lock.Unlock()
}
func TestPeriodicalExecutor_Wait(t *testing.T) { func TestPeriodicalExecutor_Wait(t *testing.T) {
var lock sync.Mutex var lock sync.Mutex
executor := NewBulkExecutor(func(tasks []any) { executer := NewBulkExecutor(func(tasks []interface{}) {
lock.Lock() lock.Lock()
defer lock.Unlock() defer lock.Unlock()
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
}, WithBulkTasks(1), WithBulkInterval(time.Second)) }, WithBulkTasks(1), WithBulkInterval(time.Second))
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
executor.Add(1) executer.Add(1)
} }
executor.Flush() executer.Flush()
executor.Wait() executer.Wait()
} }
func TestPeriodicalExecutor_WaitFast(t *testing.T) { func TestPeriodicalExecutor_WaitFast(t *testing.T) {
const total = 3 const total = 3
var cnt int var cnt int
var lock sync.Mutex var lock sync.Mutex
executor := NewBulkExecutor(func(tasks []any) { executer := NewBulkExecutor(func(tasks []interface{}) {
defer func() { defer func() {
cnt++ cnt++
}() }()
@@ -193,15 +135,15 @@ func TestPeriodicalExecutor_WaitFast(t *testing.T) {
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
}, WithBulkTasks(1), WithBulkInterval(10*time.Millisecond)) }, WithBulkTasks(1), WithBulkInterval(10*time.Millisecond))
for i := 0; i < total; i++ { for i := 0; i < total; i++ {
executor.Add(2) executer.Add(2)
} }
executor.Flush() executer.Flush()
executor.Wait() executer.Wait()
assert.Equal(t, total, cnt) assert.Equal(t, total, cnt)
} }
func TestPeriodicalExecutor_Deadlock(t *testing.T) { func TestPeriodicalExecutor_Deadlock(t *testing.T) {
executor := NewBulkExecutor(func(tasks []any) { executor := NewBulkExecutor(func(tasks []interface{}) {
}, WithBulkTasks(1), WithBulkInterval(time.Millisecond)) }, WithBulkTasks(1), WithBulkInterval(time.Millisecond))
for i := 0; i < 1e5; i++ { for i := 0; i < 1e5; i++ {
executor.Add(1) executor.Add(1)
@@ -209,7 +151,13 @@ func TestPeriodicalExecutor_Deadlock(t *testing.T) {
} }
func TestPeriodicalExecutor_hasTasks(t *testing.T) { func TestPeriodicalExecutor_hasTasks(t *testing.T) {
ticker := timex.NewFakeTicker()
defer ticker.Stop()
exec := NewPeriodicalExecutor(time.Millisecond, newContainer(time.Millisecond, nil)) exec := NewPeriodicalExecutor(time.Millisecond, newContainer(time.Millisecond, nil))
exec.newTicker = func(d time.Duration) timex.Ticker {
return ticker
}
assert.False(t, exec.hasTasks(nil)) assert.False(t, exec.hasTasks(nil))
assert.True(t, exec.hasTasks(1)) assert.True(t, exec.hasTasks(1))
} }

View File

@@ -5,4 +5,4 @@ import "time"
const defaultFlushInterval = time.Second const defaultFlushInterval = time.Second
// Execute defines the method to execute tasks. // Execute defines the method to execute tasks.
type Execute func(tasks []any) type Execute func(tasks []interface{})

View File

@@ -35,7 +35,6 @@ func firstLine(file *os.File) (string, error) {
for { for {
buf := make([]byte, bufSize) buf := make([]byte, bufSize)
n, err := file.ReadAt(buf, offset) n, err := file.ReadAt(buf, offset)
if err != nil && err != io.EOF { if err != nil && err != io.EOF {
return "", err return "", err
} }
@@ -46,10 +45,6 @@ func firstLine(file *os.File) (string, error) {
} }
} }
if err == io.EOF {
return string(append(first, buf[:n]...)), nil
}
first = append(first, buf[:n]...) first = append(first, buf[:n]...)
offset += bufSize offset += bufSize
} }
@@ -62,42 +57,30 @@ func lastLine(filename string, file *os.File) (string, error) {
} }
var last []byte var last []byte
bufLen := int64(bufSize)
offset := info.Size() offset := info.Size()
for {
for offset > 0 { offset -= bufSize
if offset < bufLen { if offset < 0 {
bufLen = offset
offset = 0 offset = 0
} else {
offset -= bufLen
} }
buf := make([]byte, bufSize)
buf := make([]byte, bufLen)
n, err := file.ReadAt(buf, offset) n, err := file.ReadAt(buf, offset)
if err != nil && err != io.EOF { if err != nil && err != io.EOF {
return "", err return "", err
} }
if n == 0 {
break
}
if buf[n-1] == '\n' { if buf[n-1] == '\n' {
buf = buf[:n-1] buf = buf[:n-1]
n-- n--
} else { } else {
buf = buf[:n] buf = buf[:n]
} }
for n--; n >= 0; n-- {
for i := n - 1; i >= 0; i-- { if buf[n] == '\n' {
if buf[i] == '\n' { return string(append(buf[n+1:], last...)), nil
return string(append(buf[i+1:], last...)), nil
} }
} }
last = append(buf, last...) last = append(buf, last...)
} }
return string(last), nil
} }

View File

@@ -52,7 +52,6 @@ last line`
second line second line
last line last line
` `
emptyContent = ``
) )
func TestFirstLine(t *testing.T) { func TestFirstLine(t *testing.T) {
@@ -75,31 +74,6 @@ func TestFirstLineShort(t *testing.T) {
assert.Equal(t, "first line", val) assert.Equal(t, "first line", val)
} }
func TestFirstLineError(t *testing.T) {
_, err := FirstLine("/tmp/does-not-exist")
assert.Error(t, err)
}
func TestFirstLineEmptyFile(t *testing.T) {
filename, err := fs.TempFilenameWithText(emptyContent)
assert.Nil(t, err)
defer os.Remove(filename)
val, err := FirstLine(filename)
assert.Nil(t, err)
assert.Equal(t, "", val)
}
func TestFirstLineWithoutNewline(t *testing.T) {
filename, err := fs.TempFilenameWithText(longLine)
assert.Nil(t, err)
defer os.Remove(filename)
val, err := FirstLine(filename)
assert.Nil(t, err)
assert.Equal(t, longLine, val)
}
func TestLastLine(t *testing.T) { func TestLastLine(t *testing.T) {
filename, err := fs.TempFilenameWithText(text) filename, err := fs.TempFilenameWithText(text)
assert.Nil(t, err) assert.Nil(t, err)
@@ -120,16 +94,6 @@ func TestLastLineWithLastNewline(t *testing.T) {
assert.Equal(t, longLine, val) assert.Equal(t, longLine, val)
} }
func TestLastLineWithoutLastNewline(t *testing.T) {
filename, err := fs.TempFilenameWithText(longLine)
assert.Nil(t, err)
defer os.Remove(filename)
val, err := LastLine(filename)
assert.Nil(t, err)
assert.Equal(t, longLine, val)
}
func TestLastLineShort(t *testing.T) { func TestLastLineShort(t *testing.T) {
filename, err := fs.TempFilenameWithText(shortText) filename, err := fs.TempFilenameWithText(shortText)
assert.Nil(t, err) assert.Nil(t, err)
@@ -149,72 +113,3 @@ func TestLastLineWithLastNewlineShort(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "last line", val) assert.Equal(t, "last line", val)
} }
func TestLastLineError(t *testing.T) {
_, err := LastLine("/tmp/does-not-exist")
assert.Error(t, err)
}
func TestLastLineEmptyFile(t *testing.T) {
filename, err := fs.TempFilenameWithText(emptyContent)
assert.Nil(t, err)
defer os.Remove(filename)
val, err := LastLine(filename)
assert.Nil(t, err)
assert.Equal(t, "", val)
}
func TestFirstLineExactlyBufSize(t *testing.T) {
content := make([]byte, bufSize)
for i := range content {
content[i] = 'a'
}
content[bufSize-1] = '\n' // Ensure there is a newline at the edge
filename, err := fs.TempFilenameWithText(string(content))
assert.Nil(t, err)
defer os.Remove(filename)
val, err := FirstLine(filename)
assert.Nil(t, err)
assert.Equal(t, string(content[:bufSize-1]), val)
}
func TestLastLineExactlyBufSize(t *testing.T) {
content := make([]byte, bufSize)
for i := range content {
content[i] = 'a'
}
content[bufSize-1] = '\n' // Ensure there is a newline at the edge
filename, err := fs.TempFilenameWithText(string(content))
assert.Nil(t, err)
defer os.Remove(filename)
val, err := LastLine(filename)
assert.Nil(t, err)
assert.Equal(t, string(content[:bufSize-1]), val)
}
func TestFirstLineLargeFile(t *testing.T) {
content := text + text + text + "\n" + "extra"
filename, err := fs.TempFilenameWithText(content)
assert.Nil(t, err)
defer os.Remove(filename)
val, err := FirstLine(filename)
assert.Nil(t, err)
assert.Equal(t, "first line", val)
}
func TestLastLineLargeFile(t *testing.T) {
content := text + text + text + "\n" + "extra"
filename, err := fs.TempFilenameWithText(content)
assert.Nil(t, err)
defer os.Remove(filename)
val, err := LastLine(filename)
assert.Nil(t, err)
assert.Equal(t, "extra", val)
}

View File

@@ -5,7 +5,7 @@ import "gopkg.in/cheggaaa/pb.v1"
type ( type (
// A Scanner is used to read lines. // A Scanner is used to read lines.
Scanner interface { Scanner interface {
// Scan checks if it has remaining to read. // Scan checks if has remaining to read.
Scan() bool Scan() bool
// Text returns next line. // Text returns next line.
Text() string Text() string

View File

@@ -1,4 +1,5 @@
//go:build windows //go:build windows
// +build windows
package fs package fs

View File

@@ -1,4 +1,5 @@
//go:build linux || darwin || freebsd //go:build linux || darwin
// +build linux darwin
package fs package fs

View File

@@ -11,29 +11,29 @@ import (
// The file is kept as open, the caller should close the file handle, // The file is kept as open, the caller should close the file handle,
// and remove the file by name. // and remove the file by name.
func TempFileWithText(text string) (*os.File, error) { func TempFileWithText(text string) (*os.File, error) {
tmpFile, err := os.CreateTemp(os.TempDir(), hash.Md5Hex([]byte(text))) tmpfile, err := os.CreateTemp(os.TempDir(), hash.Md5Hex([]byte(text)))
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := os.WriteFile(tmpFile.Name(), []byte(text), os.ModeTemporary); err != nil { if err := os.WriteFile(tmpfile.Name(), []byte(text), os.ModeTemporary); err != nil {
return nil, err return nil, err
} }
return tmpFile, nil return tmpfile, nil
} }
// TempFilenameWithText creates the file with the given content, // TempFilenameWithText creates the file with the given content,
// and returns the filename (full path). // and returns the filename (full path).
// The caller should remove the file after use. // The caller should remove the file after use.
func TempFilenameWithText(text string) (string, error) { func TempFilenameWithText(text string) (string, error) {
tmpFile, err := TempFileWithText(text) tmpfile, err := TempFileWithText(text)
if err != nil { if err != nil {
return "", err return "", err
} }
filename := tmpFile.Name() filename := tmpfile.Name()
if err = tmpFile.Close(); err != nil { if err = tmpfile.Close(); err != nil {
return "", err return "", err
} }

View File

@@ -1,9 +1,6 @@
package fx package fx
import ( import "github.com/zeromicro/go-zero/core/threading"
"github.com/zeromicro/go-zero/core/errorx"
"github.com/zeromicro/go-zero/core/threading"
)
// Parallel runs fns parallelly and waits for done. // Parallel runs fns parallelly and waits for done.
func Parallel(fns ...func()) { func Parallel(fns ...func()) {
@@ -13,20 +10,3 @@ func Parallel(fns ...func()) {
} }
group.Wait() group.Wait()
} }
func ParallelErr(fns ...func() error) error {
var be errorx.BatchError
group := threading.NewRoutineGroup()
for _, fn := range fns {
f := fn
group.RunSafe(func() {
if err := f(); err != nil {
be.Add(err)
}
})
}
group.Wait()
return be.Err()
}

View File

@@ -1,7 +1,6 @@
package fx package fx
import ( import (
"errors"
"sync/atomic" "sync/atomic"
"testing" "testing"
"time" "time"
@@ -23,54 +22,3 @@ func TestParallel(t *testing.T) {
}) })
assert.Equal(t, int32(6), count) assert.Equal(t, int32(6), count)
} }
func TestParallelErr(t *testing.T) {
var count int32
err := ParallelErr(
func() error {
time.Sleep(time.Millisecond * 100)
atomic.AddInt32(&count, 1)
return errors.New("failed to exec #1")
},
func() error {
time.Sleep(time.Millisecond * 100)
atomic.AddInt32(&count, 2)
return errors.New("failed to exec #2")
},
func() error {
time.Sleep(time.Millisecond * 100)
atomic.AddInt32(&count, 3)
return nil
},
)
assert.Equal(t, int32(6), count)
assert.Error(t, err)
assert.ErrorContains(t, err, "failed to exec #1", "failed to exec #2")
}
func TestParallelErrErrorNil(t *testing.T) {
var count int32
err := ParallelErr(
func() error {
time.Sleep(time.Millisecond * 100)
atomic.AddInt32(&count, 1)
return nil
},
func() error {
time.Sleep(time.Millisecond * 100)
atomic.AddInt32(&count, 2)
return nil
},
func() error {
time.Sleep(time.Millisecond * 100)
atomic.AddInt32(&count, 3)
return nil
},
)
assert.Equal(t, int32(6), count)
assert.NoError(t, err)
}

View File

@@ -1,12 +1,6 @@
package fx package fx
import ( import "github.com/zeromicro/go-zero/core/errorx"
"context"
"errors"
"time"
"github.com/zeromicro/go-zero/core/errorx"
)
const defaultRetryTimes = 3 const defaultRetryTimes = 3
@@ -15,110 +9,36 @@ type (
RetryOption func(*retryOptions) RetryOption func(*retryOptions)
retryOptions struct { retryOptions struct {
times int times int
interval time.Duration
timeout time.Duration
ignoreErrors []error
} }
) )
// DoWithRetry runs fn, and retries if failed. Default to retry 3 times. // DoWithRetry runs fn, and retries if failed. Default to retry 3 times.
// Note that if the fn function accesses global variables outside the function
// and performs modification operations, it is best to lock them,
// otherwise there may be data race issues
func DoWithRetry(fn func() error, opts ...RetryOption) error { func DoWithRetry(fn func() error, opts ...RetryOption) error {
return retry(context.Background(), func(errChan chan error, retryCount int) {
errChan <- fn()
}, opts...)
}
// DoWithRetryCtx runs fn, and retries if failed. Default to retry 3 times.
// fn retryCount indicates the current number of retries, starting from 0
// Note that if the fn function accesses global variables outside the function
// and performs modification operations, it is best to lock them,
// otherwise there may be data race issues
func DoWithRetryCtx(ctx context.Context, fn func(ctx context.Context, retryCount int) error,
opts ...RetryOption) error {
return retry(ctx, func(errChan chan error, retryCount int) {
errChan <- fn(ctx, retryCount)
}, opts...)
}
func retry(ctx context.Context, fn func(errChan chan error, retryCount int), opts ...RetryOption) error {
options := newRetryOptions() options := newRetryOptions()
for _, opt := range opts { for _, opt := range opts {
opt(options) opt(options)
} }
var berr errorx.BatchError var berr errorx.BatchError
var cancelFunc context.CancelFunc
if options.timeout > 0 {
ctx, cancelFunc = context.WithTimeout(ctx, options.timeout)
defer cancelFunc()
}
errChan := make(chan error, 1)
for i := 0; i < options.times; i++ { for i := 0; i < options.times; i++ {
go fn(errChan, i) if err := fn(); err != nil {
berr.Add(err)
select { } else {
case err := <-errChan: return nil
if err != nil {
for _, ignoreErr := range options.ignoreErrors {
if errors.Is(err, ignoreErr) {
return nil
}
}
berr.Add(err)
} else {
return nil
}
case <-ctx.Done():
berr.Add(ctx.Err())
return berr.Err()
}
if options.interval > 0 {
select {
case <-ctx.Done():
berr.Add(ctx.Err())
return berr.Err()
case <-time.After(options.interval):
}
} }
} }
return berr.Err() return berr.Err()
} }
// WithIgnoreErrors Ignore the specified errors // WithRetry customize a DoWithRetry call with given retry times.
func WithIgnoreErrors(ignoreErrors []error) RetryOption {
return func(options *retryOptions) {
options.ignoreErrors = ignoreErrors
}
}
// WithInterval customizes a DoWithRetry call with given interval.
func WithInterval(interval time.Duration) RetryOption {
return func(options *retryOptions) {
options.interval = interval
}
}
// WithRetry customizes a DoWithRetry call with given retry times.
func WithRetry(times int) RetryOption { func WithRetry(times int) RetryOption {
return func(options *retryOptions) { return func(options *retryOptions) {
options.times = times options.times = times
} }
} }
// WithTimeout customizes a DoWithRetry call with given timeout.
func WithTimeout(timeout time.Duration) RetryOption {
return func(options *retryOptions) {
options.timeout = timeout
}
}
func newRetryOptions() *retryOptions { func newRetryOptions() *retryOptions {
return &retryOptions{ return &retryOptions{
times: defaultRetryTimes, times: defaultRetryTimes,

View File

@@ -1,10 +1,8 @@
package fx package fx
import ( import (
"context"
"errors" "errors"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@@ -14,153 +12,31 @@ func TestRetry(t *testing.T) {
return errors.New("any") return errors.New("any")
})) }))
times1 := 0 var times int
assert.Nil(t, DoWithRetry(func() error { assert.Nil(t, DoWithRetry(func() error {
times1++ times++
if times1 == defaultRetryTimes { if times == defaultRetryTimes {
return nil return nil
} }
return errors.New("any") return errors.New("any")
})) }))
times2 := 0 times = 0
assert.NotNil(t, DoWithRetry(func() error { assert.NotNil(t, DoWithRetry(func() error {
times2++ times++
if times2 == defaultRetryTimes+1 { if times == defaultRetryTimes+1 {
return nil return nil
} }
return errors.New("any") return errors.New("any")
})) }))
total := 2 * defaultRetryTimes total := 2 * defaultRetryTimes
times3 := 0 times = 0
assert.Nil(t, DoWithRetry(func() error { assert.Nil(t, DoWithRetry(func() error {
times3++ times++
if times3 == total { if times == total {
return nil return nil
} }
return errors.New("any") return errors.New("any")
}, WithRetry(total))) }, WithRetry(total)))
} }
func TestRetryWithTimeout(t *testing.T) {
assert.Nil(t, DoWithRetry(func() error {
return nil
}, WithTimeout(time.Millisecond*500)))
times1 := 0
assert.Nil(t, DoWithRetry(func() error {
times1++
if times1 == 1 {
return errors.New("any ")
}
time.Sleep(time.Millisecond * 150)
return nil
}, WithTimeout(time.Millisecond*250)))
total := defaultRetryTimes
times2 := 0
assert.Nil(t, DoWithRetry(func() error {
times2++
if times2 == total {
return nil
}
time.Sleep(time.Millisecond * 50)
return errors.New("any")
}, WithTimeout(time.Millisecond*50*(time.Duration(total)+2))))
assert.NotNil(t, DoWithRetry(func() error {
return errors.New("any")
}, WithTimeout(time.Millisecond*250)))
}
func TestRetryWithInterval(t *testing.T) {
times1 := 0
assert.NotNil(t, DoWithRetry(func() error {
times1++
if times1 == 1 {
return errors.New("any")
}
time.Sleep(time.Millisecond * 150)
return nil
}, WithTimeout(time.Millisecond*250), WithInterval(time.Millisecond*150)))
times2 := 0
assert.NotNil(t, DoWithRetry(func() error {
times2++
if times2 == 2 {
return nil
}
time.Sleep(time.Millisecond * 150)
return errors.New("any ")
}, WithTimeout(time.Millisecond*250), WithInterval(time.Millisecond*150)))
}
func TestRetryWithWithIgnoreErrors(t *testing.T) {
ignoreErr1 := errors.New("ignore error1")
ignoreErr2 := errors.New("ignore error2")
ignoreErrs := []error{ignoreErr1, ignoreErr2}
assert.Nil(t, DoWithRetry(func() error {
return ignoreErr1
}, WithIgnoreErrors(ignoreErrs)))
assert.Nil(t, DoWithRetry(func() error {
return ignoreErr2
}, WithIgnoreErrors(ignoreErrs)))
assert.NotNil(t, DoWithRetry(func() error {
return errors.New("any")
}))
}
func TestRetryCtx(t *testing.T) {
t.Run("with timeout", func(t *testing.T) {
assert.NotNil(t, DoWithRetryCtx(context.Background(), func(ctx context.Context, retryCount int) error {
if retryCount == 0 {
return errors.New("any")
}
time.Sleep(time.Millisecond * 150)
return nil
}, WithTimeout(time.Millisecond*250), WithInterval(time.Millisecond*150)))
assert.NotNil(t, DoWithRetryCtx(context.Background(), func(ctx context.Context, retryCount int) error {
if retryCount == 1 {
return nil
}
time.Sleep(time.Millisecond * 150)
return errors.New("any ")
}, WithTimeout(time.Millisecond*250), WithInterval(time.Millisecond*150)))
})
t.Run("with deadline exceeded", func(t *testing.T) {
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Millisecond*250))
defer cancel()
var times int
assert.Error(t, DoWithRetryCtx(ctx, func(ctx context.Context, retryCount int) error {
times++
time.Sleep(time.Millisecond * 150)
return errors.New("any")
}, WithInterval(time.Millisecond*150)))
assert.Equal(t, 1, times)
})
t.Run("with deadline not exceeded", func(t *testing.T) {
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Millisecond*250))
defer cancel()
var times int
assert.NoError(t, DoWithRetryCtx(ctx, func(ctx context.Context, retryCount int) error {
times++
if times == defaultRetryTimes {
return nil
}
time.Sleep(time.Millisecond * 50)
return errors.New("any")
}))
assert.Equal(t, defaultRetryTimes, times)
})
}

View File

@@ -21,31 +21,31 @@ type (
} }
// FilterFunc defines the method to filter a Stream. // FilterFunc defines the method to filter a Stream.
FilterFunc func(item any) bool FilterFunc func(item interface{}) bool
// ForAllFunc defines the method to handle all elements in a Stream. // ForAllFunc defines the method to handle all elements in a Stream.
ForAllFunc func(pipe <-chan any) ForAllFunc func(pipe <-chan interface{})
// ForEachFunc defines the method to handle each element in a Stream. // ForEachFunc defines the method to handle each element in a Stream.
ForEachFunc func(item any) ForEachFunc func(item interface{})
// GenerateFunc defines the method to send elements into a Stream. // GenerateFunc defines the method to send elements into a Stream.
GenerateFunc func(source chan<- any) GenerateFunc func(source chan<- interface{})
// KeyFunc defines the method to generate keys for the elements in a Stream. // KeyFunc defines the method to generate keys for the elements in a Stream.
KeyFunc func(item any) any KeyFunc func(item interface{}) interface{}
// LessFunc defines the method to compare the elements in a Stream. // LessFunc defines the method to compare the elements in a Stream.
LessFunc func(a, b any) bool LessFunc func(a, b interface{}) bool
// MapFunc defines the method to map each element to another object in a Stream. // MapFunc defines the method to map each element to another object in a Stream.
MapFunc func(item any) any MapFunc func(item interface{}) interface{}
// Option defines the method to customize a Stream. // Option defines the method to customize a Stream.
Option func(opts *rxOptions) Option func(opts *rxOptions)
// ParallelFunc defines the method to handle elements parallelly. // ParallelFunc defines the method to handle elements parallelly.
ParallelFunc func(item any) ParallelFunc func(item interface{})
// ReduceFunc defines the method to reduce all the elements in a Stream. // ReduceFunc defines the method to reduce all the elements in a Stream.
ReduceFunc func(pipe <-chan any) (any, error) ReduceFunc func(pipe <-chan interface{}) (interface{}, error)
// WalkFunc defines the method to walk through all the elements in a Stream. // WalkFunc defines the method to walk through all the elements in a Stream.
WalkFunc func(item any, pipe chan<- any) WalkFunc func(item interface{}, pipe chan<- interface{})
// A Stream is a stream that can be used to do stream processing. // A Stream is a stream that can be used to do stream processing.
Stream struct { Stream struct {
source <-chan any source <-chan interface{}
} }
) )
@@ -56,7 +56,7 @@ func Concat(s Stream, others ...Stream) Stream {
// From constructs a Stream from the given GenerateFunc. // From constructs a Stream from the given GenerateFunc.
func From(generate GenerateFunc) Stream { func From(generate GenerateFunc) Stream {
source := make(chan any) source := make(chan interface{})
threading.GoSafe(func() { threading.GoSafe(func() {
defer close(source) defer close(source)
@@ -67,8 +67,8 @@ func From(generate GenerateFunc) Stream {
} }
// Just converts the given arbitrary items to a Stream. // Just converts the given arbitrary items to a Stream.
func Just(items ...any) Stream { func Just(items ...interface{}) Stream {
source := make(chan any, len(items)) source := make(chan interface{}, len(items))
for _, item := range items { for _, item := range items {
source <- item source <- item
} }
@@ -78,16 +78,16 @@ func Just(items ...any) Stream {
} }
// Range converts the given channel to a Stream. // Range converts the given channel to a Stream.
func Range(source <-chan any) Stream { func Range(source <-chan interface{}) Stream {
return Stream{ return Stream{
source: source, source: source,
} }
} }
// AllMatch returns whether all elements of this stream match the provided predicate. // AllMach returns whether all elements of this stream match the provided predicate.
// May not evaluate the predicate on all elements if not necessary for determining the result. // May not evaluate the predicate on all elements if not necessary for determining the result.
// If the stream is empty then true is returned and the predicate is not evaluated. // If the stream is empty then true is returned and the predicate is not evaluated.
func (s Stream) AllMatch(predicate func(item any) bool) bool { func (s Stream) AllMach(predicate func(item interface{}) bool) bool {
for item := range s.source { for item := range s.source {
if !predicate(item) { if !predicate(item) {
// make sure the former goroutine not block, and current func returns fast. // make sure the former goroutine not block, and current func returns fast.
@@ -99,10 +99,10 @@ func (s Stream) AllMatch(predicate func(item any) bool) bool {
return true return true
} }
// AnyMatch returns whether any elements of this stream match the provided predicate. // AnyMach returns whether any elements of this stream match the provided predicate.
// May not evaluate the predicate on all elements if not necessary for determining the result. // May not evaluate the predicate on all elements if not necessary for determining the result.
// If the stream is empty then false is returned and the predicate is not evaluated. // If the stream is empty then false is returned and the predicate is not evaluated.
func (s Stream) AnyMatch(predicate func(item any) bool) bool { func (s Stream) AnyMach(predicate func(item interface{}) bool) bool {
for item := range s.source { for item := range s.source {
if predicate(item) { if predicate(item) {
// make sure the former goroutine not block, and current func returns fast. // make sure the former goroutine not block, and current func returns fast.
@@ -121,7 +121,7 @@ func (s Stream) Buffer(n int) Stream {
n = 0 n = 0
} }
source := make(chan any, n) source := make(chan interface{}, n)
go func() { go func() {
for item := range s.source { for item := range s.source {
source <- item source <- item
@@ -134,7 +134,7 @@ func (s Stream) Buffer(n int) Stream {
// Concat returns a Stream that concatenated other streams // Concat returns a Stream that concatenated other streams
func (s Stream) Concat(others ...Stream) Stream { func (s Stream) Concat(others ...Stream) Stream {
source := make(chan any) source := make(chan interface{})
go func() { go func() {
group := threading.NewRoutineGroup() group := threading.NewRoutineGroup()
@@ -170,12 +170,12 @@ func (s Stream) Count() (count int) {
// Distinct removes the duplicated items base on the given KeyFunc. // Distinct removes the duplicated items base on the given KeyFunc.
func (s Stream) Distinct(fn KeyFunc) Stream { func (s Stream) Distinct(fn KeyFunc) Stream {
source := make(chan any) source := make(chan interface{})
threading.GoSafe(func() { threading.GoSafe(func() {
defer close(source) defer close(source)
keys := make(map[any]lang.PlaceholderType) keys := make(map[interface{}]lang.PlaceholderType)
for item := range s.source { for item := range s.source {
key := fn(item) key := fn(item)
if _, ok := keys[key]; !ok { if _, ok := keys[key]; !ok {
@@ -195,7 +195,7 @@ func (s Stream) Done() {
// Filter filters the items by the given FilterFunc. // Filter filters the items by the given FilterFunc.
func (s Stream) Filter(fn FilterFunc, opts ...Option) Stream { func (s Stream) Filter(fn FilterFunc, opts ...Option) Stream {
return s.Walk(func(item any, pipe chan<- any) { return s.Walk(func(item interface{}, pipe chan<- interface{}) {
if fn(item) { if fn(item) {
pipe <- item pipe <- item
} }
@@ -203,7 +203,7 @@ func (s Stream) Filter(fn FilterFunc, opts ...Option) Stream {
} }
// First returns the first item, nil if no items. // First returns the first item, nil if no items.
func (s Stream) First() any { func (s Stream) First() interface{} {
for item := range s.source { for item := range s.source {
// make sure the former goroutine not block, and current func returns fast. // make sure the former goroutine not block, and current func returns fast.
go drain(s.source) go drain(s.source)
@@ -229,13 +229,13 @@ func (s Stream) ForEach(fn ForEachFunc) {
// Group groups the elements into different groups based on their keys. // Group groups the elements into different groups based on their keys.
func (s Stream) Group(fn KeyFunc) Stream { func (s Stream) Group(fn KeyFunc) Stream {
groups := make(map[any][]any) groups := make(map[interface{}][]interface{})
for item := range s.source { for item := range s.source {
key := fn(item) key := fn(item)
groups[key] = append(groups[key], item) groups[key] = append(groups[key], item)
} }
source := make(chan any) source := make(chan interface{})
go func() { go func() {
for _, group := range groups { for _, group := range groups {
source <- group source <- group
@@ -252,7 +252,7 @@ func (s Stream) Head(n int64) Stream {
panic("n must be greater than 0") panic("n must be greater than 0")
} }
source := make(chan any) source := make(chan interface{})
go func() { go func() {
for item := range s.source { for item := range s.source {
@@ -279,7 +279,7 @@ func (s Stream) Head(n int64) Stream {
} }
// Last returns the last item, or nil if no items. // Last returns the last item, or nil if no items.
func (s Stream) Last() (item any) { func (s Stream) Last() (item interface{}) {
for item = range s.source { for item = range s.source {
} }
return return
@@ -287,53 +287,29 @@ func (s Stream) Last() (item any) {
// Map converts each item to another corresponding item, which means it's a 1:1 model. // Map converts each item to another corresponding item, which means it's a 1:1 model.
func (s Stream) Map(fn MapFunc, opts ...Option) Stream { func (s Stream) Map(fn MapFunc, opts ...Option) Stream {
return s.Walk(func(item any, pipe chan<- any) { return s.Walk(func(item interface{}, pipe chan<- interface{}) {
pipe <- fn(item) pipe <- fn(item)
}, opts...) }, opts...)
} }
// Max returns the maximum item from the underlying source.
func (s Stream) Max(less LessFunc) any {
var max any
for item := range s.source {
if max == nil || less(max, item) {
max = item
}
}
return max
}
// Merge merges all the items into a slice and generates a new stream. // Merge merges all the items into a slice and generates a new stream.
func (s Stream) Merge() Stream { func (s Stream) Merge() Stream {
var items []any var items []interface{}
for item := range s.source { for item := range s.source {
items = append(items, item) items = append(items, item)
} }
source := make(chan any, 1) source := make(chan interface{}, 1)
source <- items source <- items
close(source) close(source)
return Range(source) return Range(source)
} }
// Min returns the minimum item from the underlying source.
func (s Stream) Min(less LessFunc) any {
var min any
for item := range s.source {
if min == nil || less(item, min) {
min = item
}
}
return min
}
// NoneMatch returns whether all elements of this stream don't match the provided predicate. // NoneMatch returns whether all elements of this stream don't match the provided predicate.
// May not evaluate the predicate on all elements if not necessary for determining the result. // May not evaluate the predicate on all elements if not necessary for determining the result.
// If the stream is empty then true is returned and the predicate is not evaluated. // If the stream is empty then true is returned and the predicate is not evaluated.
func (s Stream) NoneMatch(predicate func(item any) bool) bool { func (s Stream) NoneMatch(predicate func(item interface{}) bool) bool {
for item := range s.source { for item := range s.source {
if predicate(item) { if predicate(item) {
// make sure the former goroutine not block, and current func returns fast. // make sure the former goroutine not block, and current func returns fast.
@@ -347,19 +323,19 @@ func (s Stream) NoneMatch(predicate func(item any) bool) bool {
// Parallel applies the given ParallelFunc to each item concurrently with given number of workers. // Parallel applies the given ParallelFunc to each item concurrently with given number of workers.
func (s Stream) Parallel(fn ParallelFunc, opts ...Option) { func (s Stream) Parallel(fn ParallelFunc, opts ...Option) {
s.Walk(func(item any, pipe chan<- any) { s.Walk(func(item interface{}, pipe chan<- interface{}) {
fn(item) fn(item)
}, opts...).Done() }, opts...).Done()
} }
// Reduce is a utility method to let the caller deal with the underlying channel. // Reduce is an utility method to let the caller deal with the underlying channel.
func (s Stream) Reduce(fn ReduceFunc) (any, error) { func (s Stream) Reduce(fn ReduceFunc) (interface{}, error) {
return fn(s.source) return fn(s.source)
} }
// Reverse reverses the elements in the stream. // Reverse reverses the elements in the stream.
func (s Stream) Reverse() Stream { func (s Stream) Reverse() Stream {
var items []any var items []interface{}
for item := range s.source { for item := range s.source {
items = append(items, item) items = append(items, item)
} }
@@ -381,7 +357,7 @@ func (s Stream) Skip(n int64) Stream {
return s return s
} }
source := make(chan any) source := make(chan interface{})
go func() { go func() {
for item := range s.source { for item := range s.source {
@@ -400,7 +376,7 @@ func (s Stream) Skip(n int64) Stream {
// Sort sorts the items from the underlying source. // Sort sorts the items from the underlying source.
func (s Stream) Sort(less LessFunc) Stream { func (s Stream) Sort(less LessFunc) Stream {
var items []any var items []interface{}
for item := range s.source { for item := range s.source {
items = append(items, item) items = append(items, item)
} }
@@ -418,9 +394,9 @@ func (s Stream) Split(n int) Stream {
panic("n should be greater than 0") panic("n should be greater than 0")
} }
source := make(chan any) source := make(chan interface{})
go func() { go func() {
var chunk []any var chunk []interface{}
for item := range s.source { for item := range s.source {
chunk = append(chunk, item) chunk = append(chunk, item)
if len(chunk) == n { if len(chunk) == n {
@@ -443,7 +419,7 @@ func (s Stream) Tail(n int64) Stream {
panic("n should be greater than 0") panic("n should be greater than 0")
} }
source := make(chan any) source := make(chan interface{})
go func() { go func() {
ring := collection.NewRing(int(n)) ring := collection.NewRing(int(n))
@@ -470,7 +446,7 @@ func (s Stream) Walk(fn WalkFunc, opts ...Option) Stream {
} }
func (s Stream) walkLimited(fn WalkFunc, option *rxOptions) Stream { func (s Stream) walkLimited(fn WalkFunc, option *rxOptions) Stream {
pipe := make(chan any, option.workers) pipe := make(chan interface{}, option.workers)
go func() { go func() {
var wg sync.WaitGroup var wg sync.WaitGroup
@@ -501,7 +477,7 @@ func (s Stream) walkLimited(fn WalkFunc, option *rxOptions) Stream {
} }
func (s Stream) walkUnlimited(fn WalkFunc, option *rxOptions) Stream { func (s Stream) walkUnlimited(fn WalkFunc, option *rxOptions) Stream {
pipe := make(chan any, option.workers) pipe := make(chan interface{}, option.workers)
go func() { go func() {
var wg sync.WaitGroup var wg sync.WaitGroup
@@ -553,7 +529,7 @@ func buildOptions(opts ...Option) *rxOptions {
} }
// drain drains the given channel. // drain drains the given channel.
func drain(channel <-chan any) { func drain(channel <-chan interface{}) {
for range channel { for range channel {
} }
} }

View File

@@ -23,7 +23,7 @@ func TestBuffer(t *testing.T) {
var count int32 var count int32
var wait sync.WaitGroup var wait sync.WaitGroup
wait.Add(1) wait.Add(1)
From(func(source chan<- any) { From(func(source chan<- interface{}) {
ticker := time.NewTicker(10 * time.Millisecond) ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop() defer ticker.Stop()
@@ -36,7 +36,7 @@ func TestBuffer(t *testing.T) {
return return
} }
} }
}).Buffer(N).ForAll(func(pipe <-chan any) { }).Buffer(N).ForAll(func(pipe <-chan interface{}) {
wait.Wait() wait.Wait()
// why N+1, because take one more to wait for sending into the channel // why N+1, because take one more to wait for sending into the channel
assert.Equal(t, int32(N+1), atomic.LoadInt32(&count)) assert.Equal(t, int32(N+1), atomic.LoadInt32(&count))
@@ -47,7 +47,7 @@ func TestBuffer(t *testing.T) {
func TestBufferNegative(t *testing.T) { func TestBufferNegative(t *testing.T) {
runCheckedTest(t, func(t *testing.T) { runCheckedTest(t, func(t *testing.T) {
var result int var result int
Just(1, 2, 3, 4).Buffer(-1).Reduce(func(pipe <-chan any) (any, error) { Just(1, 2, 3, 4).Buffer(-1).Reduce(func(pipe <-chan interface{}) (interface{}, error) {
for item := range pipe { for item := range pipe {
result += item.(int) result += item.(int)
} }
@@ -61,22 +61,22 @@ func TestCount(t *testing.T) {
runCheckedTest(t, func(t *testing.T) { runCheckedTest(t, func(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
elements []any elements []interface{}
}{ }{
{ {
name: "no elements with nil", name: "no elements with nil",
}, },
{ {
name: "no elements", name: "no elements",
elements: []any{}, elements: []interface{}{},
}, },
{ {
name: "1 element", name: "1 element",
elements: []any{1}, elements: []interface{}{1},
}, },
{ {
name: "multiple elements", name: "multiple elements",
elements: []any{1, 2, 3}, elements: []interface{}{1, 2, 3},
}, },
} }
@@ -92,7 +92,7 @@ func TestCount(t *testing.T) {
func TestDone(t *testing.T) { func TestDone(t *testing.T) {
runCheckedTest(t, func(t *testing.T) { runCheckedTest(t, func(t *testing.T) {
var count int32 var count int32
Just(1, 2, 3).Walk(func(item any, pipe chan<- any) { Just(1, 2, 3).Walk(func(item interface{}, pipe chan<- interface{}) {
time.Sleep(time.Millisecond * 100) time.Sleep(time.Millisecond * 100)
atomic.AddInt32(&count, int32(item.(int))) atomic.AddInt32(&count, int32(item.(int)))
}).Done() }).Done()
@@ -103,7 +103,7 @@ func TestDone(t *testing.T) {
func TestJust(t *testing.T) { func TestJust(t *testing.T) {
runCheckedTest(t, func(t *testing.T) { runCheckedTest(t, func(t *testing.T) {
var result int var result int
Just(1, 2, 3, 4).Reduce(func(pipe <-chan any) (any, error) { Just(1, 2, 3, 4).Reduce(func(pipe <-chan interface{}) (interface{}, error) {
for item := range pipe { for item := range pipe {
result += item.(int) result += item.(int)
} }
@@ -116,9 +116,9 @@ func TestJust(t *testing.T) {
func TestDistinct(t *testing.T) { func TestDistinct(t *testing.T) {
runCheckedTest(t, func(t *testing.T) { runCheckedTest(t, func(t *testing.T) {
var result int var result int
Just(4, 1, 3, 2, 3, 4).Distinct(func(item any) any { Just(4, 1, 3, 2, 3, 4).Distinct(func(item interface{}) interface{} {
return item return item
}).Reduce(func(pipe <-chan any) (any, error) { }).Reduce(func(pipe <-chan interface{}) (interface{}, error) {
for item := range pipe { for item := range pipe {
result += item.(int) result += item.(int)
} }
@@ -131,9 +131,9 @@ func TestDistinct(t *testing.T) {
func TestFilter(t *testing.T) { func TestFilter(t *testing.T) {
runCheckedTest(t, func(t *testing.T) { runCheckedTest(t, func(t *testing.T) {
var result int var result int
Just(1, 2, 3, 4).Filter(func(item any) bool { Just(1, 2, 3, 4).Filter(func(item interface{}) bool {
return item.(int)%2 == 0 return item.(int)%2 == 0
}).Reduce(func(pipe <-chan any) (any, error) { }).Reduce(func(pipe <-chan interface{}) (interface{}, error) {
for item := range pipe { for item := range pipe {
result += item.(int) result += item.(int)
} }
@@ -154,9 +154,9 @@ func TestFirst(t *testing.T) {
func TestForAll(t *testing.T) { func TestForAll(t *testing.T) {
runCheckedTest(t, func(t *testing.T) { runCheckedTest(t, func(t *testing.T) {
var result int var result int
Just(1, 2, 3, 4).Filter(func(item any) bool { Just(1, 2, 3, 4).Filter(func(item interface{}) bool {
return item.(int)%2 == 0 return item.(int)%2 == 0
}).ForAll(func(pipe <-chan any) { }).ForAll(func(pipe <-chan interface{}) {
for item := range pipe { for item := range pipe {
result += item.(int) result += item.(int)
} }
@@ -168,11 +168,11 @@ func TestForAll(t *testing.T) {
func TestGroup(t *testing.T) { func TestGroup(t *testing.T) {
runCheckedTest(t, func(t *testing.T) { runCheckedTest(t, func(t *testing.T) {
var groups [][]int var groups [][]int
Just(10, 11, 20, 21).Group(func(item any) any { Just(10, 11, 20, 21).Group(func(item interface{}) interface{} {
v := item.(int) v := item.(int)
return v / 10 return v / 10
}).ForEach(func(item any) { }).ForEach(func(item interface{}) {
v := item.([]any) v := item.([]interface{})
var group []int var group []int
for _, each := range v { for _, each := range v {
group = append(group, each.(int)) group = append(group, each.(int))
@@ -191,7 +191,7 @@ func TestGroup(t *testing.T) {
func TestHead(t *testing.T) { func TestHead(t *testing.T) {
runCheckedTest(t, func(t *testing.T) { runCheckedTest(t, func(t *testing.T) {
var result int var result int
Just(1, 2, 3, 4).Head(2).Reduce(func(pipe <-chan any) (any, error) { Just(1, 2, 3, 4).Head(2).Reduce(func(pipe <-chan interface{}) (interface{}, error) {
for item := range pipe { for item := range pipe {
result += item.(int) result += item.(int)
} }
@@ -204,7 +204,7 @@ func TestHead(t *testing.T) {
func TestHeadZero(t *testing.T) { func TestHeadZero(t *testing.T) {
runCheckedTest(t, func(t *testing.T) { runCheckedTest(t, func(t *testing.T) {
assert.Panics(t, func() { assert.Panics(t, func() {
Just(1, 2, 3, 4).Head(0).Reduce(func(pipe <-chan any) (any, error) { Just(1, 2, 3, 4).Head(0).Reduce(func(pipe <-chan interface{}) (interface{}, error) {
return nil, nil return nil, nil
}) })
}) })
@@ -214,7 +214,7 @@ func TestHeadZero(t *testing.T) {
func TestHeadMore(t *testing.T) { func TestHeadMore(t *testing.T) {
runCheckedTest(t, func(t *testing.T) { runCheckedTest(t, func(t *testing.T) {
var result int var result int
Just(1, 2, 3, 4).Head(6).Reduce(func(pipe <-chan any) (any, error) { Just(1, 2, 3, 4).Head(6).Reduce(func(pipe <-chan interface{}) (interface{}, error) {
for item := range pipe { for item := range pipe {
result += item.(int) result += item.(int)
} }
@@ -245,14 +245,14 @@ func TestMap(t *testing.T) {
expect int expect int
}{ }{
{ {
mapper: func(item any) any { mapper: func(item interface{}) interface{} {
v := item.(int) v := item.(int)
return v * v return v * v
}, },
expect: 30, expect: 30,
}, },
{ {
mapper: func(item any) any { mapper: func(item interface{}) interface{} {
v := item.(int) v := item.(int)
if v%2 == 0 { if v%2 == 0 {
return 0 return 0
@@ -262,7 +262,7 @@ func TestMap(t *testing.T) {
expect: 10, expect: 10,
}, },
{ {
mapper: func(item any) any { mapper: func(item interface{}) interface{} {
v := item.(int) v := item.(int)
if v%2 == 0 { if v%2 == 0 {
panic(v) panic(v)
@@ -283,12 +283,12 @@ func TestMap(t *testing.T) {
} else { } else {
workers = runtime.NumCPU() workers = runtime.NumCPU()
} }
From(func(source chan<- any) { From(func(source chan<- interface{}) {
for i := 1; i < 5; i++ { for i := 1; i < 5; i++ {
source <- i source <- i
} }
}).Map(test.mapper, WithWorkers(workers)).Reduce( }).Map(test.mapper, WithWorkers(workers)).Reduce(
func(pipe <-chan any) (any, error) { func(pipe <-chan interface{}) (interface{}, error) {
for item := range pipe { for item := range pipe {
result += item.(int) result += item.(int)
} }
@@ -303,8 +303,8 @@ func TestMap(t *testing.T) {
func TestMerge(t *testing.T) { func TestMerge(t *testing.T) {
runCheckedTest(t, func(t *testing.T) { runCheckedTest(t, func(t *testing.T) {
Just(1, 2, 3, 4).Merge().ForEach(func(item any) { Just(1, 2, 3, 4).Merge().ForEach(func(item interface{}) {
assert.ElementsMatch(t, []any{1, 2, 3, 4}, item.([]any)) assert.ElementsMatch(t, []interface{}{1, 2, 3, 4}, item.([]interface{}))
}) })
}) })
} }
@@ -312,7 +312,7 @@ func TestMerge(t *testing.T) {
func TestParallelJust(t *testing.T) { func TestParallelJust(t *testing.T) {
runCheckedTest(t, func(t *testing.T) { runCheckedTest(t, func(t *testing.T) {
var count int32 var count int32
Just(1, 2, 3).Parallel(func(item any) { Just(1, 2, 3).Parallel(func(item interface{}) {
time.Sleep(time.Millisecond * 100) time.Sleep(time.Millisecond * 100)
atomic.AddInt32(&count, int32(item.(int))) atomic.AddInt32(&count, int32(item.(int)))
}, UnlimitedWorkers()) }, UnlimitedWorkers())
@@ -322,8 +322,8 @@ func TestParallelJust(t *testing.T) {
func TestReverse(t *testing.T) { func TestReverse(t *testing.T) {
runCheckedTest(t, func(t *testing.T) { runCheckedTest(t, func(t *testing.T) {
Just(1, 2, 3, 4).Reverse().Merge().ForEach(func(item any) { Just(1, 2, 3, 4).Reverse().Merge().ForEach(func(item interface{}) {
assert.ElementsMatch(t, []any{4, 3, 2, 1}, item.([]any)) assert.ElementsMatch(t, []interface{}{4, 3, 2, 1}, item.([]interface{}))
}) })
}) })
} }
@@ -331,9 +331,9 @@ func TestReverse(t *testing.T) {
func TestSort(t *testing.T) { func TestSort(t *testing.T) {
runCheckedTest(t, func(t *testing.T) { runCheckedTest(t, func(t *testing.T) {
var prev int var prev int
Just(5, 3, 7, 1, 9, 6, 4, 8, 2).Sort(func(a, b any) bool { Just(5, 3, 7, 1, 9, 6, 4, 8, 2).Sort(func(a, b interface{}) bool {
return a.(int) < b.(int) return a.(int) < b.(int)
}).ForEach(func(item any) { }).ForEach(func(item interface{}) {
next := item.(int) next := item.(int)
assert.True(t, prev < next) assert.True(t, prev < next)
prev = next prev = next
@@ -346,12 +346,12 @@ func TestSplit(t *testing.T) {
assert.Panics(t, func() { assert.Panics(t, func() {
Just(1, 2, 3, 4, 5, 6, 7, 8, 9, 10).Split(0).Done() Just(1, 2, 3, 4, 5, 6, 7, 8, 9, 10).Split(0).Done()
}) })
var chunks [][]any var chunks [][]interface{}
Just(1, 2, 3, 4, 5, 6, 7, 8, 9, 10).Split(4).ForEach(func(item any) { Just(1, 2, 3, 4, 5, 6, 7, 8, 9, 10).Split(4).ForEach(func(item interface{}) {
chunk := item.([]any) chunk := item.([]interface{})
chunks = append(chunks, chunk) chunks = append(chunks, chunk)
}) })
assert.EqualValues(t, [][]any{ assert.EqualValues(t, [][]interface{}{
{1, 2, 3, 4}, {1, 2, 3, 4},
{5, 6, 7, 8}, {5, 6, 7, 8},
{9, 10}, {9, 10},
@@ -362,7 +362,7 @@ func TestSplit(t *testing.T) {
func TestTail(t *testing.T) { func TestTail(t *testing.T) {
runCheckedTest(t, func(t *testing.T) { runCheckedTest(t, func(t *testing.T) {
var result int var result int
Just(1, 2, 3, 4).Tail(2).Reduce(func(pipe <-chan any) (any, error) { Just(1, 2, 3, 4).Tail(2).Reduce(func(pipe <-chan interface{}) (interface{}, error) {
for item := range pipe { for item := range pipe {
result += item.(int) result += item.(int)
} }
@@ -375,7 +375,7 @@ func TestTail(t *testing.T) {
func TestTailZero(t *testing.T) { func TestTailZero(t *testing.T) {
runCheckedTest(t, func(t *testing.T) { runCheckedTest(t, func(t *testing.T) {
assert.Panics(t, func() { assert.Panics(t, func() {
Just(1, 2, 3, 4).Tail(0).Reduce(func(pipe <-chan any) (any, error) { Just(1, 2, 3, 4).Tail(0).Reduce(func(pipe <-chan interface{}) (interface{}, error) {
return nil, nil return nil, nil
}) })
}) })
@@ -385,11 +385,11 @@ func TestTailZero(t *testing.T) {
func TestWalk(t *testing.T) { func TestWalk(t *testing.T) {
runCheckedTest(t, func(t *testing.T) { runCheckedTest(t, func(t *testing.T) {
var result int var result int
Just(1, 2, 3, 4, 5).Walk(func(item any, pipe chan<- any) { Just(1, 2, 3, 4, 5).Walk(func(item interface{}, pipe chan<- interface{}) {
if item.(int)%2 != 0 { if item.(int)%2 != 0 {
pipe <- item pipe <- item
} }
}, UnlimitedWorkers()).ForEach(func(item any) { }, UnlimitedWorkers()).ForEach(func(item interface{}) {
result += item.(int) result += item.(int)
}) })
assert.Equal(t, 9, result) assert.Equal(t, 9, result)
@@ -398,16 +398,16 @@ func TestWalk(t *testing.T) {
func TestStream_AnyMach(t *testing.T) { func TestStream_AnyMach(t *testing.T) {
runCheckedTest(t, func(t *testing.T) { runCheckedTest(t, func(t *testing.T) {
assetEqual(t, false, Just(1, 2, 3).AnyMatch(func(item any) bool { assetEqual(t, false, Just(1, 2, 3).AnyMach(func(item interface{}) bool {
return item.(int) == 4 return item.(int) == 4
})) }))
assetEqual(t, false, Just(1, 2, 3).AnyMatch(func(item any) bool { assetEqual(t, false, Just(1, 2, 3).AnyMach(func(item interface{}) bool {
return item.(int) == 0 return item.(int) == 0
})) }))
assetEqual(t, true, Just(1, 2, 3).AnyMatch(func(item any) bool { assetEqual(t, true, Just(1, 2, 3).AnyMach(func(item interface{}) bool {
return item.(int) == 2 return item.(int) == 2
})) }))
assetEqual(t, true, Just(1, 2, 3).AnyMatch(func(item any) bool { assetEqual(t, true, Just(1, 2, 3).AnyMach(func(item interface{}) bool {
return item.(int) == 2 return item.(int) == 2
})) }))
}) })
@@ -416,17 +416,17 @@ func TestStream_AnyMach(t *testing.T) {
func TestStream_AllMach(t *testing.T) { func TestStream_AllMach(t *testing.T) {
runCheckedTest(t, func(t *testing.T) { runCheckedTest(t, func(t *testing.T) {
assetEqual( assetEqual(
t, true, Just(1, 2, 3).AllMatch(func(item any) bool { t, true, Just(1, 2, 3).AllMach(func(item interface{}) bool {
return true return true
}), }),
) )
assetEqual( assetEqual(
t, false, Just(1, 2, 3).AllMatch(func(item any) bool { t, false, Just(1, 2, 3).AllMach(func(item interface{}) bool {
return false return false
}), }),
) )
assetEqual( assetEqual(
t, false, Just(1, 2, 3).AllMatch(func(item any) bool { t, false, Just(1, 2, 3).AllMach(func(item interface{}) bool {
return item.(int) == 1 return item.(int) == 1
}), }),
) )
@@ -436,17 +436,17 @@ func TestStream_AllMach(t *testing.T) {
func TestStream_NoneMatch(t *testing.T) { func TestStream_NoneMatch(t *testing.T) {
runCheckedTest(t, func(t *testing.T) { runCheckedTest(t, func(t *testing.T) {
assetEqual( assetEqual(
t, true, Just(1, 2, 3).NoneMatch(func(item any) bool { t, true, Just(1, 2, 3).NoneMatch(func(item interface{}) bool {
return false return false
}), }),
) )
assetEqual( assetEqual(
t, false, Just(1, 2, 3).NoneMatch(func(item any) bool { t, false, Just(1, 2, 3).NoneMatch(func(item interface{}) bool {
return true return true
}), }),
) )
assetEqual( assetEqual(
t, true, Just(1, 2, 3).NoneMatch(func(item any) bool { t, true, Just(1, 2, 3).NoneMatch(func(item interface{}) bool {
return item.(int) == 4 return item.(int) == 4
}), }),
) )
@@ -455,19 +455,19 @@ func TestStream_NoneMatch(t *testing.T) {
func TestConcat(t *testing.T) { func TestConcat(t *testing.T) {
runCheckedTest(t, func(t *testing.T) { runCheckedTest(t, func(t *testing.T) {
a1 := []any{1, 2, 3} a1 := []interface{}{1, 2, 3}
a2 := []any{4, 5, 6} a2 := []interface{}{4, 5, 6}
s1 := Just(a1...) s1 := Just(a1...)
s2 := Just(a2...) s2 := Just(a2...)
stream := Concat(s1, s2) stream := Concat(s1, s2)
var items []any var items []interface{}
for item := range stream.source { for item := range stream.source {
items = append(items, item) items = append(items, item)
} }
sort.Slice(items, func(i, j int) bool { sort.Slice(items, func(i, j int) bool {
return items[i].(int) < items[j].(int) return items[i].(int) < items[j].(int)
}) })
ints := make([]any, 0) ints := make([]interface{}, 0)
ints = append(ints, a1...) ints = append(ints, a1...)
ints = append(ints, a2...) ints = append(ints, a2...)
assetEqual(t, ints, items) assetEqual(t, ints, items)
@@ -479,7 +479,7 @@ func TestStream_Skip(t *testing.T) {
assetEqual(t, 3, Just(1, 2, 3, 4).Skip(1).Count()) assetEqual(t, 3, Just(1, 2, 3, 4).Skip(1).Count())
assetEqual(t, 1, Just(1, 2, 3, 4).Skip(3).Count()) assetEqual(t, 1, Just(1, 2, 3, 4).Skip(3).Count())
assetEqual(t, 4, Just(1, 2, 3, 4).Skip(0).Count()) assetEqual(t, 4, Just(1, 2, 3, 4).Skip(0).Count())
equal(t, Just(1, 2, 3, 4).Skip(3), []any{4}) equal(t, Just(1, 2, 3, 4).Skip(3), []interface{}{4})
assert.Panics(t, func() { assert.Panics(t, func() {
Just(1, 2, 3, 4).Skip(-1) Just(1, 2, 3, 4).Skip(-1)
}) })
@@ -489,104 +489,27 @@ func TestStream_Skip(t *testing.T) {
func TestStream_Concat(t *testing.T) { func TestStream_Concat(t *testing.T) {
runCheckedTest(t, func(t *testing.T) { runCheckedTest(t, func(t *testing.T) {
stream := Just(1).Concat(Just(2), Just(3)) stream := Just(1).Concat(Just(2), Just(3))
var items []any var items []interface{}
for item := range stream.source { for item := range stream.source {
items = append(items, item) items = append(items, item)
} }
sort.Slice(items, func(i, j int) bool { sort.Slice(items, func(i, j int) bool {
return items[i].(int) < items[j].(int) return items[i].(int) < items[j].(int)
}) })
assetEqual(t, []any{1, 2, 3}, items) assetEqual(t, []interface{}{1, 2, 3}, items)
just := Just(1) just := Just(1)
equal(t, just.Concat(just), []any{1}) equal(t, just.Concat(just), []interface{}{1})
})
}
func TestStream_Max(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
tests := []struct {
name string
elements []any
max any
}{
{
name: "no elements with nil",
},
{
name: "no elements",
elements: []any{},
max: nil,
},
{
name: "1 element",
elements: []any{1},
max: 1,
},
{
name: "multiple elements",
elements: []any{1, 2, 9, 5, 8},
max: 9,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
val := Just(test.elements...).Max(func(a, b any) bool {
return a.(int) < b.(int)
})
assetEqual(t, test.max, val)
})
}
})
}
func TestStream_Min(t *testing.T) {
runCheckedTest(t, func(t *testing.T) {
tests := []struct {
name string
elements []any
min any
}{
{
name: "no elements with nil",
min: nil,
},
{
name: "no elements",
elements: []any{},
min: nil,
},
{
name: "1 element",
elements: []any{1},
min: 1,
},
{
name: "multiple elements",
elements: []any{-1, 1, 2, 9, 5, 8},
min: -1,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
val := Just(test.elements...).Min(func(a, b any) bool {
return a.(int) < b.(int)
})
assetEqual(t, test.min, val)
})
}
}) })
} }
func BenchmarkParallelMapReduce(b *testing.B) { func BenchmarkParallelMapReduce(b *testing.B) {
b.ReportAllocs() b.ReportAllocs()
mapper := func(v any) any { mapper := func(v interface{}) interface{} {
return v.(int64) * v.(int64) return v.(int64) * v.(int64)
} }
reducer := func(input <-chan any) (any, error) { reducer := func(input <-chan interface{}) (interface{}, error) {
var result int64 var result int64
for v := range input { for v := range input {
result += v.(int64) result += v.(int64)
@@ -594,7 +517,7 @@ func BenchmarkParallelMapReduce(b *testing.B) {
return result, nil return result, nil
} }
b.ResetTimer() b.ResetTimer()
From(func(input chan<- any) { From(func(input chan<- interface{}) {
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
for pb.Next() { for pb.Next() {
input <- int64(rand.Int()) input <- int64(rand.Int())
@@ -606,10 +529,10 @@ func BenchmarkParallelMapReduce(b *testing.B) {
func BenchmarkMapReduce(b *testing.B) { func BenchmarkMapReduce(b *testing.B) {
b.ReportAllocs() b.ReportAllocs()
mapper := func(v any) any { mapper := func(v interface{}) interface{} {
return v.(int64) * v.(int64) return v.(int64) * v.(int64)
} }
reducer := func(input <-chan any) (any, error) { reducer := func(input <-chan interface{}) (interface{}, error) {
var result int64 var result int64
for v := range input { for v := range input {
result += v.(int64) result += v.(int64)
@@ -617,21 +540,21 @@ func BenchmarkMapReduce(b *testing.B) {
return result, nil return result, nil
} }
b.ResetTimer() b.ResetTimer()
From(func(input chan<- any) { From(func(input chan<- interface{}) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
input <- int64(rand.Int()) input <- int64(rand.Int())
} }
}).Map(mapper).Reduce(reducer) }).Map(mapper).Reduce(reducer)
} }
func assetEqual(t *testing.T, except, data any) { func assetEqual(t *testing.T, except, data interface{}) {
if !reflect.DeepEqual(except, data) { if !reflect.DeepEqual(except, data) {
t.Errorf(" %v, want %v", data, except) t.Errorf(" %v, want %v", data, except)
} }
} }
func equal(t *testing.T, stream Stream, data []any) { func equal(t *testing.T, stream Stream, data []interface{}) {
items := make([]any, 0) items := make([]interface{}, 0)
for item := range stream.source { for item := range stream.source {
items = append(items, item) items = append(items, item)
} }

View File

@@ -29,7 +29,7 @@ func DoWithTimeout(fn func() error, timeout time.Duration, opts ...DoOption) err
// create channel with buffer size 1 to avoid goroutine leak // create channel with buffer size 1 to avoid goroutine leak
done := make(chan error, 1) done := make(chan error, 1)
panicChan := make(chan any, 1) panicChan := make(chan interface{}, 1)
go func() { go func() {
defer func() { defer func() {
if p := recover(); p != nil { if p := recover(); p != nil {

View File

@@ -26,7 +26,7 @@ type (
hashFunc Func hashFunc Func
replicas int replicas int
keys []uint64 keys []uint64
ring map[uint64][]any ring map[uint64][]interface{}
nodes map[string]lang.PlaceholderType nodes map[string]lang.PlaceholderType
lock sync.RWMutex lock sync.RWMutex
} }
@@ -50,21 +50,21 @@ func NewCustomConsistentHash(replicas int, fn Func) *ConsistentHash {
return &ConsistentHash{ return &ConsistentHash{
hashFunc: fn, hashFunc: fn,
replicas: replicas, replicas: replicas,
ring: make(map[uint64][]any), ring: make(map[uint64][]interface{}),
nodes: make(map[string]lang.PlaceholderType), nodes: make(map[string]lang.PlaceholderType),
} }
} }
// Add adds the node with the number of h.replicas, // Add adds the node with the number of h.replicas,
// the later call will overwrite the replicas of the former calls. // the later call will overwrite the replicas of the former calls.
func (h *ConsistentHash) Add(node any) { func (h *ConsistentHash) Add(node interface{}) {
h.AddWithReplicas(node, h.replicas) h.AddWithReplicas(node, h.replicas)
} }
// AddWithReplicas adds the node with the number of replicas, // AddWithReplicas adds the node with the number of replicas,
// replicas will be truncated to h.replicas if it's larger than h.replicas, // replicas will be truncated to h.replicas if it's larger than h.replicas,
// the later call will overwrite the replicas of the former calls. // the later call will overwrite the replicas of the former calls.
func (h *ConsistentHash) AddWithReplicas(node any, replicas int) { func (h *ConsistentHash) AddWithReplicas(node interface{}, replicas int) {
h.Remove(node) h.Remove(node)
if replicas > h.replicas { if replicas > h.replicas {
@@ -89,7 +89,7 @@ func (h *ConsistentHash) AddWithReplicas(node any, replicas int) {
// AddWithWeight adds the node with weight, the weight can be 1 to 100, indicates the percent, // AddWithWeight adds the node with weight, the weight can be 1 to 100, indicates the percent,
// the later call will overwrite the replicas of the former calls. // the later call will overwrite the replicas of the former calls.
func (h *ConsistentHash) AddWithWeight(node any, weight int) { func (h *ConsistentHash) AddWithWeight(node interface{}, weight int) {
// don't need to make sure weight not larger than TopWeight, // don't need to make sure weight not larger than TopWeight,
// because AddWithReplicas makes sure replicas cannot be larger than h.replicas // because AddWithReplicas makes sure replicas cannot be larger than h.replicas
replicas := h.replicas * weight / TopWeight replicas := h.replicas * weight / TopWeight
@@ -97,7 +97,7 @@ func (h *ConsistentHash) AddWithWeight(node any, weight int) {
} }
// Get returns the corresponding node from h base on the given v. // Get returns the corresponding node from h base on the given v.
func (h *ConsistentHash) Get(v any) (any, bool) { func (h *ConsistentHash) Get(v interface{}) (interface{}, bool) {
h.lock.RLock() h.lock.RLock()
defer h.lock.RUnlock() defer h.lock.RUnlock()
@@ -124,7 +124,7 @@ func (h *ConsistentHash) Get(v any) (any, bool) {
} }
// Remove removes the given node from h. // Remove removes the given node from h.
func (h *ConsistentHash) Remove(node any) { func (h *ConsistentHash) Remove(node interface{}) {
nodeRepr := repr(node) nodeRepr := repr(node)
h.lock.Lock() h.lock.Lock()
@@ -177,10 +177,10 @@ func (h *ConsistentHash) removeNode(nodeRepr string) {
delete(h.nodes, nodeRepr) delete(h.nodes, nodeRepr)
} }
func innerRepr(node any) string { func innerRepr(node interface{}) string {
return fmt.Sprintf("%d:%v", prime, node) return fmt.Sprintf("%d:%v", prime, node)
} }
func repr(node any) string { func repr(node interface{}) string {
return lang.Repr(node) return lang.Repr(node)
} }

Some files were not shown because too many files have changed in this diff Show More