Compare commits

..

1 Commits

Author SHA1 Message Date
spectatorMrZ
304fb182bb Fix pg model generation without tag (#1407)
1. fix pg model struct haven't tag
2. add pg command test from datasource
2022-01-08 21:48:09 +08:00
873 changed files with 15640 additions and 42302 deletions

View File

@@ -1,6 +1,3 @@
comment: comment: false
layout: "flags, files"
behavior: once
require_changes: true
ignore: ignore:
- "tools" - "tools"

3
.github/FUNDING.yml vendored
View File

@@ -9,5 +9,4 @@ community_bridge: # Replace with a single Community Bridge project-name e.g., cl
liberapay: # Replace with a single Liberapay username liberapay: # Replace with a single Liberapay username
issuehunt: # Replace with a single IssueHunt username issuehunt: # Replace with a single IssueHunt username
otechie: # Replace with a single Otechie username otechie: # Replace with a single Otechie username
custom: # https://gitee.com/kevwan/static/raw/master/images/sponsor.jpg custom: https://gitee.com/kevwan/static/raw/master/images/sponsor.jpg
ethereum: 0x5052b7f6B937B02563996D23feb69b38D06Ca150 | kevwan

View File

@@ -1,11 +0,0 @@
# To get started with Dependabot version updates, you'll need to specify which
# package ecosystems to update and where the package manifests are located.
# Please see the documentation for all configuration options:
# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates
version: 2
updates:
- package-ecosystem: "gomod" # See documentation for possible values
directory: "/" # Location of package manifests
schedule:
interval: "daily"

View File

@@ -35,11 +35,11 @@ jobs:
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v3 uses: actions/checkout@v2
# Initializes the CodeQL tools for scanning. # Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL - name: Initialize CodeQL
uses: github/codeql-action/init@v2 uses: github/codeql-action/init@v1
with: with:
languages: ${{ matrix.language }} languages: ${{ matrix.language }}
# If you wish to specify custom queries, you can do so here or in a config file. # If you wish to specify custom queries, you can do so here or in a config file.
@@ -50,7 +50,7 @@ jobs:
# Autobuild attempts to build any compiled languages (C/C++, C#, or Java). # Autobuild attempts to build any compiled languages (C/C++, C#, or Java).
# If this step fails, then you should remove it and run the build manually (see below) # If this step fails, then you should remove it and run the build manually (see below)
- name: Autobuild - name: Autobuild
uses: github/codeql-action/autobuild@v2 uses: github/codeql-action/autobuild@v1
# Command-line programs to run using the OS shell. # Command-line programs to run using the OS shell.
# 📚 https://git.io/JvXDl # 📚 https://git.io/JvXDl
@@ -64,4 +64,4 @@ jobs:
# make release # make release
- name: Perform CodeQL Analysis - name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v2 uses: github/codeql-action/analyze@v1

View File

@@ -7,55 +7,32 @@ on:
branches: [ master ] branches: [ master ]
jobs: jobs:
test-linux: build:
name: Linux name: Build
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Check out code into the Go module directory
uses: actions/checkout@v3
- name: Set up Go 1.x - name: Set up Go 1.x
uses: actions/setup-go@v3 uses: actions/setup-go@v2
with: with:
go-version: ^1.16 go-version: ^1.14
check-latest: true id: go
cache: true
id: go
- name: Get dependencies - name: Check out code into the Go module directory
run: | uses: actions/checkout@v2
go get -v -t -d ./...
- name: Lint - name: Get dependencies
run: | run: |
go vet -stdmethods=false $(go list ./...) go get -v -t -d ./...
go install mvdan.cc/gofumpt@latest
test -z "$(gofumpt -l -extra .)" || echo "Please run 'gofumpt -l -w -extra .'"
- name: Test - name: Lint
run: go test -race -coverprofile=coverage.txt -covermode=atomic ./... run: |
go vet -stdmethods=false $(go list ./...)
go install mvdan.cc/gofumpt@latest
test -z "$(gofumpt -s -l -extra .)" || echo "Please run 'gofumpt -l -w -extra .'"
- name: Codecov - name: Test
uses: codecov/codecov-action@v3 run: go test -race -coverprofile=coverage.txt -covermode=atomic ./...
test-win: - name: Codecov
name: Windows uses: codecov/codecov-action@v2
runs-on: windows-latest
steps:
- name: Checkout codebase
uses: actions/checkout@v3
- name: Set up Go 1.x
uses: actions/setup-go@v3
with:
# use 1.16 to guarantee Go 1.16 compatibility
go-version: 1.16
check-latest: true
cache: true
- name: Test
run: |
go mod verify
go mod download
go test -v -race ./...
cd tools/goctl && go build -v goctl.go

View File

@@ -1,18 +0,0 @@
name: 'issue-translator'
on:
issue_comment:
types: [created]
issues:
types: [opened]
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: usthe/issues-translate-action@v2.7
with:
IS_MODIFY_TITLE: true
# not require, default false, . Decide whether to modify the issue title
# if true, the robot account @Issues-translate-bot must have modification permissions, invite @Issues-translate-bot to your project or use your custom bot.
CUSTOM_BOT_NOTE: Bot detected the issue body's language is not English, translate it automatically. 👯👭🏻🧑‍🤝‍🧑👫🧑🏿‍🤝‍🧑🏻👩🏾‍🤝‍👨🏿👬🏿
# not require. Customize the translation robot prefix message.

View File

@@ -7,10 +7,10 @@ jobs:
close-issues: close-issues:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/stale@v6 - uses: actions/stale@v3
with: with:
days-before-issue-stale: 365 days-before-issue-stale: 30
days-before-issue-close: 90 days-before-issue-close: 14
stale-issue-label: "stale" stale-issue-label: "stale"
stale-issue-message: "This issue is stale because it has been open for 30 days with no activity." stale-issue-message: "This issue is stale because it has been open for 30 days with no activity."
close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale." close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale."

View File

@@ -1,28 +0,0 @@
on:
push:
tags:
- "tools/goctl/*"
jobs:
releases-matrix:
name: Release goctl binary
runs-on: ubuntu-latest
strategy:
matrix:
# build and publish in parallel: linux/386, linux/amd64, linux/arm64,
# windows/386, windows/amd64, windows/arm64, darwin/amd64, darwin/arm64
goos: [ linux, windows, darwin ]
goarch: [ "386", amd64, arm64 ]
exclude:
- goarch: "386"
goos: darwin
steps:
- uses: actions/checkout@v3
- uses: zeromicro/go-zero-release-action@master
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
goos: ${{ matrix.goos }}
goarch: ${{ matrix.goarch }}
goversion: "https://dl.google.com/go/go1.17.5.linux-amd64.tar.gz"
project_path: "tools/goctl"
binary_name: "goctl"
extra_files: tools/goctl/readme.md tools/goctl/readme-cn.md

View File

@@ -5,7 +5,7 @@ jobs:
name: runner / staticcheck name: runner / staticcheck
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v2
- uses: reviewdog/action-staticcheck@v1 - uses: reviewdog/action-staticcheck@v1
with: with:
github_token: ${{ secrets.github_token }} github_token: ${{ secrets.github_token }}

5
.gitignore vendored
View File

@@ -16,13 +16,10 @@
**/logs **/logs
# for test purpose # for test purpose
**/adhoc adhoc
go.work
go.work.sum
# gitlab ci # gitlab ci
.cache .cache
.golangci.yml
# vim auto backup file # vim auto backup file
*~ *~

View File

@@ -40,7 +40,7 @@ We will help you to contribute in different areas like filing issues, developing
getting your work reviewed and merged. getting your work reviewed and merged.
If you have questions about the development process, If you have questions about the development process,
feel free to [file an issue](https://github.com/zeromicro/go-zero/issues/new/choose). feel free to [file an issue](https://github.com/tal-tech/go-zero/issues/new/choose).
## Find something to work on ## Find something to work on
@@ -50,10 +50,10 @@ Here is how you get started.
### Find a good first topic ### Find a good first topic
[go-zero](https://github.com/zeromicro/go-zero) has beginner-friendly issues that provide a good first issue. [go-zero](https://github.com/tal-tech/go-zero) has beginner-friendly issues that provide a good first issue.
For example, [go-zero](https://github.com/zeromicro/go-zero) has For example, [go-zero](https://github.com/tal-tech/go-zero) has
[help wanted](https://github.com/zeromicro/go-zero/issues?q=is%3Aopen+is%3Aissue+label%3A%22help+wanted%22) and [help wanted](https://github.com/tal-tech/go-zero/issues?q=is%3Aopen+is%3Aissue+label%3A%22help+wanted%22) and
[good first issue](https://github.com/zeromicro/go-zero/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22) [good first issue](https://github.com/tal-tech/go-zero/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)
labels for issues that should not need deep knowledge of the system. labels for issues that should not need deep knowledge of the system.
We can help new contributors who wish to work on such issues. We can help new contributors who wish to work on such issues.
@@ -79,7 +79,7 @@ This is a rough outline of what a contributor's workflow looks like:
- Create a topic branch from where to base the contribution. This is usually master. - Create a topic branch from where to base the contribution. This is usually master.
- Make commits of logical units. - Make commits of logical units.
- Push changes in a topic branch to a personal fork of the repository. - Push changes in a topic branch to a personal fork of the repository.
- Submit a pull request to [go-zero](https://github.com/zeromicro/go-zero). - Submit a pull request to [go-zero](https://github.com/tal-tech/go-zero).
## Creating Pull Requests ## Creating Pull Requests

View File

@@ -1,6 +1,6 @@
MIT License MIT License
Copyright (c) 2022 zeromicro Copyright (c) 2020 xiaoheiban_server_go
Permission is hereby granted, free of charge, to any person obtaining a copy Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal of this software and associated documentation files (the "Software"), to deal

28
ROADMAP.md Normal file
View File

@@ -0,0 +1,28 @@
# go-zero Roadmap
This document defines a high level roadmap for go-zero development and upcoming releases.
Community and contributor involvement is vital for successfully implementing all desired items for each release.
We hope that the items listed below will inspire further engagement from the community to keep go-zero progressing and shipping exciting and valuable features.
## 2021 Q2
- [x] Support service discovery through K8S client api
- [x] Log full sql statements for easier sql problem solving
## 2021 Q3
- [x] Support `goctl model pg` to support PostgreSQL code generation
- [x] Adapt builtin tracing mechanism to opentracing solutions
## 2021 Q4
- [x] Support `username/password` authentication in ETCD
- [x] Support `SSL/TLS` in ETCD
- [x] Support `SSL/TLS` in `zRPC`
- [x] Support `TLS` in redis connections
- [x] Support `goctl bug` to report bugs conveniently
## 2022
- [ ] Support `goctl mock` command to start a mocking server with given `.api` file
- [ ] Add `httpx.Client` with governance, like circuit breaker etc.
- [ ] Support `goctl doctor` command to report potential issues for given service
- [ ] Support `context` in redis related methods for timeout and tracing
- [ ] Support `context` in sql related methods for timeout and tracing
- [ ] Support `context` in mongodb related methods for timeout and tracing

View File

@@ -4,8 +4,8 @@ import (
"errors" "errors"
"strconv" "strconv"
"github.com/zeromicro/go-zero/core/hash" "github.com/tal-tech/go-zero/core/hash"
"github.com/zeromicro/go-zero/core/stores/redis" "github.com/tal-tech/go-zero/core/stores/redis"
) )
const ( const (
@@ -69,8 +69,11 @@ func (f *Filter) Exists(data []byte) (bool, error) {
if err != nil { if err != nil {
return false, err return false, err
} }
if !isSet {
return false, nil
}
return isSet, nil return true, nil
} }
func (f *Filter) getLocations(data []byte) []uint { func (f *Filter) getLocations(data []byte) []uint {

View File

@@ -4,7 +4,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/stores/redis/redistest" "github.com/tal-tech/go-zero/core/stores/redis/redistest"
) )
func TestRedisBitSet_New_Set_Test(t *testing.T) { func TestRedisBitSet_New_Set_Test(t *testing.T) {

View File

@@ -5,12 +5,12 @@ import (
"fmt" "fmt"
"strings" "strings"
"sync" "sync"
"time"
"github.com/zeromicro/go-zero/core/mathx" "github.com/tal-tech/go-zero/core/mathx"
"github.com/zeromicro/go-zero/core/proc" "github.com/tal-tech/go-zero/core/proc"
"github.com/zeromicro/go-zero/core/stat" "github.com/tal-tech/go-zero/core/stat"
"github.com/zeromicro/go-zero/core/stringx" "github.com/tal-tech/go-zero/core/stringx"
"github.com/tal-tech/go-zero/core/timex"
) )
const ( const (
@@ -171,7 +171,7 @@ func (lt loggedThrottle) allow() (Promise, error) {
func (lt loggedThrottle) doReq(req func() error, fallback func(err error) error, acceptable Acceptable) error { func (lt loggedThrottle) doReq(req func() error, fallback func(err error) error, acceptable Acceptable) error {
return lt.logError(lt.internalThrottle.doReq(req, fallback, func(err error) bool { return lt.logError(lt.internalThrottle.doReq(req, fallback, func(err error) bool {
accept := acceptable(err) accept := acceptable(err)
if !accept && err != nil { if !accept {
lt.errWin.add(err.Error()) lt.errWin.add(err.Error())
} }
return accept return accept
@@ -198,7 +198,7 @@ type errorWindow struct {
func (ew *errorWindow) add(reason string) { func (ew *errorWindow) add(reason string) {
ew.lock.Lock() ew.lock.Lock()
ew.reasons[ew.index] = fmt.Sprintf("%s %s", time.Now().Format(timeFormat), reason) ew.reasons[ew.index] = fmt.Sprintf("%s %s", timex.Time().Format(timeFormat), reason)
ew.index = (ew.index + 1) % numHistoryReasons ew.index = (ew.index + 1) % numHistoryReasons
ew.count = mathx.MinInt(ew.count+1, numHistoryReasons) ew.count = mathx.MinInt(ew.count+1, numHistoryReasons)
ew.lock.Unlock() ew.lock.Unlock()

View File

@@ -8,7 +8,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/stat" "github.com/tal-tech/go-zero/core/stat"
) )
func init() { func init() {

View File

@@ -6,7 +6,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/stat" "github.com/tal-tech/go-zero/core/stat"
) )
func init() { func init() {

View File

@@ -4,8 +4,8 @@ import (
"math" "math"
"time" "time"
"github.com/zeromicro/go-zero/core/collection" "github.com/tal-tech/go-zero/core/collection"
"github.com/zeromicro/go-zero/core/mathx" "github.com/tal-tech/go-zero/core/mathx"
) )
const ( const (

View File

@@ -7,9 +7,9 @@ import (
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/collection" "github.com/tal-tech/go-zero/core/collection"
"github.com/zeromicro/go-zero/core/mathx" "github.com/tal-tech/go-zero/core/mathx"
"github.com/zeromicro/go-zero/core/stat" "github.com/tal-tech/go-zero/core/stat"
) )
const ( const (

View File

@@ -20,16 +20,16 @@ func (b noOpBreaker) Do(req func() error) error {
return req() return req()
} }
func (b noOpBreaker) DoWithAcceptable(req func() error, _ Acceptable) error { func (b noOpBreaker) DoWithAcceptable(req func() error, acceptable Acceptable) error {
return req() return req()
} }
func (b noOpBreaker) DoWithFallback(req func() error, _ func(err error) error) error { func (b noOpBreaker) DoWithFallback(req func() error, fallback func(err error) error) error {
return req() return req()
} }
func (b noOpBreaker) DoWithFallbackAcceptable(req func() error, _ func(err error) error, func (b noOpBreaker) DoWithFallbackAcceptable(req func() error, fallback func(err error) error,
_ Acceptable) error { acceptable Acceptable) error {
return req() return req()
} }
@@ -38,5 +38,5 @@ type nopPromise struct{}
func (p nopPromise) Accept() { func (p nopPromise) Accept() {
} }
func (p nopPromise) Reject(_ string) { func (p nopPromise) Reject(reason string) {
} }

View File

@@ -8,8 +8,8 @@ import (
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/iox" "github.com/tal-tech/go-zero/core/iox"
"github.com/zeromicro/go-zero/core/lang" "github.com/tal-tech/go-zero/core/lang"
) )
func TestEnterToContinue(t *testing.T) { func TestEnterToContinue(t *testing.T) {

View File

@@ -7,7 +7,7 @@ import (
"encoding/base64" "encoding/base64"
"errors" "errors"
"github.com/zeromicro/go-zero/core/logx" "github.com/tal-tech/go-zero/core/logx"
) )
// ErrPaddingSize indicates bad padding size. // ErrPaddingSize indicates bad padding size.
@@ -32,11 +32,9 @@ func NewECBEncrypter(b cipher.Block) cipher.BlockMode {
return (*ecbEncrypter)(newECB(b)) return (*ecbEncrypter)(newECB(b))
} }
// BlockSize returns the mode's block size.
func (x *ecbEncrypter) BlockSize() int { return x.blockSize } func (x *ecbEncrypter) BlockSize() int { return x.blockSize }
// CryptBlocks encrypts a number of blocks. The length of src must be a multiple of // why we don't return error is because cipher.BlockMode doesn't allow this
// the block size. Dst and src must overlap entirely or not at all.
func (x *ecbEncrypter) CryptBlocks(dst, src []byte) { func (x *ecbEncrypter) CryptBlocks(dst, src []byte) {
if len(src)%x.blockSize != 0 { if len(src)%x.blockSize != 0 {
logx.Error("crypto/cipher: input not full blocks") logx.Error("crypto/cipher: input not full blocks")
@@ -61,13 +59,11 @@ func NewECBDecrypter(b cipher.Block) cipher.BlockMode {
return (*ecbDecrypter)(newECB(b)) return (*ecbDecrypter)(newECB(b))
} }
// BlockSize returns the mode's block size.
func (x *ecbDecrypter) BlockSize() int { func (x *ecbDecrypter) BlockSize() int {
return x.blockSize return x.blockSize
} }
// CryptBlocks decrypts a number of blocks. The length of src must be a multiple of // why we don't return error is because cipher.BlockMode doesn't allow this
// the block size. Dst and src must overlap entirely or not at all.
func (x *ecbDecrypter) CryptBlocks(dst, src []byte) { func (x *ecbDecrypter) CryptBlocks(dst, src []byte) {
if len(src)%x.blockSize != 0 { if len(src)%x.blockSize != 0 {
logx.Error("crypto/cipher: input not full blocks") logx.Error("crypto/cipher: input not full blocks")

View File

@@ -1,7 +1,6 @@
package codec package codec
import ( import (
"crypto/aes"
"encoding/base64" "encoding/base64"
"testing" "testing"
@@ -11,8 +10,7 @@ import (
func TestAesEcb(t *testing.T) { func TestAesEcb(t *testing.T) {
var ( var (
key = []byte("q4t7w!z%C*F-JaNdRgUjXn2r5u8x/A?D") key = []byte("q4t7w!z%C*F-JaNdRgUjXn2r5u8x/A?D")
val = []byte("helloworld") val = []byte("hello")
valLong = []byte("helloworldlong..")
badKey1 = []byte("aaaaaaaaa") badKey1 = []byte("aaaaaaaaa")
// more than 32 chars // more than 32 chars
badKey2 = []byte("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa") badKey2 = []byte("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")
@@ -33,39 +31,6 @@ func TestAesEcb(t *testing.T) {
src, err := EcbDecrypt(key, dst) src, err := EcbDecrypt(key, dst)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, val, src) assert.Equal(t, val, src)
block, err := aes.NewCipher(key)
assert.NoError(t, err)
encrypter := NewECBEncrypter(block)
assert.Equal(t, 16, encrypter.BlockSize())
decrypter := NewECBDecrypter(block)
assert.Equal(t, 16, decrypter.BlockSize())
dst = make([]byte, 8)
encrypter.CryptBlocks(dst, val)
for _, b := range dst {
assert.Equal(t, byte(0), b)
}
dst = make([]byte, 8)
encrypter.CryptBlocks(dst, valLong)
for _, b := range dst {
assert.Equal(t, byte(0), b)
}
dst = make([]byte, 8)
decrypter.CryptBlocks(dst, val)
for _, b := range dst {
assert.Equal(t, byte(0), b)
}
dst = make([]byte, 8)
decrypter.CryptBlocks(dst, valLong)
for _, b := range dst {
assert.Equal(t, byte(0), b)
}
_, err = EcbEncryptBase64("cTR0N3dDKkYtSmFOZFJnVWpYbjJyNXU4eC9BP0QK", "aGVsbG93b3JsZGxvbmcuLgo=")
assert.Error(t, err)
} }
func TestAesEcbBase64(t *testing.T) { func TestAesEcbBase64(t *testing.T) {

View File

@@ -80,17 +80,3 @@ func TestKeyBytes(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, len(key.Bytes()) > 0) assert.True(t, len(key.Bytes()) > 0)
} }
func TestDHOnErrors(t *testing.T) {
key, err := GenerateKey()
assert.Nil(t, err)
assert.NotEmpty(t, key.Bytes())
_, err = ComputeKey(key.PubKey, key.PriKey)
assert.NoError(t, err)
_, err = ComputeKey(nil, key.PriKey)
assert.Error(t, err)
_, err = ComputeKey(key.PubKey, nil)
assert.Error(t, err)
assert.NotNil(t, NewPublicKey([]byte("")))
}

View File

@@ -7,7 +7,7 @@ import (
"encoding/base64" "encoding/base64"
"encoding/pem" "encoding/pem"
"errors" "errors"
"os" "io/ioutil"
) )
var ( var (
@@ -48,7 +48,7 @@ type (
// NewRsaDecrypter returns a RsaDecrypter with the given file. // NewRsaDecrypter returns a RsaDecrypter with the given file.
func NewRsaDecrypter(file string) (RsaDecrypter, error) { func NewRsaDecrypter(file string) (RsaDecrypter, error) {
content, err := os.ReadFile(file) content, err := ioutil.ReadFile(file)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -5,7 +5,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/fs" "github.com/tal-tech/go-zero/core/fs"
) )
const ( const (

View File

@@ -6,9 +6,9 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/zeromicro/go-zero/core/logx" "github.com/tal-tech/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/mathx" "github.com/tal-tech/go-zero/core/mathx"
"github.com/zeromicro/go-zero/core/syncx" "github.com/tal-tech/go-zero/core/syncx"
) )
const ( const (
@@ -26,7 +26,7 @@ type (
// CacheOption defines the method to customize a Cache. // CacheOption defines the method to customize a Cache.
CacheOption func(cache *Cache) CacheOption func(cache *Cache)
// A Cache object is an in-memory cache. // A Cache object is a in-memory cache.
Cache struct { Cache struct {
name string name string
lock sync.Mutex lock sync.Mutex
@@ -98,18 +98,13 @@ func (c *Cache) Get(key string) (interface{}, bool) {
// Set sets value into c with key. // Set sets value into c with key.
func (c *Cache) Set(key string, value interface{}) { func (c *Cache) Set(key string, value interface{}) {
c.SetWithExpire(key, value, c.expire)
}
// SetWithExpire sets value into c with key and expire with the given value.
func (c *Cache) SetWithExpire(key string, value interface{}, expire time.Duration) {
c.lock.Lock() c.lock.Lock()
_, ok := c.data[key] _, ok := c.data[key]
c.data[key] = value c.data[key] = value
c.lruCache.add(key) c.lruCache.add(key)
c.lock.Unlock() c.lock.Unlock()
expiry := c.unstableExpiry.AroundDuration(expire) expiry := c.unstableExpiry.AroundDuration(c.expire)
if ok { if ok {
c.timingWheel.MoveTimer(key, expiry) c.timingWheel.MoveTimer(key, expiry)
} else { } else {

View File

@@ -18,7 +18,7 @@ func TestCacheSet(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
cache.Set("first", "first element") cache.Set("first", "first element")
cache.SetWithExpire("second", "second element", time.Second*3) cache.Set("second", "second element")
value, ok := cache.Get("first") value, ok := cache.Get("first")
assert.True(t, ok) assert.True(t, ok)

View File

@@ -61,41 +61,3 @@ func TestPutMore(t *testing.T) {
assert.Equal(t, string(element), string(body.([]byte))) assert.Equal(t, string(element), string(body.([]byte)))
} }
} }
func TestPutMoreWithHeaderNotZero(t *testing.T) {
elements := [][]byte{
[]byte("hello"),
[]byte("world"),
[]byte("again"),
}
queue := NewQueue(4)
for i := range elements {
queue.Put(elements[i])
}
// take 1
body, ok := queue.Take()
assert.True(t, ok)
element, ok := body.([]byte)
assert.True(t, ok)
assert.Equal(t, element, []byte("hello"))
// put more
queue.Put([]byte("b4"))
queue.Put([]byte("b5")) // will store in elements[0]
queue.Put([]byte("b6")) // cause expansion
results := [][]byte{
[]byte("world"),
[]byte("again"),
[]byte("b4"),
[]byte("b5"),
[]byte("b6"),
}
for _, element := range results {
body, ok := queue.Take()
assert.True(t, ok)
assert.Equal(t, string(element), string(body.([]byte)))
}
}

View File

@@ -6,7 +6,7 @@ import "sync"
type Ring struct { type Ring struct {
elements []interface{} elements []interface{}
index int index int
lock sync.RWMutex lock sync.Mutex
} }
// NewRing returns a Ring object with the given size n. // NewRing returns a Ring object with the given size n.
@@ -31,8 +31,8 @@ func (r *Ring) Add(v interface{}) {
// Take takes all items from r. // Take takes all items from r.
func (r *Ring) Take() []interface{} { func (r *Ring) Take() []interface{} {
r.lock.RLock() r.lock.Lock()
defer r.lock.RUnlock() defer r.lock.Unlock()
var size int var size int
var start int var start int

View File

@@ -4,7 +4,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/zeromicro/go-zero/core/timex" "github.com/tal-tech/go-zero/core/timex"
) )
type ( type (

View File

@@ -6,7 +6,7 @@ import (
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/stringx" "github.com/tal-tech/go-zero/core/stringx"
) )
const duration = time.Millisecond * 50 const duration = time.Millisecond * 50

View File

@@ -68,24 +68,6 @@ func (m *SafeMap) Get(key interface{}) (interface{}, bool) {
return val, ok return val, ok
} }
// Range calls f sequentially for each key and value present in the map.
// If f returns false, range stops the iteration.
func (m *SafeMap) Range(f func(key, val interface{}) bool) {
m.lock.RLock()
defer m.lock.RUnlock()
for k, v := range m.dirtyOld {
if !f(k, v) {
return
}
}
for k, v := range m.dirtyNew {
if !f(k, v) {
return
}
}
}
// Set sets the value into m with the given key. // Set sets the value into m with the given key.
func (m *SafeMap) Set(key, value interface{}) { func (m *SafeMap) Set(key, value interface{}) {
m.lock.Lock() m.lock.Lock()

View File

@@ -1,11 +1,10 @@
package collection package collection
import ( import (
"sync/atomic"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/stringx" "github.com/tal-tech/go-zero/core/stringx"
) )
func TestSafeMap(t *testing.T) { func TestSafeMap(t *testing.T) {
@@ -108,42 +107,3 @@ func testSafeMapWithParameters(t *testing.T, size, exception int) {
} }
} }
} }
func TestSafeMap_Range(t *testing.T) {
const (
size = 100000
exception1 = 5
exception2 = 500
)
m := NewSafeMap()
newMap := NewSafeMap()
for i := 0; i < size; i++ {
m.Set(i, i)
}
for i := 0; i < size; i++ {
if i%exception1 == 0 {
m.Del(i)
}
}
for i := size; i < size<<1; i++ {
m.Set(i, i)
}
for i := size; i < size<<1; i++ {
if i%exception2 != 0 {
m.Del(i)
}
}
var count int32
m.Range(func(k, v interface{}) bool {
atomic.AddInt32(&count, 1)
newMap.Set(k, v)
return true
})
assert.Equal(t, int(atomic.LoadInt32(&count)), m.Size())
assert.Equal(t, m.dirtyNew, newMap.dirtyNew)
assert.Equal(t, m.dirtyOld, newMap.dirtyOld)
}

View File

@@ -1,8 +1,8 @@
package collection package collection
import ( import (
"github.com/zeromicro/go-zero/core/lang" "github.com/tal-tech/go-zero/core/lang"
"github.com/zeromicro/go-zero/core/logx" "github.com/tal-tech/go-zero/core/logx"
) )
const ( const (
@@ -29,7 +29,7 @@ func NewSet() *Set {
} }
} }
// NewUnmanagedSet returns an unmanaged Set, which can put values with different types. // NewUnmanagedSet returns a unmanaged Set, which can put values with different types.
func NewUnmanagedSet() *Set { func NewUnmanagedSet() *Set {
return &Set{ return &Set{
data: make(map[interface{}]lang.PlaceholderType), data: make(map[interface{}]lang.PlaceholderType),
@@ -213,23 +213,23 @@ func (s *Set) validate(i interface{}) {
switch i.(type) { switch i.(type) {
case int: case int:
if s.tp != intType { if s.tp != intType {
logx.Errorf("element is int, but set contains elements with type %d", s.tp) logx.Errorf("Error: element is int, but set contains elements with type %d", s.tp)
} }
case int64: case int64:
if s.tp != int64Type { if s.tp != int64Type {
logx.Errorf("element is int64, but set contains elements with type %d", s.tp) logx.Errorf("Error: element is int64, but set contains elements with type %d", s.tp)
} }
case uint: case uint:
if s.tp != uintType { if s.tp != uintType {
logx.Errorf("element is uint, but set contains elements with type %d", s.tp) logx.Errorf("Error: element is uint, but set contains elements with type %d", s.tp)
} }
case uint64: case uint64:
if s.tp != uint64Type { if s.tp != uint64Type {
logx.Errorf("element is uint64, but set contains elements with type %d", s.tp) logx.Errorf("Error: element is uint64, but set contains elements with type %d", s.tp)
} }
case string: case string:
if s.tp != stringType { if s.tp != stringType {
logx.Errorf("element is string, but set contains elements with type %d", s.tp) logx.Errorf("Error: element is string, but set contains elements with type %d", s.tp)
} }
} }
} }

View File

@@ -5,7 +5,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/logx" "github.com/tal-tech/go-zero/core/logx"
) )
func init() { func init() {

View File

@@ -2,22 +2,16 @@ package collection
import ( import (
"container/list" "container/list"
"errors"
"fmt" "fmt"
"time" "time"
"github.com/zeromicro/go-zero/core/lang" "github.com/tal-tech/go-zero/core/lang"
"github.com/zeromicro/go-zero/core/threading" "github.com/tal-tech/go-zero/core/threading"
"github.com/zeromicro/go-zero/core/timex" "github.com/tal-tech/go-zero/core/timex"
) )
const drainWorkers = 8 const drainWorkers = 8
var (
ErrClosed = errors.New("TimingWheel is closed already")
ErrArgument = errors.New("incorrect task argument")
)
type ( type (
// Execute defines the method to execute the task. // Execute defines the method to execute the task.
Execute func(key, value interface{}) Execute func(key, value interface{})
@@ -65,16 +59,14 @@ type (
// NewTimingWheel returns a TimingWheel. // NewTimingWheel returns a TimingWheel.
func NewTimingWheel(interval time.Duration, numSlots int, execute Execute) (*TimingWheel, error) { func NewTimingWheel(interval time.Duration, numSlots int, execute Execute) (*TimingWheel, error) {
if interval <= 0 || numSlots <= 0 || execute == nil { if interval <= 0 || numSlots <= 0 || execute == nil {
return nil, fmt.Errorf("interval: %v, slots: %d, execute: %p", return nil, fmt.Errorf("interval: %v, slots: %d, execute: %p", interval, numSlots, execute)
interval, numSlots, execute)
} }
return NewTimingWheelWithTicker(interval, numSlots, execute, timex.NewTicker(interval)) return newTimingWheelWithClock(interval, numSlots, execute, timex.NewTicker(interval))
} }
// NewTimingWheelWithTicker returns a TimingWheel with the given ticker. func newTimingWheelWithClock(interval time.Duration, numSlots int, execute Execute, ticker timex.Ticker) (
func NewTimingWheelWithTicker(interval time.Duration, numSlots int, execute Execute, *TimingWheel, error) {
ticker timex.Ticker) (*TimingWheel, error) {
tw := &TimingWheel{ tw := &TimingWheel{
interval: interval, interval: interval,
ticker: ticker, ticker: ticker,
@@ -97,67 +89,47 @@ func NewTimingWheelWithTicker(interval time.Duration, numSlots int, execute Exec
} }
// Drain drains all items and executes them. // Drain drains all items and executes them.
func (tw *TimingWheel) Drain(fn func(key, value interface{})) error { func (tw *TimingWheel) Drain(fn func(key, value interface{})) {
select { tw.drainChannel <- fn
case tw.drainChannel <- fn:
return nil
case <-tw.stopChannel:
return ErrClosed
}
} }
// MoveTimer moves the task with the given key to the given delay. // MoveTimer moves the task with the given key to the given delay.
func (tw *TimingWheel) MoveTimer(key interface{}, delay time.Duration) error { func (tw *TimingWheel) MoveTimer(key interface{}, delay time.Duration) {
if delay <= 0 || key == nil { if delay <= 0 || key == nil {
return ErrArgument return
} }
select { tw.moveChannel <- baseEntry{
case tw.moveChannel <- baseEntry{
delay: delay, delay: delay,
key: key, key: key,
}:
return nil
case <-tw.stopChannel:
return ErrClosed
} }
} }
// RemoveTimer removes the task with the given key. // RemoveTimer removes the task with the given key.
func (tw *TimingWheel) RemoveTimer(key interface{}) error { func (tw *TimingWheel) RemoveTimer(key interface{}) {
if key == nil { if key == nil {
return ErrArgument return
} }
select { tw.removeChannel <- key
case tw.removeChannel <- key:
return nil
case <-tw.stopChannel:
return ErrClosed
}
} }
// SetTimer sets the task value with the given key to the delay. // SetTimer sets the task value with the given key to the delay.
func (tw *TimingWheel) SetTimer(key, value interface{}, delay time.Duration) error { func (tw *TimingWheel) SetTimer(key, value interface{}, delay time.Duration) {
if delay <= 0 || key == nil { if delay <= 0 || key == nil {
return ErrArgument return
} }
select { tw.setChannel <- timingEntry{
case tw.setChannel <- timingEntry{
baseEntry: baseEntry{ baseEntry: baseEntry{
delay: delay, delay: delay,
key: key, key: key,
}, },
value: value, value: value,
}:
return nil
case <-tw.stopChannel:
return ErrClosed
} }
} }
// Stop stops tw. No more actions after stopping a TimingWheel. // Stop stops tw.
func (tw *TimingWheel) Stop() { func (tw *TimingWheel) Stop() {
close(tw.stopChannel) close(tw.stopChannel)
} }

View File

@@ -8,10 +8,10 @@ import (
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/lang" "github.com/tal-tech/go-zero/core/lang"
"github.com/zeromicro/go-zero/core/stringx" "github.com/tal-tech/go-zero/core/stringx"
"github.com/zeromicro/go-zero/core/syncx" "github.com/tal-tech/go-zero/core/syncx"
"github.com/zeromicro/go-zero/core/timex" "github.com/tal-tech/go-zero/core/timex"
) )
const ( const (
@@ -26,8 +26,9 @@ func TestNewTimingWheel(t *testing.T) {
func TestTimingWheel_Drain(t *testing.T) { func TestTimingWheel_Drain(t *testing.T) {
ticker := timex.NewFakeTicker() ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) { tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {
}, ticker) }, ticker)
defer tw.Stop()
tw.SetTimer("first", 3, testStep*4) tw.SetTimer("first", 3, testStep*4)
tw.SetTimer("second", 5, testStep*7) tw.SetTimer("second", 5, testStep*7)
tw.SetTimer("third", 7, testStep*7) tw.SetTimer("third", 7, testStep*7)
@@ -55,14 +56,12 @@ func TestTimingWheel_Drain(t *testing.T) {
}) })
time.Sleep(time.Millisecond * 100) time.Sleep(time.Millisecond * 100)
assert.Equal(t, 0, count) assert.Equal(t, 0, count)
tw.Stop()
assert.Equal(t, ErrClosed, tw.Drain(func(key, value interface{}) {}))
} }
func TestTimingWheel_SetTimerSoon(t *testing.T) { func TestTimingWheel_SetTimerSoon(t *testing.T) {
run := syncx.NewAtomicBool() run := syncx.NewAtomicBool()
ticker := timex.NewFakeTicker() ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) { tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {
assert.True(t, run.CompareAndSwap(false, true)) assert.True(t, run.CompareAndSwap(false, true))
assert.Equal(t, "any", k) assert.Equal(t, "any", k)
assert.Equal(t, 3, v.(int)) assert.Equal(t, 3, v.(int))
@@ -78,7 +77,7 @@ func TestTimingWheel_SetTimerSoon(t *testing.T) {
func TestTimingWheel_SetTimerTwice(t *testing.T) { func TestTimingWheel_SetTimerTwice(t *testing.T) {
run := syncx.NewAtomicBool() run := syncx.NewAtomicBool()
ticker := timex.NewFakeTicker() ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) { tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {
assert.True(t, run.CompareAndSwap(false, true)) assert.True(t, run.CompareAndSwap(false, true))
assert.Equal(t, "any", k) assert.Equal(t, "any", k)
assert.Equal(t, 5, v.(int)) assert.Equal(t, 5, v.(int))
@@ -96,29 +95,23 @@ func TestTimingWheel_SetTimerTwice(t *testing.T) {
func TestTimingWheel_SetTimerWrongDelay(t *testing.T) { func TestTimingWheel_SetTimerWrongDelay(t *testing.T) {
ticker := timex.NewFakeTicker() ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {}, ticker) tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {}, ticker)
defer tw.Stop() defer tw.Stop()
assert.NotPanics(t, func() { assert.NotPanics(t, func() {
tw.SetTimer("any", 3, -testStep) tw.SetTimer("any", 3, -testStep)
}) })
} }
func TestTimingWheel_SetTimerAfterClose(t *testing.T) {
ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {}, ticker)
tw.Stop()
assert.Equal(t, ErrClosed, tw.SetTimer("any", 3, testStep))
}
func TestTimingWheel_MoveTimer(t *testing.T) { func TestTimingWheel_MoveTimer(t *testing.T) {
run := syncx.NewAtomicBool() run := syncx.NewAtomicBool()
ticker := timex.NewFakeTicker() ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 3, func(k, v interface{}) { tw, _ := newTimingWheelWithClock(testStep, 3, func(k, v interface{}) {
assert.True(t, run.CompareAndSwap(false, true)) assert.True(t, run.CompareAndSwap(false, true))
assert.Equal(t, "any", k) assert.Equal(t, "any", k)
assert.Equal(t, 3, v.(int)) assert.Equal(t, 3, v.(int))
ticker.Done() ticker.Done()
}, ticker) }, ticker)
defer tw.Stop()
tw.SetTimer("any", 3, testStep*4) tw.SetTimer("any", 3, testStep*4)
tw.MoveTimer("any", testStep*7) tw.MoveTimer("any", testStep*7)
tw.MoveTimer("any", -testStep) tw.MoveTimer("any", -testStep)
@@ -132,14 +125,12 @@ func TestTimingWheel_MoveTimer(t *testing.T) {
} }
assert.Nil(t, ticker.Wait(waitTime)) assert.Nil(t, ticker.Wait(waitTime))
assert.True(t, run.True()) assert.True(t, run.True())
tw.Stop()
assert.Equal(t, ErrClosed, tw.MoveTimer("any", time.Millisecond))
} }
func TestTimingWheel_MoveTimerSoon(t *testing.T) { func TestTimingWheel_MoveTimerSoon(t *testing.T) {
run := syncx.NewAtomicBool() run := syncx.NewAtomicBool()
ticker := timex.NewFakeTicker() ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 3, func(k, v interface{}) { tw, _ := newTimingWheelWithClock(testStep, 3, func(k, v interface{}) {
assert.True(t, run.CompareAndSwap(false, true)) assert.True(t, run.CompareAndSwap(false, true))
assert.Equal(t, "any", k) assert.Equal(t, "any", k)
assert.Equal(t, 3, v.(int)) assert.Equal(t, 3, v.(int))
@@ -155,7 +146,7 @@ func TestTimingWheel_MoveTimerSoon(t *testing.T) {
func TestTimingWheel_MoveTimerEarlier(t *testing.T) { func TestTimingWheel_MoveTimerEarlier(t *testing.T) {
run := syncx.NewAtomicBool() run := syncx.NewAtomicBool()
ticker := timex.NewFakeTicker() ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) { tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {
assert.True(t, run.CompareAndSwap(false, true)) assert.True(t, run.CompareAndSwap(false, true))
assert.Equal(t, "any", k) assert.Equal(t, "any", k)
assert.Equal(t, 3, v.(int)) assert.Equal(t, 3, v.(int))
@@ -173,7 +164,7 @@ func TestTimingWheel_MoveTimerEarlier(t *testing.T) {
func TestTimingWheel_RemoveTimer(t *testing.T) { func TestTimingWheel_RemoveTimer(t *testing.T) {
ticker := timex.NewFakeTicker() ticker := timex.NewFakeTicker()
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {}, ticker) tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {}, ticker)
tw.SetTimer("any", 3, testStep) tw.SetTimer("any", 3, testStep)
assert.NotPanics(t, func() { assert.NotPanics(t, func() {
tw.RemoveTimer("any") tw.RemoveTimer("any")
@@ -184,7 +175,6 @@ func TestTimingWheel_RemoveTimer(t *testing.T) {
ticker.Tick() ticker.Tick()
} }
tw.Stop() tw.Stop()
assert.Equal(t, ErrClosed, tw.RemoveTimer("any"))
} }
func TestTimingWheel_SetTimer(t *testing.T) { func TestTimingWheel_SetTimer(t *testing.T) {
@@ -236,7 +226,7 @@ func TestTimingWheel_SetTimer(t *testing.T) {
} }
var actual int32 var actual int32
done := make(chan lang.PlaceholderType) done := make(chan lang.PlaceholderType)
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value interface{}) { tw, err := newTimingWheelWithClock(testStep, test.slots, func(key, value interface{}) {
assert.Equal(t, 1, key.(int)) assert.Equal(t, 1, key.(int))
assert.Equal(t, 2, value.(int)) assert.Equal(t, 2, value.(int))
actual = atomic.LoadInt32(&count) actual = atomic.LoadInt32(&count)
@@ -317,7 +307,7 @@ func TestTimingWheel_SetAndMoveThenStart(t *testing.T) {
} }
var actual int32 var actual int32
done := make(chan lang.PlaceholderType) done := make(chan lang.PlaceholderType)
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value interface{}) { tw, err := newTimingWheelWithClock(testStep, test.slots, func(key, value interface{}) {
actual = atomic.LoadInt32(&count) actual = atomic.LoadInt32(&count)
close(done) close(done)
}, ticker) }, ticker)
@@ -405,7 +395,7 @@ func TestTimingWheel_SetAndMoveTwice(t *testing.T) {
} }
var actual int32 var actual int32
done := make(chan lang.PlaceholderType) done := make(chan lang.PlaceholderType)
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value interface{}) { tw, err := newTimingWheelWithClock(testStep, test.slots, func(key, value interface{}) {
actual = atomic.LoadInt32(&count) actual = atomic.LoadInt32(&count)
close(done) close(done)
}, ticker) }, ticker)
@@ -486,7 +476,7 @@ func TestTimingWheel_ElapsedAndSet(t *testing.T) {
} }
var actual int32 var actual int32
done := make(chan lang.PlaceholderType) done := make(chan lang.PlaceholderType)
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value interface{}) { tw, err := newTimingWheelWithClock(testStep, test.slots, func(key, value interface{}) {
actual = atomic.LoadInt32(&count) actual = atomic.LoadInt32(&count)
close(done) close(done)
}, ticker) }, ticker)
@@ -577,7 +567,7 @@ func TestTimingWheel_ElapsedAndSetThenMove(t *testing.T) {
} }
var actual int32 var actual int32
done := make(chan lang.PlaceholderType) done := make(chan lang.PlaceholderType)
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value interface{}) { tw, err := newTimingWheelWithClock(testStep, test.slots, func(key, value interface{}) {
actual = atomic.LoadInt32(&count) actual = atomic.LoadInt32(&count)
close(done) close(done)
}, ticker) }, ticker)
@@ -612,7 +602,7 @@ func TestMoveAndRemoveTask(t *testing.T) {
} }
} }
var keys []int var keys []int
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) { tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {
assert.Equal(t, "any", k) assert.Equal(t, "any", k)
assert.Equal(t, 3, v.(int)) assert.Equal(t, 3, v.(int))
keys = append(keys, v.(int)) keys = append(keys, v.(int))

View File

@@ -1,73 +0,0 @@
package color
import "github.com/fatih/color"
const (
// NoColor is no color for both foreground and background.
NoColor Color = iota
// FgBlack is the foreground color black.
FgBlack
// FgRed is the foreground color red.
FgRed
// FgGreen is the foreground color green.
FgGreen
// FgYellow is the foreground color yellow.
FgYellow
// FgBlue is the foreground color blue.
FgBlue
// FgMagenta is the foreground color magenta.
FgMagenta
// FgCyan is the foreground color cyan.
FgCyan
// FgWhite is the foreground color white.
FgWhite
// BgBlack is the background color black.
BgBlack
// BgRed is the background color red.
BgRed
// BgGreen is the background color green.
BgGreen
// BgYellow is the background color yellow.
BgYellow
// BgBlue is the background color blue.
BgBlue
// BgMagenta is the background color magenta.
BgMagenta
// BgCyan is the background color cyan.
BgCyan
// BgWhite is the background color white.
BgWhite
)
var colors = map[Color][]color.Attribute{
FgBlack: {color.FgBlack, color.Bold},
FgRed: {color.FgRed, color.Bold},
FgGreen: {color.FgGreen, color.Bold},
FgYellow: {color.FgYellow, color.Bold},
FgBlue: {color.FgBlue, color.Bold},
FgMagenta: {color.FgMagenta, color.Bold},
FgCyan: {color.FgCyan, color.Bold},
FgWhite: {color.FgWhite, color.Bold},
BgBlack: {color.BgBlack, color.FgHiWhite, color.Bold},
BgRed: {color.BgRed, color.FgHiWhite, color.Bold},
BgGreen: {color.BgGreen, color.FgHiWhite, color.Bold},
BgYellow: {color.BgHiYellow, color.FgHiBlack, color.Bold},
BgBlue: {color.BgBlue, color.FgHiWhite, color.Bold},
BgMagenta: {color.BgMagenta, color.FgHiWhite, color.Bold},
BgCyan: {color.BgCyan, color.FgHiWhite, color.Bold},
BgWhite: {color.BgHiWhite, color.FgHiBlack, color.Bold},
}
type Color uint32
// WithColor returns a string with the given color applied.
func WithColor(text string, colour Color) string {
c := color.New(colors[colour]...)
return c.Sprint(text)
}
// WithColorPadding returns a string with the given color applied with leading and trailing spaces.
func WithColorPadding(text string, colour Color) string {
return WithColor(" "+text+" ", colour)
}

View File

@@ -1,17 +0,0 @@
package color
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestWithColor(t *testing.T) {
output := WithColor("Hello", BgRed)
assert.Equal(t, "Hello", output)
}
func TestWithColorPadding(t *testing.T) {
output := WithColorPadding("Hello", BgRed)
assert.Equal(t, " Hello ", output)
}

View File

@@ -2,50 +2,28 @@ package conf
import ( import (
"fmt" "fmt"
"io/ioutil"
"log" "log"
"os" "os"
"path" "path"
"reflect"
"strings"
"github.com/zeromicro/go-zero/core/jsonx" "github.com/tal-tech/go-zero/core/mapping"
"github.com/zeromicro/go-zero/core/mapping"
"github.com/zeromicro/go-zero/internal/encoding"
) )
const jsonTagKey = "json" var loaders = map[string]func([]byte, interface{}) error{
".json": LoadConfigFromJsonBytes,
var ( ".yaml": LoadConfigFromYamlBytes,
fillDefaultUnmarshaler = mapping.NewUnmarshaler(jsonTagKey, mapping.WithDefault()) ".yml": LoadConfigFromYamlBytes,
loaders = map[string]func([]byte, interface{}) error{
".json": LoadFromJsonBytes,
".toml": LoadFromTomlBytes,
".yaml": LoadFromYamlBytes,
".yml": LoadFromYamlBytes,
}
)
// children and mapField should not be both filled.
// named fields and map cannot be bound to the same field name.
type fieldInfo struct {
children map[string]*fieldInfo
mapField *fieldInfo
} }
// FillDefault fills the default values for the given v, // LoadConfig loads config into v from file, .json, .yaml and .yml are acceptable.
// and the premise is that the value of v must be guaranteed to be empty. func LoadConfig(file string, v interface{}, opts ...Option) error {
func FillDefault(v interface{}) error { content, err := ioutil.ReadFile(file)
return fillDefaultUnmarshaler.Unmarshal(map[string]interface{}{}, v)
}
// Load loads config into v from file, .json, .yaml and .yml are acceptable.
func Load(file string, v interface{}, opts ...Option) error {
content, err := os.ReadFile(file)
if err != nil { if err != nil {
return err return err
} }
loader, ok := loaders[strings.ToLower(path.Ext(file))] loader, ok := loaders[path.Ext(file)]
if !ok { if !ok {
return fmt.Errorf("unrecognized file type: %s", file) return fmt.Errorf("unrecognized file type: %s", file)
} }
@@ -62,266 +40,19 @@ func Load(file string, v interface{}, opts ...Option) error {
return loader(content, v) return loader(content, v)
} }
// LoadConfig loads config into v from file, .json, .yaml and .yml are acceptable.
// Deprecated: use Load instead.
func LoadConfig(file string, v interface{}, opts ...Option) error {
return Load(file, v, opts...)
}
// LoadFromJsonBytes loads config into v from content json bytes.
func LoadFromJsonBytes(content []byte, v interface{}) error {
info, err := buildFieldsInfo(reflect.TypeOf(v))
if err != nil {
return err
}
var m map[string]interface{}
if err := jsonx.Unmarshal(content, &m); err != nil {
return err
}
lowerCaseKeyMap := toLowerCaseKeyMap(m, info)
return mapping.UnmarshalJsonMap(lowerCaseKeyMap, v, mapping.WithCanonicalKeyFunc(toLowerCase))
}
// LoadConfigFromJsonBytes loads config into v from content json bytes. // LoadConfigFromJsonBytes loads config into v from content json bytes.
// Deprecated: use LoadFromJsonBytes instead.
func LoadConfigFromJsonBytes(content []byte, v interface{}) error { func LoadConfigFromJsonBytes(content []byte, v interface{}) error {
return LoadFromJsonBytes(content, v) return mapping.UnmarshalJsonBytes(content, v)
}
// LoadFromTomlBytes loads config into v from content toml bytes.
func LoadFromTomlBytes(content []byte, v interface{}) error {
b, err := encoding.TomlToJson(content)
if err != nil {
return err
}
return LoadFromJsonBytes(b, v)
}
// LoadFromYamlBytes loads config into v from content yaml bytes.
func LoadFromYamlBytes(content []byte, v interface{}) error {
b, err := encoding.YamlToJson(content)
if err != nil {
return err
}
return LoadFromJsonBytes(b, v)
} }
// LoadConfigFromYamlBytes loads config into v from content yaml bytes. // LoadConfigFromYamlBytes loads config into v from content yaml bytes.
// Deprecated: use LoadFromYamlBytes instead.
func LoadConfigFromYamlBytes(content []byte, v interface{}) error { func LoadConfigFromYamlBytes(content []byte, v interface{}) error {
return LoadFromYamlBytes(content, v) return mapping.UnmarshalYamlBytes(content, v)
} }
// MustLoad loads config into v from path, exits on error. // MustLoad loads config into v from path, exits on error.
func MustLoad(path string, v interface{}, opts ...Option) { func MustLoad(path string, v interface{}, opts ...Option) {
if err := Load(path, v, opts...); err != nil { if err := LoadConfig(path, v, opts...); err != nil {
log.Fatalf("error: config file %s, %s", path, err.Error()) log.Fatalf("error: config file %s, %s", path, err.Error())
} }
} }
func addOrMergeFields(info *fieldInfo, key string, child *fieldInfo) error {
if prev, ok := info.children[key]; ok {
if child.mapField != nil {
return newDupKeyError(key)
}
if err := mergeFields(prev, key, child.children); err != nil {
return err
}
} else {
info.children[key] = child
}
return nil
}
func buildAnonymousFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type) error {
switch ft.Kind() {
case reflect.Struct:
fields, err := buildFieldsInfo(ft)
if err != nil {
return err
}
for k, v := range fields.children {
if err = addOrMergeFields(info, k, v); err != nil {
return err
}
}
case reflect.Map:
elemField, err := buildFieldsInfo(mapping.Deref(ft.Elem()))
if err != nil {
return err
}
if _, ok := info.children[lowerCaseName]; ok {
return newDupKeyError(lowerCaseName)
}
info.children[lowerCaseName] = &fieldInfo{
children: make(map[string]*fieldInfo),
mapField: elemField,
}
default:
if _, ok := info.children[lowerCaseName]; ok {
return newDupKeyError(lowerCaseName)
}
info.children[lowerCaseName] = &fieldInfo{
children: make(map[string]*fieldInfo),
}
}
return nil
}
func buildFieldsInfo(tp reflect.Type) (*fieldInfo, error) {
tp = mapping.Deref(tp)
switch tp.Kind() {
case reflect.Struct:
return buildStructFieldsInfo(tp)
case reflect.Array, reflect.Slice:
return buildFieldsInfo(mapping.Deref(tp.Elem()))
case reflect.Chan, reflect.Func:
return nil, fmt.Errorf("unsupported type: %s", tp.Kind())
default:
return &fieldInfo{
children: make(map[string]*fieldInfo),
}, nil
}
}
func buildNamedFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type) error {
var finfo *fieldInfo
var err error
switch ft.Kind() {
case reflect.Struct:
finfo, err = buildFieldsInfo(ft)
if err != nil {
return err
}
case reflect.Array, reflect.Slice:
finfo, err = buildFieldsInfo(ft.Elem())
if err != nil {
return err
}
case reflect.Map:
elemInfo, err := buildFieldsInfo(mapping.Deref(ft.Elem()))
if err != nil {
return err
}
finfo = &fieldInfo{
children: make(map[string]*fieldInfo),
mapField: elemInfo,
}
default:
finfo, err = buildFieldsInfo(ft)
if err != nil {
return err
}
}
return addOrMergeFields(info, lowerCaseName, finfo)
}
func buildStructFieldsInfo(tp reflect.Type) (*fieldInfo, error) {
info := &fieldInfo{
children: make(map[string]*fieldInfo),
}
for i := 0; i < tp.NumField(); i++ {
field := tp.Field(i)
name := field.Name
lowerCaseName := toLowerCase(name)
ft := mapping.Deref(field.Type)
// flatten anonymous fields
if field.Anonymous {
if err := buildAnonymousFieldInfo(info, lowerCaseName, ft); err != nil {
return nil, err
}
} else if err := buildNamedFieldInfo(info, lowerCaseName, ft); err != nil {
return nil, err
}
}
return info, nil
}
func mergeFields(prev *fieldInfo, key string, children map[string]*fieldInfo) error {
if len(prev.children) == 0 || len(children) == 0 {
return newDupKeyError(key)
}
// merge fields
for k, v := range children {
if _, ok := prev.children[k]; ok {
return newDupKeyError(k)
}
prev.children[k] = v
}
return nil
}
func toLowerCase(s string) string {
return strings.ToLower(s)
}
func toLowerCaseInterface(v interface{}, info *fieldInfo) interface{} {
switch vv := v.(type) {
case map[string]interface{}:
return toLowerCaseKeyMap(vv, info)
case []interface{}:
var arr []interface{}
for _, vvv := range vv {
arr = append(arr, toLowerCaseInterface(vvv, info))
}
return arr
default:
return v
}
}
func toLowerCaseKeyMap(m map[string]interface{}, info *fieldInfo) map[string]interface{} {
res := make(map[string]interface{})
for k, v := range m {
ti, ok := info.children[k]
if ok {
res[k] = toLowerCaseInterface(v, ti)
continue
}
lk := toLowerCase(k)
if ti, ok = info.children[lk]; ok {
res[lk] = toLowerCaseInterface(v, ti)
} else if info.mapField != nil {
res[k] = toLowerCaseInterface(v, info.mapField)
} else {
res[k] = v
}
}
return res
}
type dupKeyError struct {
key string
}
func newDupKeyError(key string) dupKeyError {
return dupKeyError{key: key}
}
func (e dupKeyError) Error() string {
return fmt.Sprintf("duplicated key %s", e.key)
}

File diff suppressed because it is too large Load Diff

View File

@@ -7,11 +7,12 @@ import (
"strings" "strings"
"sync" "sync"
"github.com/zeromicro/go-zero/core/iox" "github.com/tal-tech/go-zero/core/iox"
) )
// PropertyError represents a configuration error message. // PropertyError represents a configuration error message.
type PropertyError struct { type PropertyError struct {
error
message string message string
} }

View File

@@ -5,7 +5,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/fs" "github.com/tal-tech/go-zero/core/fs"
) )
func TestProperties(t *testing.T) { func TestProperties(t *testing.T) {

View File

@@ -1,58 +0,0 @@
## How to use
1. Define a config structure, like below:
```go
type RestfulConf struct {
ServiceName string `json:",env=SERVICE_NAME"` // read from env automatically
Host string `json:",default=0.0.0.0"`
Port int
LogMode string `json:",options=[file,console]"`
Verbose bool `json:",optional"`
MaxConns int `json:",default=10000"`
MaxBytes int64 `json:",default=1048576"`
Timeout time.Duration `json:",default=3s"`
CpuThreshold int64 `json:",default=900,range=[0:1000]"`
}
```
2. Write the yaml, toml or json config file:
- yaml example
```yaml
# most fields are optional or have default values
port: 8080
logMode: console
# you can use env settings
maxBytes: ${MAX_BYTES}
```
- toml example
```toml
# most fields are optional or have default values
port = 8_080
logMode = "console"
# you can use env settings
maxBytes = "${MAX_BYTES}"
```
3. Load the config from a file:
```go
// exit on error
var config RestfulConf
conf.MustLoad(configFile, &config)
// or handle the error on your own
var config RestfulConf
if err := conf.Load(configFile, &config); err != nil {
log.Fatal(err)
}
// enable reading from environments
var config RestfulConf
conf.MustLoad(configFile, &config, conf.UseEnv())
```

View File

@@ -3,7 +3,7 @@ package contextx
import ( import (
"context" "context"
"github.com/zeromicro/go-zero/core/mapping" "github.com/tal-tech/go-zero/core/mapping"
) )
const contextTagKey = "ctx" const contextTagKey = "ctx"

View File

@@ -1,6 +1,6 @@
package discov package discov
import "github.com/zeromicro/go-zero/core/discov/internal" import "github.com/tal-tech/go-zero/core/discov/internal"
// RegisterAccount registers the username/password to the given etcd cluster. // RegisterAccount registers the username/password to the given etcd cluster.
func RegisterAccount(endpoints []string, user, pass string) { func RegisterAccount(endpoints []string, user, pass string) {

View File

@@ -4,8 +4,8 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/discov/internal" "github.com/tal-tech/go-zero/core/discov/internal"
"github.com/zeromicro/go-zero/core/stringx" "github.com/tal-tech/go-zero/core/stringx"
) )
func TestRegisterAccount(t *testing.T) { func TestRegisterAccount(t *testing.T) {

View File

@@ -4,7 +4,7 @@ import (
"fmt" "fmt"
"strings" "strings"
"github.com/zeromicro/go-zero/core/discov/internal" "github.com/tal-tech/go-zero/core/discov/internal"
) )
const ( const (

View File

@@ -5,7 +5,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/discov/internal" "github.com/tal-tech/go-zero/core/discov/internal"
) )
var mockLock sync.Mutex var mockLock sync.Mutex

View File

@@ -2,13 +2,6 @@ package discov
import "errors" import "errors"
var (
// errEmptyEtcdHosts indicates that etcd hosts are empty.
errEmptyEtcdHosts = errors.New("empty etcd hosts")
// errEmptyEtcdKey indicates that etcd key is empty.
errEmptyEtcdKey = errors.New("empty etcd key")
)
// EtcdConf is the config item with the given key on etcd. // EtcdConf is the config item with the given key on etcd.
type EtcdConf struct { type EtcdConf struct {
Hosts []string Hosts []string
@@ -34,9 +27,9 @@ func (c EtcdConf) HasTLS() bool {
// Validate validates c. // Validate validates c.
func (c EtcdConf) Validate() error { func (c EtcdConf) Validate() error {
if len(c.Hosts) == 0 { if len(c.Hosts) == 0 {
return errEmptyEtcdHosts return errors.New("empty etcd hosts")
} else if len(c.Key) == 0 { } else if len(c.Key) == 0 {
return errEmptyEtcdKey return errors.New("empty etcd key")
} else { } else {
return nil return nil
} }

View File

@@ -3,7 +3,7 @@ package internal
import ( import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"os" "io/ioutil"
"sync" "sync"
) )
@@ -37,7 +37,7 @@ func AddTLS(endpoints []string, certFile, certKeyFile, caFile string, insecureSk
return err return err
} }
caData, err := os.ReadFile(caFile) caData, err := ioutil.ReadFile(caFile)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -4,7 +4,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/stringx" "github.com/tal-tech/go-zero/core/stringx"
) )
func TestAccount(t *testing.T) { func TestAccount(t *testing.T) {

View File

@@ -9,11 +9,11 @@ import (
"sync" "sync"
"time" "time"
"github.com/zeromicro/go-zero/core/contextx" "github.com/tal-tech/go-zero/core/contextx"
"github.com/zeromicro/go-zero/core/lang" "github.com/tal-tech/go-zero/core/lang"
"github.com/zeromicro/go-zero/core/logx" "github.com/tal-tech/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/syncx" "github.com/tal-tech/go-zero/core/syncx"
"github.com/zeromicro/go-zero/core/threading" "github.com/tal-tech/go-zero/core/threading"
clientv3 "go.etcd.io/etcd/client/v3" clientv3 "go.etcd.io/etcd/client/v3"
) )
@@ -191,11 +191,9 @@ func (c *cluster) handleWatchEvents(key string, events []*clientv3.Event) {
}) })
} }
case clientv3.EventTypeDelete: case clientv3.EventTypeDelete:
c.lock.Lock()
if vals, ok := c.values[key]; ok { if vals, ok := c.values[key]; ok {
delete(vals, string(ev.Kv.Key)) delete(vals, string(ev.Kv.Key))
} }
c.lock.Unlock()
for _, l := range listeners { for _, l := range listeners {
l.OnDelete(KV{ l.OnDelete(KV{
Key: string(ev.Kv.Key), Key: string(ev.Kv.Key),
@@ -208,7 +206,7 @@ func (c *cluster) handleWatchEvents(key string, events []*clientv3.Event) {
} }
} }
func (c *cluster) load(cli EtcdClient, key string) int64 { func (c *cluster) load(cli EtcdClient, key string) {
var resp *clientv3.GetResponse var resp *clientv3.GetResponse
for { for {
var err error var err error
@@ -232,8 +230,6 @@ func (c *cluster) load(cli EtcdClient, key string) int64 {
} }
c.handleChanges(key, kvs) c.handleChanges(key, kvs)
return resp.Header.Revision
} }
func (c *cluster) monitor(key string, l UpdateListener) error { func (c *cluster) monitor(key string, l UpdateListener) error {
@@ -246,9 +242,9 @@ func (c *cluster) monitor(key string, l UpdateListener) error {
return err return err
} }
rev := c.load(cli, key) c.load(cli, key)
c.watchGroup.Run(func() { c.watchGroup.Run(func() {
c.watch(cli, key, rev) c.watch(cli, key)
}) })
return nil return nil
@@ -280,29 +276,22 @@ func (c *cluster) reload(cli EtcdClient) {
for _, key := range keys { for _, key := range keys {
k := key k := key
c.watchGroup.Run(func() { c.watchGroup.Run(func() {
rev := c.load(cli, k) c.load(cli, k)
c.watch(cli, k, rev) c.watch(cli, k)
}) })
} }
} }
func (c *cluster) watch(cli EtcdClient, key string, rev int64) { func (c *cluster) watch(cli EtcdClient, key string) {
for { for {
if c.watchStream(cli, key, rev) { if c.watchStream(cli, key) {
return return
} }
} }
} }
func (c *cluster) watchStream(cli EtcdClient, key string, rev int64) bool { func (c *cluster) watchStream(cli EtcdClient, key string) bool {
var rch clientv3.WatchChan rch := cli.Watch(clientv3.WithRequireLeader(c.context(cli)), makeKeyPrefix(key), clientv3.WithPrefix())
if rev != 0 {
rch = cli.Watch(clientv3.WithRequireLeader(c.context(cli)), makeKeyPrefix(key), clientv3.WithPrefix(),
clientv3.WithRev(rev+1))
} else {
rch = cli.Watch(clientv3.WithRequireLeader(c.context(cli)), makeKeyPrefix(key), clientv3.WithPrefix())
}
for { for {
select { select {
case wresp, ok := <-rch: case wresp, ok := <-rch:
@@ -343,7 +332,6 @@ func DialClient(endpoints []string) (EtcdClient, error) {
DialKeepAliveTime: dialKeepAliveTime, DialKeepAliveTime: dialKeepAliveTime,
DialKeepAliveTimeout: DialTimeout, DialKeepAliveTimeout: DialTimeout,
RejectOldCluster: true, RejectOldCluster: true,
PermitWithoutStream: true,
} }
if account, ok := GetAccount(endpoints); ok { if account, ok := GetAccount(endpoints); ok {
cfg.Username = account.User cfg.Username = account.User

View File

@@ -7,11 +7,10 @@ import (
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/contextx" "github.com/tal-tech/go-zero/core/contextx"
"github.com/zeromicro/go-zero/core/lang" "github.com/tal-tech/go-zero/core/lang"
"github.com/zeromicro/go-zero/core/logx" "github.com/tal-tech/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stringx" "github.com/tal-tech/go-zero/core/stringx"
"go.etcd.io/etcd/api/v3/etcdserverpb"
"go.etcd.io/etcd/api/v3/mvccpb" "go.etcd.io/etcd/api/v3/mvccpb"
clientv3 "go.etcd.io/etcd/client/v3" clientv3 "go.etcd.io/etcd/client/v3"
) )
@@ -113,7 +112,6 @@ func TestCluster_Load(t *testing.T) {
restore := setMockClient(cli) restore := setMockClient(cli)
defer restore() defer restore()
cli.EXPECT().Get(gomock.Any(), "any/", gomock.Any()).Return(&clientv3.GetResponse{ cli.EXPECT().Get(gomock.Any(), "any/", gomock.Any()).Return(&clientv3.GetResponse{
Header: &etcdserverpb.ResponseHeader{},
Kvs: []*mvccpb.KeyValue{ Kvs: []*mvccpb.KeyValue{
{ {
Key: []byte("hello"), Key: []byte("hello"),
@@ -170,7 +168,7 @@ func TestCluster_Watch(t *testing.T) {
listener.EXPECT().OnDelete(gomock.Any()).Do(func(_ interface{}) { listener.EXPECT().OnDelete(gomock.Any()).Do(func(_ interface{}) {
wg.Done() wg.Done()
}).MaxTimes(1) }).MaxTimes(1)
go c.watch(cli, "any", 0) go c.watch(cli, "any")
ch <- clientv3.WatchResponse{ ch <- clientv3.WatchResponse{
Events: []*clientv3.Event{ Events: []*clientv3.Event{
{ {
@@ -214,7 +212,7 @@ func TestClusterWatch_RespFailures(t *testing.T) {
ch <- resp ch <- resp
close(c.done) close(c.done)
}() }()
c.watch(cli, "any", 0) c.watch(cli, "any")
}) })
} }
} }
@@ -234,7 +232,7 @@ func TestClusterWatch_CloseChan(t *testing.T) {
close(ch) close(ch)
close(c.done) close(c.done)
}() }()
c.watch(cli, "any", 0) c.watch(cli, "any")
} }
func TestValueOnlyContext(t *testing.T) { func TestValueOnlyContext(t *testing.T) {

View File

@@ -1,14 +1,12 @@
package discov package discov
import ( import (
"time" "github.com/tal-tech/go-zero/core/discov/internal"
"github.com/tal-tech/go-zero/core/lang"
"github.com/zeromicro/go-zero/core/discov/internal" "github.com/tal-tech/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/lang" "github.com/tal-tech/go-zero/core/proc"
"github.com/zeromicro/go-zero/core/logx" "github.com/tal-tech/go-zero/core/syncx"
"github.com/zeromicro/go-zero/core/proc" "github.com/tal-tech/go-zero/core/threading"
"github.com/zeromicro/go-zero/core/syncx"
"github.com/zeromicro/go-zero/core/threading"
clientv3 "go.etcd.io/etcd/client/v3" clientv3 "go.etcd.io/etcd/client/v3"
) )
@@ -53,7 +51,12 @@ func NewPublisher(endpoints []string, key, value string, opts ...PubOption) *Pub
// KeepAlive keeps key:value alive. // KeepAlive keeps key:value alive.
func (p *Publisher) KeepAlive() error { func (p *Publisher) KeepAlive() error {
cli, err := p.doRegister() cli, err := internal.GetRegistry().GetConn(p.endpoints)
if err != nil {
return err
}
p.lease, err = p.register(cli)
if err != nil { if err != nil {
return err return err
} }
@@ -80,43 +83,6 @@ func (p *Publisher) Stop() {
p.quit.Close() p.quit.Close()
} }
func (p *Publisher) doKeepAlive() error {
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
for range ticker.C {
select {
case <-p.quit.Done():
return nil
default:
cli, err := p.doRegister()
if err != nil {
logx.Errorf("etcd publisher doRegister: %s", err.Error())
break
}
if err := p.keepAliveAsync(cli); err != nil {
logx.Errorf("etcd publisher keepAliveAsync: %s", err.Error())
break
}
return nil
}
}
return nil
}
func (p *Publisher) doRegister() (internal.EtcdClient, error) {
cli, err := internal.GetRegistry().GetConn(p.endpoints)
if err != nil {
return nil, err
}
p.lease, err = p.register(cli)
return cli, err
}
func (p *Publisher) keepAliveAsync(cli internal.EtcdClient) error { func (p *Publisher) keepAliveAsync(cli internal.EtcdClient) error {
ch, err := cli.KeepAlive(cli.Ctx(), p.lease) ch, err := cli.KeepAlive(cli.Ctx(), p.lease)
if err != nil { if err != nil {
@@ -129,8 +95,8 @@ func (p *Publisher) keepAliveAsync(cli internal.EtcdClient) error {
case _, ok := <-ch: case _, ok := <-ch:
if !ok { if !ok {
p.revoke(cli) p.revoke(cli)
if err := p.doKeepAlive(); err != nil { if err := p.KeepAlive(); err != nil {
logx.Errorf("etcd publisher KeepAlive: %s", err.Error()) logx.Errorf("KeepAlive: %s", err.Error())
} }
return return
} }
@@ -139,8 +105,8 @@ func (p *Publisher) keepAliveAsync(cli internal.EtcdClient) error {
p.revoke(cli) p.revoke(cli)
select { select {
case <-p.resumeChan: case <-p.resumeChan:
if err := p.doKeepAlive(); err != nil { if err := p.KeepAlive(); err != nil {
logx.Errorf("etcd publisher KeepAlive: %s", err.Error()) logx.Errorf("KeepAlive: %s", err.Error())
} }
return return
case <-p.quit.Done(): case <-p.quit.Done():
@@ -175,7 +141,7 @@ func (p *Publisher) register(client internal.EtcdClient) (clientv3.LeaseID, erro
func (p *Publisher) revoke(cli internal.EtcdClient) { func (p *Publisher) revoke(cli internal.EtcdClient) {
if _, err := cli.Revoke(cli.Ctx(), p.lease); err != nil { if _, err := cli.Revoke(cli.Ctx(), p.lease); err != nil {
logx.Errorf("etcd publisher revoke: %s", err.Error()) logx.Error(err)
} }
} }

View File

@@ -8,10 +8,10 @@ import (
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/discov/internal" "github.com/tal-tech/go-zero/core/discov/internal"
"github.com/zeromicro/go-zero/core/lang" "github.com/tal-tech/go-zero/core/lang"
"github.com/zeromicro/go-zero/core/logx" "github.com/tal-tech/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stringx" "github.com/tal-tech/go-zero/core/stringx"
clientv3 "go.etcd.io/etcd/client/v3" clientv3 "go.etcd.io/etcd/client/v3"
) )

View File

@@ -4,16 +4,16 @@ import (
"sync" "sync"
"sync/atomic" "sync/atomic"
"github.com/zeromicro/go-zero/core/discov/internal" "github.com/tal-tech/go-zero/core/discov/internal"
"github.com/zeromicro/go-zero/core/logx" "github.com/tal-tech/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/syncx" "github.com/tal-tech/go-zero/core/syncx"
) )
type ( type (
// SubOption defines the method to customize a Subscriber. // SubOption defines the method to customize a Subscriber.
SubOption func(sub *Subscriber) SubOption func(sub *Subscriber)
// A Subscriber is used to subscribe the given key on an etcd cluster. // A Subscriber is used to subscribe the given key on a etcd cluster.
Subscriber struct { Subscriber struct {
endpoints []string endpoints []string
exclusive bool exclusive bool

View File

@@ -5,8 +5,8 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/discov/internal" "github.com/tal-tech/go-zero/core/discov/internal"
"github.com/zeromicro/go-zero/core/stringx" "github.com/tal-tech/go-zero/core/stringx"
) )
const ( const (

View File

@@ -11,12 +11,10 @@ type (
errorArray []error errorArray []error
) )
// Add adds errs to be, nil errors are ignored. // Add adds err to be.
func (be *BatchError) Add(errs ...error) { func (be *BatchError) Add(err error) {
for _, err := range errs { if err != nil {
if err != nil { be.errs = append(be.errs, err)
be.errs = append(be.errs, err)
}
} }
} }

View File

@@ -1,21 +0,0 @@
package errorx
import "fmt"
// Wrap returns an error that wraps err with given message.
func Wrap(err error, message string) error {
if err == nil {
return nil
}
return fmt.Errorf("%s: %w", message, err)
}
// Wrapf returns an error that wraps err with given format and args.
func Wrapf(err error, format string, args ...interface{}) error {
if err == nil {
return nil
}
return fmt.Errorf("%s: %w", fmt.Sprintf(format, args...), err)
}

View File

@@ -1,24 +0,0 @@
package errorx
import (
"errors"
"testing"
"github.com/stretchr/testify/assert"
)
func TestWrap(t *testing.T) {
assert.Nil(t, Wrap(nil, "test"))
assert.Equal(t, "foo: bar", Wrap(errors.New("bar"), "foo").Error())
err := errors.New("foo")
assert.True(t, errors.Is(Wrap(err, "bar"), err))
}
func TestWrapf(t *testing.T) {
assert.Nil(t, Wrapf(nil, "%s", "test"))
assert.Equal(t, "foo bar: quz", Wrapf(errors.New("quz"), "foo %s", "bar").Error())
err := errors.New("foo")
assert.True(t, errors.Is(Wrapf(err, "foo %s", "bar"), err))
}

View File

@@ -53,11 +53,10 @@ func TestChunkExecutorFlushInterval(t *testing.T) {
} }
func TestChunkExecutorEmpty(t *testing.T) { func TestChunkExecutorEmpty(t *testing.T) {
executor := NewChunkExecutor(func(items []interface{}) { NewChunkExecutor(func(items []interface{}) {
assert.Fail(t, "should not called") assert.Fail(t, "should not called")
}, WithChunkBytes(10), WithFlushInterval(time.Millisecond)) }, WithChunkBytes(10), WithFlushInterval(time.Millisecond))
time.Sleep(time.Millisecond * 100) time.Sleep(time.Millisecond * 100)
executor.Wait()
} }
func TestChunkExecutorFlush(t *testing.T) { func TestChunkExecutorFlush(t *testing.T) {

View File

@@ -4,7 +4,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/zeromicro/go-zero/core/threading" "github.com/tal-tech/go-zero/core/threading"
) )
// A DelayExecutor delays a tasks on given delay interval. // A DelayExecutor delays a tasks on given delay interval.

View File

@@ -3,8 +3,8 @@ package executors
import ( import (
"time" "time"
"github.com/zeromicro/go-zero/core/syncx" "github.com/tal-tech/go-zero/core/syncx"
"github.com/zeromicro/go-zero/core/timex" "github.com/tal-tech/go-zero/core/timex"
) )
// A LessExecutor is an executor to limit execution once within given time interval. // A LessExecutor is an executor to limit execution once within given time interval.

View File

@@ -5,7 +5,7 @@ import (
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/timex" "github.com/tal-tech/go-zero/core/timex"
) )
func TestLessExecutor_DoOrDiscard(t *testing.T) { func TestLessExecutor_DoOrDiscard(t *testing.T) {

View File

@@ -6,11 +6,11 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/zeromicro/go-zero/core/lang" "github.com/tal-tech/go-zero/core/lang"
"github.com/zeromicro/go-zero/core/proc" "github.com/tal-tech/go-zero/core/proc"
"github.com/zeromicro/go-zero/core/syncx" "github.com/tal-tech/go-zero/core/syncx"
"github.com/zeromicro/go-zero/core/threading" "github.com/tal-tech/go-zero/core/threading"
"github.com/zeromicro/go-zero/core/timex" "github.com/tal-tech/go-zero/core/timex"
) )
const idleRound = 10 const idleRound = 10

View File

@@ -8,8 +8,7 @@ import (
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/proc" "github.com/tal-tech/go-zero/core/timex"
"github.com/zeromicro/go-zero/core/timex"
) )
const threshold = 10 const threshold = 10
@@ -68,7 +67,6 @@ func TestPeriodicalExecutor_QuitGoroutine(t *testing.T) {
ticker.Tick() ticker.Tick()
ticker.Wait(time.Millisecond * idleRound) ticker.Wait(time.Millisecond * idleRound)
assert.Equal(t, routines, runtime.NumGoroutine()) assert.Equal(t, routines, runtime.NumGoroutine())
proc.Shutdown()
} }
func TestPeriodicalExecutor_Bulk(t *testing.T) { func TestPeriodicalExecutor_Bulk(t *testing.T) {

View File

@@ -5,7 +5,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/fs" "github.com/tal-tech/go-zero/core/fs"
) )
const ( const (

View File

@@ -5,7 +5,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/fs" "github.com/tal-tech/go-zero/core/fs"
) )
func TestSplitLineChunks(t *testing.T) { func TestSplitLineChunks(t *testing.T) {

View File

@@ -5,9 +5,6 @@ import (
"os" "os"
) )
// errExceedFileSize indicates that the file size is exceeded.
var errExceedFileSize = errors.New("exceed file size")
// A RangeReader is used to read a range of content from a file. // A RangeReader is used to read a range of content from a file.
type RangeReader struct { type RangeReader struct {
file *os.File file *os.File
@@ -32,7 +29,7 @@ func (rr *RangeReader) Read(p []byte) (n int, err error) {
} }
if rr.stop < rr.start || rr.start >= stat.Size() { if rr.stop < rr.start || rr.start >= stat.Size() {
return 0, errExceedFileSize return 0, errors.New("exceed file size")
} }
if rr.stop-rr.start < int64(len(p)) { if rr.stop-rr.start < int64(len(p)) {

View File

@@ -5,7 +5,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/fs" "github.com/tal-tech/go-zero/core/fs"
) )
func TestRangeReader(t *testing.T) { func TestRangeReader(t *testing.T) {

View File

@@ -1,15 +0,0 @@
package fs
import (
"os"
"testing"
"github.com/stretchr/testify/assert"
)
func TestCloseOnExec(t *testing.T) {
file := os.NewFile(0, os.DevNull)
assert.NotPanics(t, func() {
CloseOnExec(file)
})
}

View File

@@ -1,9 +1,10 @@
package fs package fs
import ( import (
"io/ioutil"
"os" "os"
"github.com/zeromicro/go-zero/core/hash" "github.com/tal-tech/go-zero/core/hash"
) )
// TempFileWithText creates the temporary file with the given content, // TempFileWithText creates the temporary file with the given content,
@@ -11,12 +12,12 @@ import (
// The file is kept as open, the caller should close the file handle, // The file is kept as open, the caller should close the file handle,
// and remove the file by name. // and remove the file by name.
func TempFileWithText(text string) (*os.File, error) { func TempFileWithText(text string) (*os.File, error) {
tmpfile, err := os.CreateTemp(os.TempDir(), hash.Md5Hex([]byte(text))) tmpfile, err := ioutil.TempFile(os.TempDir(), hash.Md5Hex([]byte(text)))
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := os.WriteFile(tmpfile.Name(), []byte(text), os.ModeTemporary); err != nil { if err := ioutil.WriteFile(tmpfile.Name(), []byte(text), os.ModeTemporary); err != nil {
return nil, err return nil, err
} }

View File

@@ -1,49 +0,0 @@
package fs
import (
"io"
"os"
"testing"
"github.com/stretchr/testify/assert"
)
func TestTempFileWithText(t *testing.T) {
f, err := TempFileWithText("test")
if err != nil {
t.Error(err)
}
if f == nil {
t.Error("TempFileWithText returned nil")
}
if f.Name() == "" {
t.Error("TempFileWithText returned empty file name")
}
defer os.Remove(f.Name())
bs, err := io.ReadAll(f)
assert.Nil(t, err)
if len(bs) != 4 {
t.Error("TempFileWithText returned wrong file size")
}
if f.Close() != nil {
t.Error("TempFileWithText returned error on close")
}
}
func TestTempFilenameWithText(t *testing.T) {
f, err := TempFilenameWithText("test")
if err != nil {
t.Error(err)
}
if f == "" {
t.Error("TempFilenameWithText returned empty file name")
}
defer os.Remove(f)
bs, err := os.ReadFile(f)
assert.Nil(t, err)
if len(bs) != 4 {
t.Error("TempFilenameWithText returned wrong file size")
}
}

View File

@@ -1,6 +1,6 @@
package fx package fx
import "github.com/zeromicro/go-zero/core/threading" import "github.com/tal-tech/go-zero/core/threading"
// Parallel runs fns parallelly and waits for done. // Parallel runs fns parallelly and waits for done.
func Parallel(fns ...func()) { func Parallel(fns ...func()) {

View File

@@ -1,6 +1,6 @@
package fx package fx
import "github.com/zeromicro/go-zero/core/errorx" import "github.com/tal-tech/go-zero/core/errorx"
const defaultRetryTimes = 3 const defaultRetryTimes = 3

View File

@@ -4,9 +4,9 @@ import (
"sort" "sort"
"sync" "sync"
"github.com/zeromicro/go-zero/core/collection" "github.com/tal-tech/go-zero/core/collection"
"github.com/zeromicro/go-zero/core/lang" "github.com/tal-tech/go-zero/core/lang"
"github.com/zeromicro/go-zero/core/threading" "github.com/tal-tech/go-zero/core/threading"
) )
const ( const (
@@ -328,7 +328,7 @@ func (s Stream) Parallel(fn ParallelFunc, opts ...Option) {
}, opts...).Done() }, opts...).Done()
} }
// Reduce is an utility method to let the caller deal with the underlying channel. // Reduce is a utility method to let the caller deal with the underlying channel.
func (s Stream) Reduce(fn ReduceFunc) (interface{}, error) { func (s Stream) Reduce(fn ReduceFunc) (interface{}, error) {
return fn(s.source) return fn(s.source)
} }

View File

@@ -1,7 +1,7 @@
package fx package fx
import ( import (
"io" "io/ioutil"
"log" "log"
"math/rand" "math/rand"
"reflect" "reflect"
@@ -13,8 +13,7 @@ import (
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/stringx" "github.com/tal-tech/go-zero/core/stringx"
"go.uber.org/goleak"
) )
func TestBuffer(t *testing.T) { func TestBuffer(t *testing.T) {
@@ -238,7 +237,7 @@ func TestLast(t *testing.T) {
func TestMap(t *testing.T) { func TestMap(t *testing.T) {
runCheckedTest(t, func(t *testing.T) { runCheckedTest(t, func(t *testing.T) {
log.SetOutput(io.Discard) log.SetOutput(ioutil.Discard)
tests := []struct { tests := []struct {
mapper MapFunc mapper MapFunc
@@ -564,6 +563,9 @@ func equal(t *testing.T, stream Stream, data []interface{}) {
} }
func runCheckedTest(t *testing.T, fn func(t *testing.T)) { func runCheckedTest(t *testing.T, fn func(t *testing.T)) {
defer goleak.VerifyNone(t) goroutines := runtime.NumGoroutine()
fn(t) fn(t)
// let scheduler schedule first
time.Sleep(time.Millisecond)
assert.True(t, runtime.NumGoroutine() <= goroutines)
} }

View File

@@ -6,7 +6,8 @@ import (
"strconv" "strconv"
"sync" "sync"
"github.com/zeromicro/go-zero/core/lang" "github.com/tal-tech/go-zero/core/lang"
"github.com/tal-tech/go-zero/core/mapping"
) )
const ( const (
@@ -182,5 +183,5 @@ func innerRepr(node interface{}) string {
} }
func repr(node interface{}) string { func repr(node interface{}) string {
return lang.Repr(node) return mapping.Repr(node)
} }

View File

@@ -6,7 +6,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/mathx" "github.com/tal-tech/go-zero/core/mathx"
) )
const ( const (

View File

@@ -10,7 +10,7 @@ func (nopCloser) Close() error {
return nil return nil
} }
// NopCloser returns an io.WriteCloser that does nothing on calling Close. // NopCloser returns a io.WriteCloser that does nothing on calling Close.
func NopCloser(w io.Writer) io.WriteCloser { func NopCloser(w io.Writer) io.WriteCloser {
return nopCloser{w} return nopCloser{w}
} }

View File

@@ -4,6 +4,7 @@ import (
"bufio" "bufio"
"bytes" "bytes"
"io" "io"
"io/ioutil"
"os" "os"
"strings" "strings"
) )
@@ -25,7 +26,7 @@ type (
func DupReadCloser(reader io.ReadCloser) (io.ReadCloser, io.ReadCloser) { func DupReadCloser(reader io.ReadCloser) (io.ReadCloser, io.ReadCloser) {
var buf bytes.Buffer var buf bytes.Buffer
tee := io.TeeReader(reader, &buf) tee := io.TeeReader(reader, &buf)
return io.NopCloser(tee), io.NopCloser(&buf) return ioutil.NopCloser(tee), ioutil.NopCloser(&buf)
} }
// KeepSpace customizes the reading functions to keep leading and tailing spaces. // KeepSpace customizes the reading functions to keep leading and tailing spaces.
@@ -53,7 +54,7 @@ func ReadBytes(reader io.Reader, buf []byte) error {
// ReadText reads content from the given file with leading and tailing spaces trimmed. // ReadText reads content from the given file with leading and tailing spaces trimmed.
func ReadText(filename string) (string, error) { func ReadText(filename string) (string, error) {
content, err := os.ReadFile(filename) content, err := ioutil.ReadFile(filename)
if err != nil { if err != nil {
return "", err return "", err
} }

View File

@@ -3,13 +3,14 @@ package iox
import ( import (
"bytes" "bytes"
"io" "io"
"io/ioutil"
"os" "os"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/fs" "github.com/tal-tech/go-zero/core/fs"
"github.com/zeromicro/go-zero/core/stringx" "github.com/tal-tech/go-zero/core/stringx"
) )
func TestReadText(t *testing.T) { func TestReadText(t *testing.T) {
@@ -96,10 +97,10 @@ func TestReadTextLines(t *testing.T) {
func TestDupReadCloser(t *testing.T) { func TestDupReadCloser(t *testing.T) {
input := "hello" input := "hello"
reader := io.NopCloser(bytes.NewBufferString(input)) reader := ioutil.NopCloser(bytes.NewBufferString(input))
r1, r2 := DupReadCloser(reader) r1, r2 := DupReadCloser(reader)
verify := func(r io.Reader) { verify := func(r io.Reader) {
output, err := io.ReadAll(r) output, err := ioutil.ReadAll(r)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, input, string(output)) assert.Equal(t, input, string(output))
} }
@@ -109,7 +110,7 @@ func TestDupReadCloser(t *testing.T) {
} }
func TestReadBytes(t *testing.T) { func TestReadBytes(t *testing.T) {
reader := io.NopCloser(bytes.NewBufferString("helloworld")) reader := ioutil.NopCloser(bytes.NewBufferString("helloworld"))
buf := make([]byte, 5) buf := make([]byte, 5)
err := ReadBytes(reader, buf) err := ReadBytes(reader, buf)
assert.Nil(t, err) assert.Nil(t, err)
@@ -117,7 +118,7 @@ func TestReadBytes(t *testing.T) {
} }
func TestReadBytesNotEnough(t *testing.T) { func TestReadBytesNotEnough(t *testing.T) {
reader := io.NopCloser(bytes.NewBufferString("hell")) reader := ioutil.NopCloser(bytes.NewBufferString("hell"))
buf := make([]byte, 5) buf := make([]byte, 5)
err := ReadBytes(reader, buf) err := ReadBytes(reader, buf)
assert.Equal(t, io.EOF, err) assert.Equal(t, io.EOF, err)

View File

@@ -1,6 +1,7 @@
package iox package iox
import ( import (
"io/ioutil"
"os" "os"
"testing" "testing"
@@ -12,7 +13,7 @@ func TestCountLines(t *testing.T) {
2 2
3 3
4` 4`
file, err := os.CreateTemp(os.TempDir(), "test-") file, err := ioutil.TempFile(os.TempDir(), "test-")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

44
core/jsontype/time.go Normal file
View File

@@ -0,0 +1,44 @@
package jsontype
import (
"encoding/json"
"time"
"github.com/globalsign/mgo/bson"
)
// MilliTime represents time.Time that works better with mongodb.
type MilliTime struct {
time.Time
}
// MarshalJSON marshals mt to json bytes.
func (mt MilliTime) MarshalJSON() ([]byte, error) {
return json.Marshal(mt.Milli())
}
// UnmarshalJSON unmarshals data into mt.
func (mt *MilliTime) UnmarshalJSON(data []byte) error {
var milli int64
if err := json.Unmarshal(data, &milli); err != nil {
return err
}
mt.Time = time.Unix(0, milli*int64(time.Millisecond))
return nil
}
// GetBSON returns BSON base on mt.
func (mt MilliTime) GetBSON() (interface{}, error) {
return mt.Time, nil
}
// SetBSON sets raw into mt.
func (mt *MilliTime) SetBSON(raw bson.Raw) error {
return raw.Unmarshal(&mt.Time)
}
// Milli returns milliseconds for mt.
func (mt MilliTime) Milli() int64 {
return mt.UnixNano() / int64(time.Millisecond)
}

126
core/jsontype/time_test.go Normal file
View File

@@ -0,0 +1,126 @@
package jsontype
import (
"strconv"
"testing"
"time"
"github.com/globalsign/mgo/bson"
"github.com/stretchr/testify/assert"
)
func TestMilliTime_GetBSON(t *testing.T) {
tests := []struct {
name string
tm time.Time
}{
{
name: "now",
tm: time.Now(),
},
{
name: "future",
tm: time.Now().Add(time.Hour),
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
got, err := MilliTime{test.tm}.GetBSON()
assert.Nil(t, err)
assert.Equal(t, test.tm, got)
})
}
}
func TestMilliTime_MarshalJSON(t *testing.T) {
tests := []struct {
name string
tm time.Time
}{
{
name: "now",
tm: time.Now(),
},
{
name: "future",
tm: time.Now().Add(time.Hour),
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
b, err := MilliTime{test.tm}.MarshalJSON()
assert.Nil(t, err)
assert.Equal(t, strconv.FormatInt(test.tm.UnixNano()/1e6, 10), string(b))
})
}
}
func TestMilliTime_Milli(t *testing.T) {
tests := []struct {
name string
tm time.Time
}{
{
name: "now",
tm: time.Now(),
},
{
name: "future",
tm: time.Now().Add(time.Hour),
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
n := MilliTime{test.tm}.Milli()
assert.Equal(t, test.tm.UnixNano()/1e6, n)
})
}
}
func TestMilliTime_UnmarshalJSON(t *testing.T) {
tests := []struct {
name string
tm time.Time
}{
{
name: "now",
tm: time.Now(),
},
{
name: "future",
tm: time.Now().Add(time.Hour),
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var mt MilliTime
s := strconv.FormatInt(test.tm.UnixNano()/1e6, 10)
err := mt.UnmarshalJSON([]byte(s))
assert.Nil(t, err)
s1, err := mt.MarshalJSON()
assert.Nil(t, err)
assert.Equal(t, s, string(s1))
})
}
}
func TestUnmarshalWithError(t *testing.T) {
var mt MilliTime
assert.NotNil(t, mt.UnmarshalJSON([]byte("hello")))
}
func TestSetBSON(t *testing.T) {
data, err := bson.Marshal(time.Now())
assert.Nil(t, err)
var raw bson.Raw
assert.Nil(t, bson.Unmarshal(data, &raw))
var mt MilliTime
assert.Nil(t, mt.SetBSON(raw))
assert.NotNil(t, mt.SetBSON(bson.Raw{}))
}

View File

@@ -13,16 +13,6 @@ func Marshal(v interface{}) ([]byte, error) {
return json.Marshal(v) return json.Marshal(v)
} }
// MarshalToString marshals v into a string.
func MarshalToString(v interface{}) (string, error) {
data, err := Marshal(v)
if err != nil {
return "", err
}
return string(data), nil
}
// Unmarshal unmarshals data bytes into v. // Unmarshal unmarshals data bytes into v.
func Unmarshal(data []byte, v interface{}) error { func Unmarshal(data []byte, v interface{}) error {
decoder := json.NewDecoder(bytes.NewReader(data)) decoder := json.NewDecoder(bytes.NewReader(data))
@@ -61,5 +51,5 @@ func unmarshalUseNumber(decoder *json.Decoder, v interface{}) error {
} }
func formatError(v string, err error) error { func formatError(v string, err error) error {
return fmt.Errorf("string: `%s`, error: `%w`", v, err) return fmt.Errorf("string: `%s`, error: `%s`", v, err.Error())
} }

View File

@@ -1,103 +0,0 @@
package jsonx
import (
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestMarshal(t *testing.T) {
var v = struct {
Name string `json:"name"`
Age int `json:"age"`
}{
Name: "John",
Age: 30,
}
bs, err := Marshal(v)
assert.Nil(t, err)
assert.Equal(t, `{"name":"John","age":30}`, string(bs))
}
func TestMarshalToString(t *testing.T) {
var v = struct {
Name string `json:"name"`
Age int `json:"age"`
}{
Name: "John",
Age: 30,
}
toString, err := MarshalToString(v)
assert.Nil(t, err)
assert.Equal(t, `{"name":"John","age":30}`, toString)
_, err = MarshalToString(make(chan int))
assert.NotNil(t, err)
}
func TestUnmarshal(t *testing.T) {
const s = `{"name":"John","age":30}`
var v struct {
Name string `json:"name"`
Age int `json:"age"`
}
err := Unmarshal([]byte(s), &v)
assert.Nil(t, err)
assert.Equal(t, "John", v.Name)
assert.Equal(t, 30, v.Age)
}
func TestUnmarshalError(t *testing.T) {
const s = `{"name":"John","age":30`
var v struct {
Name string `json:"name"`
Age int `json:"age"`
}
err := Unmarshal([]byte(s), &v)
assert.NotNil(t, err)
}
func TestUnmarshalFromString(t *testing.T) {
const s = `{"name":"John","age":30}`
var v struct {
Name string `json:"name"`
Age int `json:"age"`
}
err := UnmarshalFromString(s, &v)
assert.Nil(t, err)
assert.Equal(t, "John", v.Name)
assert.Equal(t, 30, v.Age)
}
func TestUnmarshalFromStringError(t *testing.T) {
const s = `{"name":"John","age":30`
var v struct {
Name string `json:"name"`
Age int `json:"age"`
}
err := UnmarshalFromString(s, &v)
assert.NotNil(t, err)
}
func TestUnmarshalFromRead(t *testing.T) {
const s = `{"name":"John","age":30}`
var v struct {
Name string `json:"name"`
Age int `json:"age"`
}
err := UnmarshalFromReader(strings.NewReader(s), &v)
assert.Nil(t, err)
assert.Equal(t, "John", v.Name)
assert.Equal(t, 30, v.Age)
}
func TestUnmarshalFromReaderError(t *testing.T) {
const s = `{"name":"John","age":30`
var v struct {
Name string `json:"name"`
Age int `json:"age"`
}
err := UnmarshalFromReader(strings.NewReader(s), &v)
assert.NotNil(t, err)
}

View File

@@ -1,11 +1,5 @@
package lang package lang
import (
"fmt"
"reflect"
"strconv"
)
// Placeholder is a placeholder object that can be used globally. // Placeholder is a placeholder object that can be used globally.
var Placeholder PlaceholderType var Placeholder PlaceholderType
@@ -15,64 +9,3 @@ type (
// PlaceholderType represents a placeholder type. // PlaceholderType represents a placeholder type.
PlaceholderType = struct{} PlaceholderType = struct{}
) )
// Repr returns the string representation of v.
func Repr(v interface{}) string {
if v == nil {
return ""
}
// if func (v *Type) String() string, we can't use Elem()
switch vt := v.(type) {
case fmt.Stringer:
return vt.String()
}
val := reflect.ValueOf(v)
for val.Kind() == reflect.Ptr && !val.IsNil() {
val = val.Elem()
}
return reprOfValue(val)
}
func reprOfValue(val reflect.Value) string {
switch vt := val.Interface().(type) {
case bool:
return strconv.FormatBool(vt)
case error:
return vt.Error()
case float32:
return strconv.FormatFloat(float64(vt), 'f', -1, 32)
case float64:
return strconv.FormatFloat(vt, 'f', -1, 64)
case fmt.Stringer:
return vt.String()
case int:
return strconv.Itoa(vt)
case int8:
return strconv.Itoa(int(vt))
case int16:
return strconv.Itoa(int(vt))
case int32:
return strconv.Itoa(int(vt))
case int64:
return strconv.FormatInt(vt, 10)
case string:
return vt
case uint:
return strconv.FormatUint(uint64(vt), 10)
case uint8:
return strconv.FormatUint(uint64(vt), 10)
case uint16:
return strconv.FormatUint(uint64(vt), 10)
case uint32:
return strconv.FormatUint(uint64(vt), 10)
case uint64:
return strconv.FormatUint(vt, 10)
case []byte:
return string(vt)
default:
return fmt.Sprint(val.Interface())
}
}

View File

@@ -1,156 +0,0 @@
package lang
import (
"encoding/json"
"errors"
"reflect"
"testing"
"github.com/stretchr/testify/assert"
)
func TestRepr(t *testing.T) {
var (
f32 float32 = 1.1
f64 = 2.2
i8 int8 = 1
i16 int16 = 2
i32 int32 = 3
i64 int64 = 4
u8 uint8 = 5
u16 uint16 = 6
u32 uint32 = 7
u64 uint64 = 8
)
tests := []struct {
v interface{}
expect string
}{
{
nil,
"",
},
{
mockStringable{},
"mocked",
},
{
new(mockStringable),
"mocked",
},
{
newMockPtr(),
"mockptr",
},
{
&mockOpacity{
val: 1,
},
"{1}",
},
{
true,
"true",
},
{
false,
"false",
},
{
f32,
"1.1",
},
{
f64,
"2.2",
},
{
i8,
"1",
},
{
i16,
"2",
},
{
i32,
"3",
},
{
i64,
"4",
},
{
u8,
"5",
},
{
u16,
"6",
},
{
u32,
"7",
},
{
u64,
"8",
},
{
[]byte(`abcd`),
"abcd",
},
{
mockOpacity{val: 1},
"{1}",
},
}
for _, test := range tests {
t.Run(test.expect, func(t *testing.T) {
assert.Equal(t, test.expect, Repr(test.v))
})
}
}
func TestReprOfValue(t *testing.T) {
t.Run("error", func(t *testing.T) {
assert.Equal(t, "error", reprOfValue(reflect.ValueOf(errors.New("error"))))
})
t.Run("stringer", func(t *testing.T) {
assert.Equal(t, "1.23", reprOfValue(reflect.ValueOf(json.Number("1.23"))))
})
t.Run("int", func(t *testing.T) {
assert.Equal(t, "1", reprOfValue(reflect.ValueOf(1)))
})
t.Run("int", func(t *testing.T) {
assert.Equal(t, "1", reprOfValue(reflect.ValueOf("1")))
})
t.Run("int", func(t *testing.T) {
assert.Equal(t, "1", reprOfValue(reflect.ValueOf(uint(1))))
})
}
type mockStringable struct{}
func (m mockStringable) String() string {
return "mocked"
}
type mockPtr struct{}
func newMockPtr() *mockPtr {
return new(mockPtr)
}
func (m *mockPtr) String() string {
return "mockptr"
}
type mockOpacity struct {
val int
}

View File

@@ -1,28 +1,30 @@
package limit package limit
import ( import (
"context"
"errors" "errors"
"strconv" "strconv"
"time" "time"
"github.com/zeromicro/go-zero/core/stores/redis" "github.com/tal-tech/go-zero/core/stores/redis"
) )
// to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key const (
const periodScript = `local limit = tonumber(ARGV[1]) // to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key
periodScript = `local limit = tonumber(ARGV[1])
local window = tonumber(ARGV[2]) local window = tonumber(ARGV[2])
local current = redis.call("INCRBY", KEYS[1], 1) local current = redis.call("INCRBY", KEYS[1], 1)
if current == 1 then if current == 1 then
redis.call("expire", KEYS[1], window) redis.call("expire", KEYS[1], window)
end return 1
if current < limit then elseif current < limit then
return 1 return 1
elseif current == limit then elseif current == limit then
return 2 return 2
else else
return 0 return 0
end` end`
zoneDiff = 3600 * 8 // GMT+8 for our services
)
const ( const (
// Unknown means not initialized state. // Unknown means not initialized state.
@@ -75,12 +77,7 @@ func NewPeriodLimit(period, quota int, limitStore *redis.Redis, keyPrefix string
// Take requests a permit, it returns the permit state. // Take requests a permit, it returns the permit state.
func (h *PeriodLimit) Take(key string) (int, error) { func (h *PeriodLimit) Take(key string) (int, error) {
return h.TakeCtx(context.Background(), key) resp, err := h.limitStore.Eval(periodScript, []string{h.keyPrefix + key}, []string{
}
// TakeCtx requests a permit with context, it returns the permit state.
func (h *PeriodLimit) TakeCtx(ctx context.Context, key string) (int, error) {
resp, err := h.limitStore.EvalCtx(ctx, periodScript, []string{h.keyPrefix + key}, []string{
strconv.Itoa(h.quota), strconv.Itoa(h.quota),
strconv.Itoa(h.calcExpireSeconds()), strconv.Itoa(h.calcExpireSeconds()),
}) })
@@ -107,9 +104,7 @@ func (h *PeriodLimit) TakeCtx(ctx context.Context, key string) (int, error) {
func (h *PeriodLimit) calcExpireSeconds() int { func (h *PeriodLimit) calcExpireSeconds() int {
if h.align { if h.align {
now := time.Now() unix := time.Now().Unix() + zoneDiff
_, offset := now.Zone()
unix := now.Unix() + int64(offset)
return h.period - int(unix%int64(h.period)) return h.period - int(unix%int64(h.period))
} }
@@ -117,8 +112,6 @@ func (h *PeriodLimit) calcExpireSeconds() int {
} }
// Align returns a func to customize a PeriodLimit with alignment. // Align returns a func to customize a PeriodLimit with alignment.
// For example, if we want to limit end users with 5 sms verification messages every day,
// we need to align with the local timezone and the start of the day.
func Align() PeriodOption { func Align() PeriodOption {
return func(l *PeriodLimit) { return func(l *PeriodLimit) {
l.align = true l.align = true

View File

@@ -5,8 +5,8 @@ import (
"github.com/alicebob/miniredis/v2" "github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/stores/redis" "github.com/tal-tech/go-zero/core/stores/redis"
"github.com/zeromicro/go-zero/core/stores/redis/redistest" "github.com/tal-tech/go-zero/core/stores/redis/redistest"
) )
func TestPeriodLimit_Take(t *testing.T) { func TestPeriodLimit_Take(t *testing.T) {
@@ -23,9 +23,10 @@ func TestPeriodLimit_RedisUnavailable(t *testing.T) {
const ( const (
seconds = 1 seconds = 1
total = 100
quota = 5 quota = 5
) )
l := NewPeriodLimit(seconds, quota, redis.New(s.Addr()), "periodlimit") l := NewPeriodLimit(seconds, quota, redis.NewRedis(s.Addr(), redis.NodeType), "periodlimit")
s.Close() s.Close()
val, err := l.Take("first") val, err := l.Take("first")
assert.NotNil(t, err) assert.NotNil(t, err)
@@ -65,13 +66,3 @@ func testPeriodLimit(t *testing.T, opts ...PeriodOption) {
assert.Equal(t, 1, hitQuota) assert.Equal(t, 1, hitQuota)
assert.Equal(t, total-quota, overQuota) assert.Equal(t, total-quota, overQuota)
} }
func TestQuotaFull(t *testing.T) {
s, err := miniredis.Run()
assert.Nil(t, err)
l := NewPeriodLimit(1, 1, redis.New(s.Addr()), "periodlimit")
val, err := l.Take("first")
assert.Nil(t, err)
assert.Equal(t, HitQuota, val)
}

View File

@@ -1,16 +1,14 @@
package limit package limit
import ( import (
"context"
"errors"
"fmt" "fmt"
"strconv" "strconv"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/zeromicro/go-zero/core/logx" "github.com/tal-tech/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stores/redis" "github.com/tal-tech/go-zero/core/stores/redis"
xrate "golang.org/x/time/rate" xrate "golang.org/x/time/rate"
) )
@@ -60,8 +58,8 @@ type TokenLimiter struct {
timestampKey string timestampKey string
rescueLock sync.Mutex rescueLock sync.Mutex
redisAlive uint32 redisAlive uint32
monitorStarted bool
rescueLimiter *xrate.Limiter rescueLimiter *xrate.Limiter
monitorStarted bool
} }
// NewTokenLimiter returns a new TokenLimiter that allows events up to rate and permits // NewTokenLimiter returns a new TokenLimiter that allows events up to rate and permits
@@ -86,31 +84,19 @@ func (lim *TokenLimiter) Allow() bool {
return lim.AllowN(time.Now(), 1) return lim.AllowN(time.Now(), 1)
} }
// AllowCtx is shorthand for AllowNCtx(ctx,time.Now(), 1) with incoming context.
func (lim *TokenLimiter) AllowCtx(ctx context.Context) bool {
return lim.AllowNCtx(ctx, time.Now(), 1)
}
// AllowN reports whether n events may happen at time now. // AllowN reports whether n events may happen at time now.
// Use this method if you intend to drop / skip events that exceed the rate. // Use this method if you intend to drop / skip events that exceed the rate rate.
// Otherwise, use Reserve or Wait. // Otherwise use Reserve or Wait.
func (lim *TokenLimiter) AllowN(now time.Time, n int) bool { func (lim *TokenLimiter) AllowN(now time.Time, n int) bool {
return lim.reserveN(context.Background(), now, n) return lim.reserveN(now, n)
} }
// AllowNCtx reports whether n events may happen at time now with incoming context. func (lim *TokenLimiter) reserveN(now time.Time, n int) bool {
// Use this method if you intend to drop / skip events that exceed the rate.
// Otherwise, use Reserve or Wait.
func (lim *TokenLimiter) AllowNCtx(ctx context.Context, now time.Time, n int) bool {
return lim.reserveN(ctx, now, n)
}
func (lim *TokenLimiter) reserveN(ctx context.Context, now time.Time, n int) bool {
if atomic.LoadUint32(&lim.redisAlive) == 0 { if atomic.LoadUint32(&lim.redisAlive) == 0 {
return lim.rescueLimiter.AllowN(now, n) return lim.rescueLimiter.AllowN(now, n)
} }
resp, err := lim.store.EvalCtx(ctx, resp, err := lim.store.Eval(
script, script,
[]string{ []string{
lim.tokenKey, lim.tokenKey,
@@ -126,12 +112,7 @@ func (lim *TokenLimiter) reserveN(ctx context.Context, now time.Time, n int) boo
// Lua boolean false -> r Nil bulk reply // Lua boolean false -> r Nil bulk reply
if err == redis.Nil { if err == redis.Nil {
return false return false
} } else if err != nil {
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
logx.Errorf("fail to use rate limiter: %s", err)
return false
}
if err != nil {
logx.Errorf("fail to use rate limiter: %s, use in-process limiter for rescue", err) logx.Errorf("fail to use rate limiter: %s, use in-process limiter for rescue", err)
lim.startMonitor() lim.startMonitor()
return lim.rescueLimiter.AllowN(now, n) return lim.rescueLimiter.AllowN(now, n)

View File

@@ -1,45 +1,20 @@
package limit package limit
import ( import (
"context"
"testing" "testing"
"time" "time"
"github.com/alicebob/miniredis/v2" "github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/logx" "github.com/tal-tech/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stores/redis" "github.com/tal-tech/go-zero/core/stores/redis"
"github.com/zeromicro/go-zero/core/stores/redis/redistest" "github.com/tal-tech/go-zero/core/stores/redis/redistest"
) )
func init() { func init() {
logx.Disable() logx.Disable()
} }
func TestTokenLimit_WithCtx(t *testing.T) {
s, err := miniredis.Run()
assert.Nil(t, err)
const (
total = 100
rate = 5
burst = 10
)
l := NewTokenLimiter(rate, burst, redis.New(s.Addr()), "tokenlimit")
defer s.Close()
ctx, cancel := context.WithCancel(context.Background())
ok := l.AllowCtx(ctx)
assert.True(t, ok)
cancel()
for i := 0; i < total; i++ {
ok := l.AllowCtx(ctx)
assert.False(t, ok)
assert.False(t, l.monitorStarted)
}
}
func TestTokenLimit_Rescue(t *testing.T) { func TestTokenLimit_Rescue(t *testing.T) {
s, err := miniredis.Run() s, err := miniredis.Run()
assert.Nil(t, err) assert.Nil(t, err)

View File

@@ -7,17 +7,17 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/zeromicro/go-zero/core/collection" "github.com/tal-tech/go-zero/core/collection"
"github.com/zeromicro/go-zero/core/logx" "github.com/tal-tech/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stat" "github.com/tal-tech/go-zero/core/stat"
"github.com/zeromicro/go-zero/core/syncx" "github.com/tal-tech/go-zero/core/syncx"
"github.com/zeromicro/go-zero/core/timex" "github.com/tal-tech/go-zero/core/timex"
) )
const ( const (
defaultBuckets = 50 defaultBuckets = 50
defaultWindow = time.Second * 5 defaultWindow = time.Second * 5
// using 1000m notation, 900m is like 90%, keep it as var for unit test // using 1000m notation, 900m is like 80%, keep it as var for unit test
defaultCpuThreshold = 900 defaultCpuThreshold = 900
defaultMinRt = float64(time.Second / time.Millisecond) defaultMinRt = float64(time.Second / time.Millisecond)
// moving average hyperparameter beta for calculating requests on the fly // moving average hyperparameter beta for calculating requests on the fly
@@ -70,7 +70,7 @@ type (
flying int64 flying int64
avgFlying float64 avgFlying float64
avgFlyingLock syncx.SpinLock avgFlyingLock syncx.SpinLock
overloadTime *syncx.AtomicDuration dropTime *syncx.AtomicDuration
droppedRecently *syncx.AtomicBool droppedRecently *syncx.AtomicBool
passCounter *collection.RollingWindow passCounter *collection.RollingWindow
rtCounter *collection.RollingWindow rtCounter *collection.RollingWindow
@@ -106,7 +106,7 @@ func NewAdaptiveShedder(opts ...ShedderOption) Shedder {
return &adaptiveShedder{ return &adaptiveShedder{
cpuThreshold: options.cpuThreshold, cpuThreshold: options.cpuThreshold,
windows: int64(time.Second / bucketDuration), windows: int64(time.Second / bucketDuration),
overloadTime: syncx.NewAtomicDuration(), dropTime: syncx.NewAtomicDuration(),
droppedRecently: syncx.NewAtomicBool(), droppedRecently: syncx.NewAtomicBool(),
passCounter: collection.NewRollingWindow(options.buckets, bucketDuration, passCounter: collection.NewRollingWindow(options.buckets, bucketDuration,
collection.IgnoreCurrentBucket()), collection.IgnoreCurrentBucket()),
@@ -118,6 +118,7 @@ func NewAdaptiveShedder(opts ...ShedderOption) Shedder {
// Allow implements Shedder.Allow. // Allow implements Shedder.Allow.
func (as *adaptiveShedder) Allow() (Promise, error) { func (as *adaptiveShedder) Allow() (Promise, error) {
if as.shouldDrop() { if as.shouldDrop() {
as.dropTime.Set(timex.Now())
as.droppedRecently.Set(true) as.droppedRecently.Set(true)
return nil, ErrServiceOverloaded return nil, ErrServiceOverloaded
@@ -214,26 +215,21 @@ func (as *adaptiveShedder) stillHot() bool {
return false return false
} }
overloadTime := as.overloadTime.Load() dropTime := as.dropTime.Load()
if overloadTime == 0 { if dropTime == 0 {
return false return false
} }
if timex.Since(overloadTime) < coolOffDuration { hot := timex.Since(dropTime) < coolOffDuration
return true if !hot {
as.droppedRecently.Set(false)
} }
as.droppedRecently.Set(false) return hot
return false
} }
func (as *adaptiveShedder) systemOverloaded() bool { func (as *adaptiveShedder) systemOverloaded() bool {
if !systemOverloadChecker(as.cpuThreshold) { return systemOverloadChecker(as.cpuThreshold)
return false
}
as.overloadTime.Set(timex.Now())
return true
} }
// WithBuckets customizes the Shedder with given number of buckets. // WithBuckets customizes the Shedder with given number of buckets.

View File

@@ -8,12 +8,11 @@ import (
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/collection" "github.com/tal-tech/go-zero/core/collection"
"github.com/zeromicro/go-zero/core/logx" "github.com/tal-tech/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/mathx" "github.com/tal-tech/go-zero/core/mathx"
"github.com/zeromicro/go-zero/core/stat" "github.com/tal-tech/go-zero/core/stat"
"github.com/zeromicro/go-zero/core/syncx" "github.com/tal-tech/go-zero/core/syncx"
"github.com/zeromicro/go-zero/core/timex"
) )
const ( const (
@@ -137,7 +136,7 @@ func TestAdaptiveShedderShouldDrop(t *testing.T) {
passCounter: passCounter, passCounter: passCounter,
rtCounter: rtCounter, rtCounter: rtCounter,
windows: buckets, windows: buckets,
overloadTime: syncx.NewAtomicDuration(), dropTime: syncx.NewAtomicDuration(),
droppedRecently: syncx.NewAtomicBool(), droppedRecently: syncx.NewAtomicBool(),
} }
// cpu >= 800, inflight < maxPass // cpu >= 800, inflight < maxPass
@@ -191,15 +190,12 @@ func TestAdaptiveShedderStillHot(t *testing.T) {
passCounter: passCounter, passCounter: passCounter,
rtCounter: rtCounter, rtCounter: rtCounter,
windows: buckets, windows: buckets,
overloadTime: syncx.NewAtomicDuration(), dropTime: syncx.NewAtomicDuration(),
droppedRecently: syncx.ForAtomicBool(true), droppedRecently: syncx.ForAtomicBool(true),
} }
assert.False(t, shedder.stillHot()) assert.False(t, shedder.stillHot())
shedder.overloadTime.Set(-coolOffDuration * 2) shedder.dropTime.Set(-coolOffDuration * 2)
assert.False(t, shedder.stillHot()) assert.False(t, shedder.stillHot())
shedder.droppedRecently.Set(true)
shedder.overloadTime.Set(timex.Now())
assert.True(t, shedder.stillHot())
} }
func BenchmarkAdaptiveShedder_Allow(b *testing.B) { func BenchmarkAdaptiveShedder_Allow(b *testing.B) {

View File

@@ -3,7 +3,7 @@ package load
import ( import (
"io" "io"
"github.com/zeromicro/go-zero/core/syncx" "github.com/tal-tech/go-zero/core/syncx"
) )
// A ShedderGroup is a manager to manage key based shedders. // A ShedderGroup is a manager to manage key based shedders.

View File

@@ -4,8 +4,8 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/zeromicro/go-zero/core/logx" "github.com/tal-tech/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stat" "github.com/tal-tech/go-zero/core/stat"
) )
type ( type (

Some files were not shown because too many files have changed in this diff Show More