mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-06-13 17:31:56 +08:00
Compare commits
91 Commits
tools/goct
...
v1.9.4
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f29c8612e8 | ||
|
|
35ba024103 | ||
|
|
52df1c532a | ||
|
|
39729f3756 | ||
|
|
5c9ea81db2 | ||
|
|
b284664de4 | ||
|
|
1b76885040 | ||
|
|
eef217522b | ||
|
|
6bd0d169d5 | ||
|
|
3d291328d8 | ||
|
|
858f8ca82e | ||
|
|
4ff3975c5a | ||
|
|
7b23f73268 | ||
|
|
918a7be698 | ||
|
|
0a724447cd | ||
|
|
9e425893a7 | ||
|
|
4de13b6cc8 | ||
|
|
c6f75532fa | ||
|
|
fdf4ccf057 | ||
|
|
b333ed245b | ||
|
|
8f1576df36 | ||
|
|
72dd970969 | ||
|
|
29b65e12c1 | ||
|
|
577a611dc3 | ||
|
|
75941aedd4 | ||
|
|
c7065171d7 | ||
|
|
052de3b552 | ||
|
|
866613af8c | ||
|
|
3d4f6a5e16 | ||
|
|
d1d47d02d5 | ||
|
|
d6c876860b | ||
|
|
98423ca948 | ||
|
|
4e52d77ad8 | ||
|
|
1fc2cfb859 | ||
|
|
942cdae41d | ||
|
|
e9c3607bc6 | ||
|
|
d1603e9166 | ||
|
|
e30317e9c4 | ||
|
|
568f9ce007 | ||
|
|
dcb309065a | ||
|
|
bf8e17a686 | ||
|
|
b2ebbfce62 | ||
|
|
2b10a6a223 | ||
|
|
80c320b46e | ||
|
|
bea9d150a1 | ||
|
|
3f756a2cbf | ||
|
|
bbe5bbb0c0 | ||
|
|
5ad2278a69 | ||
|
|
77763fe748 | ||
|
|
538c4fb5c7 | ||
|
|
315fb2fe0a | ||
|
|
e382887eb8 | ||
|
|
cf21cb2b0b | ||
|
|
61e8894c31 | ||
|
|
7a6c3c8129 | ||
|
|
875fec3e1a | ||
|
|
60128c2100 | ||
|
|
ce6d0e3ea7 | ||
|
|
fa85c84af3 | ||
|
|
440884105e | ||
|
|
271f10598f | ||
|
|
cf55a88ce3 | ||
|
|
c1c786b14a | ||
|
|
988fb9d9bf | ||
|
|
d212c81bca | ||
|
|
bc43df2641 | ||
|
|
351b8cb37b | ||
|
|
0d681a2e29 | ||
|
|
5ea027c5de | ||
|
|
5de6112dcd | ||
|
|
4fb51723b7 | ||
|
|
06502d1115 | ||
|
|
3854d6dd00 | ||
|
|
895854913a | ||
|
|
ef753b8857 | ||
|
|
9c16fede73 | ||
|
|
ce11adb5e4 | ||
|
|
894e8b1218 | ||
|
|
2ec7e432dd | ||
|
|
870e8352c1 | ||
|
|
de42f27e03 | ||
|
|
955b8016aa | ||
|
|
d728a3b2d9 | ||
|
|
0c205a71fc | ||
|
|
a8c0199d96 | ||
|
|
032a266ec4 | ||
|
|
40b75fbb9b | ||
|
|
afad55045b | ||
|
|
5f54f06ee5 | ||
|
|
20f56ae1d0 | ||
|
|
73d6fcfccd |
241
.github/copilot-instructions.md
vendored
Normal file
241
.github/copilot-instructions.md
vendored
Normal file
@@ -0,0 +1,241 @@
|
|||||||
|
# GitHub Copilot Instructions for go-zero
|
||||||
|
|
||||||
|
This document provides guidelines for GitHub Copilot when assisting with development in the go-zero project.
|
||||||
|
|
||||||
|
## Project Overview
|
||||||
|
|
||||||
|
go-zero is a web and RPC framework with lots of built-in engineering practices designed to ensure the stability of busy services with resilience design. It has been serving sites with tens of millions of users for years.
|
||||||
|
|
||||||
|
### Key Architecture Components
|
||||||
|
|
||||||
|
- **REST API framework** (`rest/`) - HTTP service framework with middleware chain support
|
||||||
|
- **RPC framework** (`zrpc/`) - gRPC-based RPC framework with etcd service discovery and p2c_ewma load balancing
|
||||||
|
- **Gateway** (`gateway/`) - API gateway supporting both HTTP and gRPC upstreams with proto-based routing
|
||||||
|
- **MCP Server** (`mcp/`) - Model Context Protocol server for AI agent integration via SSE
|
||||||
|
- **Core utilities** (`core/`) - Production-grade components:
|
||||||
|
- Resilience: circuit breakers (`breaker/`), rate limiters (`limit/`), adaptive load shedding (`load/`)
|
||||||
|
- Storage: SQL with cache (`stores/sqlc/`), Redis (`stores/redis/`), MongoDB (`stores/mongo/`)
|
||||||
|
- Concurrency: MapReduce (`mr/`), worker pools (`executors/`), sync primitives (`syncx/`)
|
||||||
|
- Observability: metrics (`metric/`), tracing (`trace/`), structured logging (`logx/`)
|
||||||
|
- **Code generation tool** (`tools/goctl/`) - CLI tool for generating Go code from `.api` and `.proto` files
|
||||||
|
|
||||||
|
## Coding Standards and Conventions
|
||||||
|
|
||||||
|
### Code Style
|
||||||
|
|
||||||
|
1. **Follow Go conventions**: Use `gofmt` for formatting, follow effective Go practices
|
||||||
|
2. **Package naming**: Use lowercase, single-word package names when possible
|
||||||
|
3. **Error handling**: Always handle errors explicitly, use `errorx.BatchError` for multiple errors
|
||||||
|
4. **Context propagation**: Always pass `context.Context` as the first parameter for functions that may block
|
||||||
|
5. **Configuration structures**: Use struct tags with JSON annotations, defaults, and validation
|
||||||
|
|
||||||
|
**Pattern**: All service configs embed `service.ServiceConf` for common fields (Name, Log, Mode, Telemetry)
|
||||||
|
```go
|
||||||
|
type Config struct {
|
||||||
|
service.ServiceConf // Always embed for services
|
||||||
|
Host string `json:",default=0.0.0.0"`
|
||||||
|
Port int // Required field (no default)
|
||||||
|
Timeout int64 `json:",default=3000"` // Timeouts in milliseconds
|
||||||
|
Optional string `json:",optional"` // Optional field
|
||||||
|
Mode string `json:",default=pro,options=dev|test|rt|pre|pro"` // Validated options
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Service modes**: `dev`/`test`/`rt` disable load shedding and stats; `pre`/`pro` enable all resilience features
|
||||||
|
|
||||||
|
### Interface Design
|
||||||
|
|
||||||
|
1. **Small interfaces**: Follow Go's preference for small, focused interfaces
|
||||||
|
2. **Context methods**: Provide both context and non-context versions of methods
|
||||||
|
3. **Options pattern**: Use functional options for complex configuration
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```go
|
||||||
|
func (c *Client) Get(key string, val any) error {
|
||||||
|
return c.GetCtx(context.Background(), key, val)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) GetCtx(ctx context.Context, key string, val any) error {
|
||||||
|
// implementation
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Testing Patterns
|
||||||
|
|
||||||
|
1. **Test file naming**: Use `*_test.go` suffix
|
||||||
|
2. **Test function naming**: Use `TestFunctionName` pattern
|
||||||
|
3. **Use testify/assert**: Prefer `assert` package for assertions
|
||||||
|
4. **Table-driven tests**: Use table-driven tests for multiple scenarios
|
||||||
|
5. **Mock interfaces**: Use `go.uber.org/mock` for mocking
|
||||||
|
6. **Test helpers**: Use `redistest`, `mongtest` helpers for database testing
|
||||||
|
|
||||||
|
Example test pattern:
|
||||||
|
```go
|
||||||
|
func TestSomething(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"valid case", "input", "output", false},
|
||||||
|
{"error case", "bad", "", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result, err := SomeFunction(tt.input)
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Framework-Specific Guidelines
|
||||||
|
|
||||||
|
### REST API Development
|
||||||
|
|
||||||
|
1. **API Definition**: Use `.api` files to define REST APIs with goctl codegen
|
||||||
|
2. **Handler pattern**: Separate business logic into logic packages (handlers call logic layer)
|
||||||
|
3. **Middleware chain**: Middlewares wrap via `chain.Chain` interface - use `Append()` or `Prepend()` to control order
|
||||||
|
- Built-in middlewares (all in `rest/handler/`): tracing, logging, metrics, recovery, breaker, shedding, timeout, maxconns, maxbytes, gunzip
|
||||||
|
- Custom middleware: `func(http.Handler) http.Handler` - call `next.ServeHTTP(w, r)` to continue chain
|
||||||
|
4. **Response handling**: Use `httpx.WriteJson(w, code, v)` for JSON responses
|
||||||
|
5. **Error handling**: Use `httpx.Error(w, err)` or `httpx.ErrorCtx(ctx, w, err)` for HTTP error responses
|
||||||
|
6. **Route registration**: Routes defined with `Method`, `Path`, and `Handler` - wildcards use `:param` syntax
|
||||||
|
|
||||||
|
### RPC Development
|
||||||
|
|
||||||
|
1. **Protocol Buffers**: Use protobuf for service definitions, generate code with goctl
|
||||||
|
2. **Service discovery**: Use etcd for dynamic service registration/discovery, or direct endpoints for static routing
|
||||||
|
3. **Load balancing**: Default is `p2c_ewma` (power of 2 choices with EWMA), configurable via `BalancerName`
|
||||||
|
4. **Client configuration**: Support `Etcd`, `Endpoints`, or `Target` - use `BuildTarget()` to construct connection string
|
||||||
|
5. **Interceptors**: Implement gRPC interceptors for cross-cutting concerns (auth, logging, metrics)
|
||||||
|
6. **Health checks**: gRPC health checks enabled by default (`Health: true`)
|
||||||
|
|
||||||
|
### Database Operations
|
||||||
|
|
||||||
|
1. **SQL operations**: Use `sqlx.SqlConn` interface - methods always end with `Ctx` for context support
|
||||||
|
2. **Caching pattern**: `stores/sqlc` provides `CachedConn` for automatic cache-aside pattern
|
||||||
|
- `QueryRowCtx`: Query with cache key, auto-populate on cache miss
|
||||||
|
- `ExecCtx`: Execute and delete cache keys
|
||||||
|
3. **Transactions**: Use `sqlx.SqlConn.TransactCtx()` to get transaction session
|
||||||
|
4. **Connection pooling**: Managed automatically (64 max idle/open, 1min lifetime)
|
||||||
|
5. **Test helpers**: Use `redistest.CreateRedis(t)` for Redis, SQL mocks for DB testing
|
||||||
|
|
||||||
|
Example cache pattern:
|
||||||
|
```go
|
||||||
|
err := c.QueryRowCtx(ctx, &dest, key, func(ctx context.Context, conn sqlx.SqlConn) error {
|
||||||
|
return conn.QueryRowCtx(ctx, &dest, query, args...)
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Configuration Management
|
||||||
|
|
||||||
|
1. **YAML configuration**: Use YAML for configuration files
|
||||||
|
2. **Environment variables**: Support environment variable overrides
|
||||||
|
3. **Validation**: Include proper validation for configuration parameters
|
||||||
|
4. **Sensible defaults**: Provide reasonable default values
|
||||||
|
|
||||||
|
## Error Handling Best Practices
|
||||||
|
|
||||||
|
1. **Wrap errors**: Use `fmt.Errorf` with `%w` verb to wrap errors
|
||||||
|
2. **Custom errors**: Define custom error types when needed
|
||||||
|
3. **Error logging**: Log errors appropriately with context
|
||||||
|
4. **Graceful degradation**: Implement fallback mechanisms
|
||||||
|
|
||||||
|
## Performance Considerations
|
||||||
|
|
||||||
|
1. **Resource pools**: Use connection pools and worker pools
|
||||||
|
2. **Circuit breakers**: Implement circuit breaker patterns for external calls
|
||||||
|
3. **Rate limiting**: Apply rate limiting to protect services
|
||||||
|
4. **Load shedding**: Implement adaptive load shedding
|
||||||
|
5. **Metrics**: Add appropriate metrics and monitoring
|
||||||
|
|
||||||
|
## Security Guidelines
|
||||||
|
|
||||||
|
1. **Input validation**: Validate all input parameters
|
||||||
|
2. **SQL injection prevention**: Use parameterized queries
|
||||||
|
3. **Authentication**: Implement proper JWT token handling
|
||||||
|
4. **HTTPS**: Support TLS/HTTPS configurations
|
||||||
|
5. **CORS**: Configure CORS appropriately for web APIs
|
||||||
|
|
||||||
|
## Documentation Standards
|
||||||
|
|
||||||
|
1. **Package documentation**: Include package-level documentation
|
||||||
|
2. **Function documentation**: Document exported functions with examples
|
||||||
|
3. **API documentation**: Maintain API documentation in sync
|
||||||
|
4. **README updates**: Update README for significant changes
|
||||||
|
|
||||||
|
## Common Patterns to Follow
|
||||||
|
|
||||||
|
### Service Configuration
|
||||||
|
```go
|
||||||
|
type ServiceConf struct {
|
||||||
|
Name string
|
||||||
|
Log logx.LogConf
|
||||||
|
Mode string `json:",default=pro,options=[dev,test,pre,pro]"`
|
||||||
|
// ... other common fields
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Middleware Implementation
|
||||||
|
```go
|
||||||
|
func SomeMiddleware() rest.Middleware {
|
||||||
|
return func(next http.HandlerFunc) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Pre-processing
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
// Post-processing
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Resource Management
|
||||||
|
Always implement proper resource cleanup using defer and context cancellation.
|
||||||
|
|
||||||
|
## Build and Test Commands
|
||||||
|
|
||||||
|
- Build: `go build ./...`
|
||||||
|
- Test: `go test ./...`
|
||||||
|
- Test with race detection: `go test -race ./...`
|
||||||
|
- Format: `gofmt -w .`
|
||||||
|
- Code generation:
|
||||||
|
- REST API: `goctl api go -api *.api -dir .`
|
||||||
|
- RPC: `goctl rpc protoc *.proto --go_out=. --go-grpc_out=. --zrpc_out=.`
|
||||||
|
- Model from SQL: `goctl model mysql datasource -url="user:pass@tcp(host:port)/db" -table="*" -dir="./model"`
|
||||||
|
|
||||||
|
## Critical Architecture Patterns
|
||||||
|
|
||||||
|
### Resilience Design Philosophy
|
||||||
|
go-zero implements defense-in-depth with multiple protection layers:
|
||||||
|
1. **Circuit Breaker** (`core/breaker`): Google SRE breaker - tracks success/failure, opens on error threshold
|
||||||
|
2. **Adaptive Load Shedding** (`core/load`): CPU-based auto-rejection when system overloaded (disabled in dev/test/rt modes)
|
||||||
|
3. **Rate Limiting** (`core/limit`): Token bucket (Redis-based) and period limiters
|
||||||
|
4. **Timeout Control**: Cascading timeouts via context - set at multiple levels (client, server, handler)
|
||||||
|
|
||||||
|
### Middleware Chain Architecture
|
||||||
|
`rest/chain` provides middleware composition:
|
||||||
|
```go
|
||||||
|
// Middleware signature
|
||||||
|
type Middleware func(http.Handler) http.Handler
|
||||||
|
|
||||||
|
// Chain operations
|
||||||
|
chain := chain.New(m1, m2)
|
||||||
|
chain.Append(m3) // Adds to end: m1 -> m2 -> m3
|
||||||
|
chain.Prepend(m0) // Adds to start: m0 -> m1 -> m2 -> m3
|
||||||
|
handler := chain.Then(finalHandler)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Concurrency Patterns
|
||||||
|
- **MapReduce** (`core/mr`): Parallel processing with worker pools - use for batch operations
|
||||||
|
- **Executors** (`core/executors`): Bulk/period executors for batching operations
|
||||||
|
- **SingleFlight** (`core/syncx`): Deduplicates concurrent identical requests
|
||||||
|
|
||||||
|
Remember to run tests and ensure all checks pass before submitting changes. The project emphasizes high quality, performance, and reliability, so these should be primary considerations in all development work.
|
||||||
8
.github/workflows/codeql-analysis.yml
vendored
8
.github/workflows/codeql-analysis.yml
vendored
@@ -35,11 +35,11 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v5
|
uses: actions/checkout@v6
|
||||||
|
|
||||||
# Initializes the CodeQL tools for scanning.
|
# Initializes the CodeQL tools for scanning.
|
||||||
- name: Initialize CodeQL
|
- name: Initialize CodeQL
|
||||||
uses: github/codeql-action/init@v3
|
uses: github/codeql-action/init@v4
|
||||||
with:
|
with:
|
||||||
languages: ${{ matrix.language }}
|
languages: ${{ matrix.language }}
|
||||||
# If you wish to specify custom queries, you can do so here or in a config file.
|
# If you wish to specify custom queries, you can do so here or in a config file.
|
||||||
@@ -50,7 +50,7 @@ jobs:
|
|||||||
# Autobuild attempts to build any compiled languages (C/C++, C#, or Java).
|
# Autobuild attempts to build any compiled languages (C/C++, C#, or Java).
|
||||||
# If this step fails, then you should remove it and run the build manually (see below)
|
# If this step fails, then you should remove it and run the build manually (see below)
|
||||||
- name: Autobuild
|
- name: Autobuild
|
||||||
uses: github/codeql-action/autobuild@v3
|
uses: github/codeql-action/autobuild@v4
|
||||||
|
|
||||||
# ℹ️ Command-line programs to run using the OS shell.
|
# ℹ️ Command-line programs to run using the OS shell.
|
||||||
# 📚 https://git.io/JvXDl
|
# 📚 https://git.io/JvXDl
|
||||||
@@ -64,4 +64,4 @@ jobs:
|
|||||||
# make release
|
# make release
|
||||||
|
|
||||||
- name: Perform CodeQL Analysis
|
- name: Perform CodeQL Analysis
|
||||||
uses: github/codeql-action/analyze@v3
|
uses: github/codeql-action/analyze@v4
|
||||||
|
|||||||
8
.github/workflows/go.yml
vendored
8
.github/workflows/go.yml
vendored
@@ -12,10 +12,10 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Check out code into the Go module directory
|
- name: Check out code into the Go module directory
|
||||||
uses: actions/checkout@v5
|
uses: actions/checkout@v6
|
||||||
|
|
||||||
- name: Set up Go 1.x
|
- name: Set up Go 1.x
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v6
|
||||||
with:
|
with:
|
||||||
go-version-file: go.mod
|
go-version-file: go.mod
|
||||||
check-latest: true
|
check-latest: true
|
||||||
@@ -52,10 +52,10 @@ jobs:
|
|||||||
runs-on: windows-latest
|
runs-on: windows-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout codebase
|
- name: Checkout codebase
|
||||||
uses: actions/checkout@v5
|
uses: actions/checkout@v6
|
||||||
|
|
||||||
- name: Set up Go 1.x
|
- name: Set up Go 1.x
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v6
|
||||||
with:
|
with:
|
||||||
# make sure Go version compatible with go-zero
|
# make sure Go version compatible with go-zero
|
||||||
go-version-file: go.mod
|
go-version-file: go.mod
|
||||||
|
|||||||
2
.github/workflows/issues.yml
vendored
2
.github/workflows/issues.yml
vendored
@@ -7,7 +7,7 @@ jobs:
|
|||||||
close-issues:
|
close-issues:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/stale@v9
|
- uses: actions/stale@v10
|
||||||
with:
|
with:
|
||||||
days-before-issue-stale: 365
|
days-before-issue-stale: 365
|
||||||
days-before-issue-close: 90
|
days-before-issue-close: 90
|
||||||
|
|||||||
2
.github/workflows/release.yaml
vendored
2
.github/workflows/release.yaml
vendored
@@ -16,7 +16,7 @@ jobs:
|
|||||||
- goarch: "386"
|
- goarch: "386"
|
||||||
goos: darwin
|
goos: darwin
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v6
|
||||||
- uses: zeromicro/go-zero-release-action@master
|
- uses: zeromicro/go-zero-release-action@master
|
||||||
with:
|
with:
|
||||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|||||||
7
.github/workflows/reviewdog.yml
vendored
7
.github/workflows/reviewdog.yml
vendored
@@ -5,7 +5,12 @@ jobs:
|
|||||||
name: runner / staticcheck
|
name: runner / staticcheck
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v6
|
||||||
|
- uses: actions/setup-go@v6
|
||||||
|
with:
|
||||||
|
go-version-file: go.mod
|
||||||
|
check-latest: true
|
||||||
|
cache: true
|
||||||
- uses: reviewdog/action-staticcheck@v1
|
- uses: reviewdog/action-staticcheck@v1
|
||||||
with:
|
with:
|
||||||
github_token: ${{ secrets.github_token }}
|
github_token: ${{ secrets.github_token }}
|
||||||
|
|||||||
4
.github/workflows/version-check.yml
vendored
4
.github/workflows/version-check.yml
vendored
@@ -10,10 +10,10 @@ jobs:
|
|||||||
version-check:
|
version-check:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v6
|
||||||
|
|
||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v6
|
||||||
with:
|
with:
|
||||||
go-version: '1.21'
|
go-version: '1.21'
|
||||||
|
|
||||||
|
|||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -17,6 +17,7 @@
|
|||||||
**/logs
|
**/logs
|
||||||
**/adhoc
|
**/adhoc
|
||||||
**/coverage.txt
|
**/coverage.txt
|
||||||
|
**/WARP.md
|
||||||
|
|
||||||
# for test purpose
|
# for test purpose
|
||||||
go.work
|
go.work
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ type (
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
// New create a Filter, store is the backed redis, key is the key for the bloom filter,
|
// New creates a Filter, store is the backed redis, key is the key for the bloom filter,
|
||||||
// bits is how many bits will be used, maps is how many hashes for each addition.
|
// bits is how many bits will be used, maps is how many hashes for each addition.
|
||||||
// best practices:
|
// best practices:
|
||||||
// elements - means how many actual elements
|
// elements - means how many actual elements
|
||||||
|
|||||||
@@ -81,6 +81,10 @@ func (c *Cache) Del(key string) {
|
|||||||
delete(c.data, key)
|
delete(c.data, key)
|
||||||
c.lruCache.remove(key)
|
c.lruCache.remove(key)
|
||||||
c.lock.Unlock()
|
c.lock.Unlock()
|
||||||
|
|
||||||
|
// RemoveTimer is called outside the lock to avoid performance impact from this
|
||||||
|
// potentially time-consuming operation. Data integrity is maintained by lruCache,
|
||||||
|
// which will eventually evict any remaining entries when capacity is exceeded.
|
||||||
c.timingWheel.RemoveTimer(key)
|
c.timingWheel.RemoveTimer(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -164,6 +164,7 @@ func (tw *TimingWheel) Stop() {
|
|||||||
|
|
||||||
func (tw *TimingWheel) drainAll(fn func(key, value any)) {
|
func (tw *TimingWheel) drainAll(fn func(key, value any)) {
|
||||||
runner := threading.NewTaskRunner(drainWorkers)
|
runner := threading.NewTaskRunner(drainWorkers)
|
||||||
|
|
||||||
for _, slot := range tw.slots {
|
for _, slot := range tw.slots {
|
||||||
for e := slot.Front(); e != nil; {
|
for e := slot.Front(); e != nil; {
|
||||||
task := e.Value.(*timingEntry)
|
task := e.Value.(*timingEntry)
|
||||||
@@ -177,6 +178,8 @@ func (tw *TimingWheel) drainAll(fn func(key, value any)) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
runner.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tw *TimingWheel) getPositionAndCircle(d time.Duration) (pos, circle int) {
|
func (tw *TimingWheel) getPositionAndCircle(d time.Duration) (pos, circle int) {
|
||||||
|
|||||||
@@ -629,6 +629,157 @@ func TestMoveAndRemoveTask(t *testing.T) {
|
|||||||
assert.Equal(t, 0, len(keys))
|
assert.Equal(t, 0, len(keys))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestTimingWheel_DrainClosureBug tests the closure capture bug in drainAll
|
||||||
|
// Issue: https://github.com/zeromicro/go-zero/issues/5314
|
||||||
|
func TestTimingWheel_DrainClosureBug(t *testing.T) {
|
||||||
|
ticker := timex.NewFakeTicker()
|
||||||
|
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) {}, ticker)
|
||||||
|
defer tw.Stop()
|
||||||
|
|
||||||
|
// Set multiple timers with different values
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
tw.SetTimer(i, i*10, testStep*5)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Give time for timers to be set
|
||||||
|
time.Sleep(time.Millisecond * 100)
|
||||||
|
|
||||||
|
var mu sync.Mutex
|
||||||
|
received := make(map[int]int)
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(10)
|
||||||
|
|
||||||
|
tw.Drain(func(key, value any) {
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
k := key.(int)
|
||||||
|
v := value.(int)
|
||||||
|
received[k] = v
|
||||||
|
wg.Done()
|
||||||
|
})
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// Check if all values match their keys
|
||||||
|
for k, v := range received {
|
||||||
|
expected := k * 10
|
||||||
|
assert.Equal(t, expected, v, "key %d should have value %d, got %d", k, expected, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestTimingWheel_RunTasksClosureBug tests the closure capture bug in runTasks
|
||||||
|
// Issue: https://github.com/zeromicro/go-zero/issues/5314
|
||||||
|
func TestTimingWheel_RunTasksClosureBug(t *testing.T) {
|
||||||
|
ticker := timex.NewFakeTicker()
|
||||||
|
var mu sync.Mutex
|
||||||
|
executed := make(map[int]int)
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) {
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
key := k.(int)
|
||||||
|
val := v.(int)
|
||||||
|
executed[key] = val
|
||||||
|
wg.Done()
|
||||||
|
}, ticker)
|
||||||
|
defer tw.Stop()
|
||||||
|
|
||||||
|
// Set multiple timers that should fire in the same tick
|
||||||
|
count := 10
|
||||||
|
wg.Add(count)
|
||||||
|
for i := 0; i < count; i++ {
|
||||||
|
tw.SetTimer(i, i*10, testStep)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Advance ticker to trigger tasks
|
||||||
|
ticker.Tick()
|
||||||
|
|
||||||
|
// Wait for execution with timeout
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
wg.Wait()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
// Success
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timeout waiting for tasks to execute")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify all tasks executed with correct values
|
||||||
|
assert.Equal(t, count, len(executed), "should have executed all tasks")
|
||||||
|
for k, v := range executed {
|
||||||
|
expected := k * 10
|
||||||
|
assert.Equal(t, expected, v, "key %d should have value %d, got %d", k, expected, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestTimingWheel_RunTasksRaceCondition tests for race conditions in runTasks
|
||||||
|
// This test specifically targets the loop variable capture bug
|
||||||
|
func TestTimingWheel_RunTasksRaceCondition(t *testing.T) {
|
||||||
|
// Run multiple times to increase likelihood of catching the bug
|
||||||
|
for attempt := 0; attempt < 10; attempt++ {
|
||||||
|
t.Run("", func(t *testing.T) {
|
||||||
|
ticker := timex.NewFakeTicker()
|
||||||
|
var mu sync.Mutex
|
||||||
|
keyValues := make(map[int][]int)
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v any) {
|
||||||
|
// Add small delay to increase chance of race
|
||||||
|
time.Sleep(time.Microsecond)
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
key := k.(int)
|
||||||
|
val := v.(int)
|
||||||
|
keyValues[key] = append(keyValues[key], val)
|
||||||
|
wg.Done()
|
||||||
|
}, ticker)
|
||||||
|
defer tw.Stop()
|
||||||
|
|
||||||
|
// Set many timers rapidly to increase chance of race
|
||||||
|
count := 50
|
||||||
|
wg.Add(count)
|
||||||
|
for i := 0; i < count; i++ {
|
||||||
|
tw.SetTimer(i, i*100, testStep)
|
||||||
|
}
|
||||||
|
|
||||||
|
ticker.Tick()
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
wg.Wait()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("timeout waiting for tasks")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for duplicates or wrong values
|
||||||
|
wrongCount := 0
|
||||||
|
for key, values := range keyValues {
|
||||||
|
assert.Equal(t, 1, len(values), "key %d should only execute once, got %v", key, values)
|
||||||
|
if len(values) > 0 {
|
||||||
|
expected := key * 100
|
||||||
|
if values[0] != expected {
|
||||||
|
wrongCount++
|
||||||
|
t.Logf("BUG DETECTED: key %d should have value %d, got %d", key, expected, values[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if wrongCount > 0 {
|
||||||
|
t.Errorf("Found %d tasks with wrong values due to closure bug", wrongCount)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkTimingWheel(b *testing.B) {
|
func BenchmarkTimingWheel(b *testing.B) {
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
|||||||
@@ -368,5 +368,5 @@ func getFullName(parent, child string) string {
|
|||||||
return child
|
return child
|
||||||
}
|
}
|
||||||
|
|
||||||
return strings.Join([]string{parent, child}, ".")
|
return parent + "." + child
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1377,3 +1377,23 @@ func (m mockConfig) Validate() error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetFullName(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
parent string
|
||||||
|
child string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"", "child", "child"},
|
||||||
|
{"parent", "child", "parent.child"},
|
||||||
|
{"a.b", "c", "a.b.c"},
|
||||||
|
{"root", "nested.field", "root.nested.field"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.parent+"."+tt.child, func(t *testing.T) {
|
||||||
|
got := getFullName(tt.parent, tt.child)
|
||||||
|
assert.Equal(t, tt.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
package subscriber
|
package subscriber
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/core/discov"
|
"github.com/zeromicro/go-zero/core/discov"
|
||||||
"github.com/zeromicro/go-zero/core/logx"
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
)
|
)
|
||||||
@@ -37,6 +40,7 @@ func NewEtcdSubscriber(conf EtcdConf) (Subscriber, error) {
|
|||||||
func buildSubOptions(conf EtcdConf) []discov.SubOption {
|
func buildSubOptions(conf EtcdConf) []discov.SubOption {
|
||||||
opts := []discov.SubOption{
|
opts := []discov.SubOption{
|
||||||
discov.WithExactMatch(),
|
discov.WithExactMatch(),
|
||||||
|
discov.WithContainer(newContainer()),
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(conf.User) > 0 {
|
if len(conf.User) > 0 {
|
||||||
@@ -65,3 +69,47 @@ func (s *etcdSubscriber) Value() (string, error) {
|
|||||||
|
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type container struct {
|
||||||
|
value atomic.Value
|
||||||
|
listeners []func()
|
||||||
|
lock sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func newContainer() *container {
|
||||||
|
return &container{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *container) OnAdd(kv discov.KV) {
|
||||||
|
c.value.Store([]string{kv.Val})
|
||||||
|
c.notifyChange()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *container) OnDelete(_ discov.KV) {
|
||||||
|
c.value.Store([]string(nil))
|
||||||
|
c.notifyChange()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *container) AddListener(listener func()) {
|
||||||
|
c.lock.Lock()
|
||||||
|
c.listeners = append(c.listeners, listener)
|
||||||
|
c.lock.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *container) GetValues() []string {
|
||||||
|
if vals, ok := c.value.Load().([]string); ok {
|
||||||
|
return vals
|
||||||
|
}
|
||||||
|
|
||||||
|
return []string(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *container) notifyChange() {
|
||||||
|
c.lock.Lock()
|
||||||
|
listeners := append(([]func())(nil), c.listeners...)
|
||||||
|
c.lock.Unlock()
|
||||||
|
|
||||||
|
for _, listener := range listeners {
|
||||||
|
listener()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
186
core/configcenter/subscriber/etcd_test.go
Normal file
186
core/configcenter/subscriber/etcd_test.go
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
package subscriber
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/zeromicro/go-zero/core/discov"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
actionAdd = iota
|
||||||
|
actionDel
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConfigCenterContainer(t *testing.T) {
|
||||||
|
type action struct {
|
||||||
|
act int
|
||||||
|
key string
|
||||||
|
val string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
do []action
|
||||||
|
expect []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "add one",
|
||||||
|
do: []action{
|
||||||
|
{
|
||||||
|
act: actionAdd,
|
||||||
|
key: "first",
|
||||||
|
val: "a",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expect: []string{
|
||||||
|
"a",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "add two",
|
||||||
|
do: []action{
|
||||||
|
{
|
||||||
|
act: actionAdd,
|
||||||
|
key: "first",
|
||||||
|
val: "a",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
act: actionAdd,
|
||||||
|
key: "second",
|
||||||
|
val: "b",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expect: []string{
|
||||||
|
"b",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "add two, delete one",
|
||||||
|
do: []action{
|
||||||
|
{
|
||||||
|
act: actionAdd,
|
||||||
|
key: "first",
|
||||||
|
val: "a",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
act: actionAdd,
|
||||||
|
key: "second",
|
||||||
|
val: "b",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
act: actionDel,
|
||||||
|
key: "first",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expect: []string(nil),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "add two, delete two",
|
||||||
|
do: []action{
|
||||||
|
{
|
||||||
|
act: actionAdd,
|
||||||
|
key: "first",
|
||||||
|
val: "a",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
act: actionAdd,
|
||||||
|
key: "second",
|
||||||
|
val: "b",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
act: actionDel,
|
||||||
|
key: "first",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
act: actionDel,
|
||||||
|
key: "second",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expect: []string(nil),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "add two, dup values",
|
||||||
|
do: []action{
|
||||||
|
{
|
||||||
|
act: actionAdd,
|
||||||
|
key: "first",
|
||||||
|
val: "a",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
act: actionAdd,
|
||||||
|
key: "second",
|
||||||
|
val: "b",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
act: actionAdd,
|
||||||
|
key: "third",
|
||||||
|
val: "a",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expect: []string{"a"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "add three, dup values, delete two, add one",
|
||||||
|
do: []action{
|
||||||
|
{
|
||||||
|
act: actionAdd,
|
||||||
|
key: "first",
|
||||||
|
val: "a",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
act: actionAdd,
|
||||||
|
key: "second",
|
||||||
|
val: "b",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
act: actionAdd,
|
||||||
|
key: "third",
|
||||||
|
val: "a",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
act: actionDel,
|
||||||
|
key: "first",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
act: actionDel,
|
||||||
|
key: "second",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
act: actionAdd,
|
||||||
|
key: "forth",
|
||||||
|
val: "c",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expect: []string{"c"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
var changed bool
|
||||||
|
c := newContainer()
|
||||||
|
c.AddListener(func() {
|
||||||
|
changed = true
|
||||||
|
})
|
||||||
|
assert.Nil(t, c.GetValues())
|
||||||
|
assert.False(t, changed)
|
||||||
|
|
||||||
|
for _, order := range test.do {
|
||||||
|
if order.act == actionAdd {
|
||||||
|
c.OnAdd(discov.KV{
|
||||||
|
Key: order.key,
|
||||||
|
Val: order.val,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
c.OnDelete(discov.KV{
|
||||||
|
Key: order.key,
|
||||||
|
Val: order.val,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.True(t, changed)
|
||||||
|
assert.ElementsMatch(t, test.expect, c.GetValues())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -386,8 +386,9 @@ func (c *cluster) watch(cli EtcdClient, key watchKey, rev int64) {
|
|||||||
rev = c.load(cli, key)
|
rev = c.load(cli, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
// log the error and retry
|
// log the error and retry with cooldown to prevent CPU/disk exhaustion
|
||||||
logc.Error(cli.Ctx(), err)
|
logc.Error(cli.Ctx(), err)
|
||||||
|
time.Sleep(coolDownUnstable.AroundDuration(coolDownInterval))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -19,8 +19,9 @@ type (
|
|||||||
exclusive bool
|
exclusive bool
|
||||||
key string
|
key string
|
||||||
exactMatch bool
|
exactMatch bool
|
||||||
items *container
|
items Container
|
||||||
}
|
}
|
||||||
|
KV = internal.KV
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewSubscriber returns a Subscriber.
|
// NewSubscriber returns a Subscriber.
|
||||||
@@ -35,7 +36,9 @@ func NewSubscriber(endpoints []string, key string, opts ...SubOption) (*Subscrib
|
|||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
opt(sub)
|
opt(sub)
|
||||||
}
|
}
|
||||||
sub.items = newContainer(sub.exclusive)
|
if sub.items == nil {
|
||||||
|
sub.items = newContainer(sub.exclusive)
|
||||||
|
}
|
||||||
|
|
||||||
if err := internal.GetRegistry().Monitor(endpoints, key, sub.exactMatch, sub.items); err != nil {
|
if err := internal.GetRegistry().Monitor(endpoints, key, sub.exactMatch, sub.items); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -46,7 +49,7 @@ func NewSubscriber(endpoints []string, key string, opts ...SubOption) (*Subscrib
|
|||||||
|
|
||||||
// AddListener adds listener to s.
|
// AddListener adds listener to s.
|
||||||
func (s *Subscriber) AddListener(listener func()) {
|
func (s *Subscriber) AddListener(listener func()) {
|
||||||
s.items.addListener(listener)
|
s.items.AddListener(listener)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close closes the subscriber.
|
// Close closes the subscriber.
|
||||||
@@ -56,7 +59,7 @@ func (s *Subscriber) Close() {
|
|||||||
|
|
||||||
// Values returns all the subscription values.
|
// Values returns all the subscription values.
|
||||||
func (s *Subscriber) Values() []string {
|
func (s *Subscriber) Values() []string {
|
||||||
return s.items.getValues()
|
return s.items.GetValues()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Exclusive means that key value can only be 1:1,
|
// Exclusive means that key value can only be 1:1,
|
||||||
@@ -88,16 +91,32 @@ func WithSubEtcdTLS(certFile, certKeyFile, caFile string, insecureSkipVerify boo
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type container struct {
|
// WithContainer provides a custom container to the subscriber.
|
||||||
exclusive bool
|
func WithContainer(container Container) SubOption {
|
||||||
values map[string][]string
|
return func(sub *Subscriber) {
|
||||||
mapping map[string]string
|
sub.items = container
|
||||||
snapshot atomic.Value
|
}
|
||||||
dirty *syncx.AtomicBool
|
|
||||||
listeners []func()
|
|
||||||
lock sync.Mutex
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type (
|
||||||
|
Container interface {
|
||||||
|
OnAdd(kv internal.KV)
|
||||||
|
OnDelete(kv internal.KV)
|
||||||
|
AddListener(listener func())
|
||||||
|
GetValues() []string
|
||||||
|
}
|
||||||
|
|
||||||
|
container struct {
|
||||||
|
exclusive bool
|
||||||
|
values map[string][]string
|
||||||
|
mapping map[string]string
|
||||||
|
snapshot atomic.Value
|
||||||
|
dirty *syncx.AtomicBool
|
||||||
|
listeners []func()
|
||||||
|
lock sync.Mutex
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
func newContainer(exclusive bool) *container {
|
func newContainer(exclusive bool) *container {
|
||||||
return &container{
|
return &container{
|
||||||
exclusive: exclusive,
|
exclusive: exclusive,
|
||||||
@@ -141,7 +160,7 @@ func (c *container) addKv(key, value string) ([]string, bool) {
|
|||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *container) addListener(listener func()) {
|
func (c *container) AddListener(listener func()) {
|
||||||
c.lock.Lock()
|
c.lock.Lock()
|
||||||
c.listeners = append(c.listeners, listener)
|
c.listeners = append(c.listeners, listener)
|
||||||
c.lock.Unlock()
|
c.lock.Unlock()
|
||||||
@@ -170,7 +189,7 @@ func (c *container) doRemoveKey(key string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *container) getValues() []string {
|
func (c *container) GetValues() []string {
|
||||||
if !c.dirty.True() {
|
if !c.dirty.True() {
|
||||||
return c.snapshot.Load().([]string)
|
return c.snapshot.Load().([]string)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -171,10 +171,10 @@ func TestContainer(t *testing.T) {
|
|||||||
t.Run(test.name, func(t *testing.T) {
|
t.Run(test.name, func(t *testing.T) {
|
||||||
var changed bool
|
var changed bool
|
||||||
c := newContainer(exclusive)
|
c := newContainer(exclusive)
|
||||||
c.addListener(func() {
|
c.AddListener(func() {
|
||||||
changed = true
|
changed = true
|
||||||
})
|
})
|
||||||
assert.Nil(t, c.getValues())
|
assert.Nil(t, c.GetValues())
|
||||||
assert.False(t, changed)
|
assert.False(t, changed)
|
||||||
|
|
||||||
for _, order := range test.do {
|
for _, order := range test.do {
|
||||||
@@ -193,9 +193,9 @@ func TestContainer(t *testing.T) {
|
|||||||
|
|
||||||
assert.True(t, changed)
|
assert.True(t, changed)
|
||||||
assert.True(t, c.dirty.True())
|
assert.True(t, c.dirty.True())
|
||||||
assert.ElementsMatch(t, test.expect, c.getValues())
|
assert.ElementsMatch(t, test.expect, c.GetValues())
|
||||||
assert.False(t, c.dirty.True())
|
assert.False(t, c.dirty.True())
|
||||||
assert.ElementsMatch(t, test.expect, c.getValues())
|
assert.ElementsMatch(t, test.expect, c.GetValues())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -204,12 +204,14 @@ func TestContainer(t *testing.T) {
|
|||||||
func TestSubscriber(t *testing.T) {
|
func TestSubscriber(t *testing.T) {
|
||||||
sub := new(Subscriber)
|
sub := new(Subscriber)
|
||||||
Exclusive()(sub)
|
Exclusive()(sub)
|
||||||
sub.items = newContainer(sub.exclusive)
|
c := newContainer(sub.exclusive)
|
||||||
|
WithContainer(c)(sub)
|
||||||
|
sub.items = c
|
||||||
var count int32
|
var count int32
|
||||||
sub.AddListener(func() {
|
sub.AddListener(func() {
|
||||||
atomic.AddInt32(&count, 1)
|
atomic.AddInt32(&count, 1)
|
||||||
})
|
})
|
||||||
sub.items.notifyChange()
|
c.notifyChange()
|
||||||
assert.Empty(t, sub.Values())
|
assert.Empty(t, sub.Values())
|
||||||
assert.Equal(t, int32(1), atomic.LoadInt32(&count))
|
assert.Equal(t, int32(1), atomic.LoadInt32(&count))
|
||||||
}
|
}
|
||||||
@@ -229,12 +231,13 @@ func TestWithSubEtcdAccount(t *testing.T) {
|
|||||||
func TestWithExactMatch(t *testing.T) {
|
func TestWithExactMatch(t *testing.T) {
|
||||||
sub := new(Subscriber)
|
sub := new(Subscriber)
|
||||||
WithExactMatch()(sub)
|
WithExactMatch()(sub)
|
||||||
sub.items = newContainer(sub.exclusive)
|
c := newContainer(sub.exclusive)
|
||||||
|
sub.items = c
|
||||||
var count int32
|
var count int32
|
||||||
sub.AddListener(func() {
|
sub.AddListener(func() {
|
||||||
atomic.AddInt32(&count, 1)
|
atomic.AddInt32(&count, 1)
|
||||||
})
|
})
|
||||||
sub.items.notifyChange()
|
c.notifyChange()
|
||||||
assert.Empty(t, sub.Values())
|
assert.Empty(t, sub.Values())
|
||||||
assert.Equal(t, int32(1), atomic.LoadInt32(&count))
|
assert.Equal(t, int32(1), atomic.LoadInt32(&count))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -168,7 +168,7 @@ func (s Stream) Count() (count int) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Distinct removes the duplicated items base on the given KeyFunc.
|
// Distinct removes the duplicated items based on the given KeyFunc.
|
||||||
func (s Stream) Distinct(fn KeyFunc) Stream {
|
func (s Stream) Distinct(fn KeyFunc) Stream {
|
||||||
source := make(chan any)
|
source := make(chan any)
|
||||||
|
|
||||||
@@ -459,7 +459,7 @@ func (s Stream) Tail(n int64) Stream {
|
|||||||
return Range(source)
|
return Range(source)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Walk lets the callers handle each item, the caller may write zero, one or more items base on the given item.
|
// Walk lets the callers handle each item, the caller may write zero, one or more items based on the given item.
|
||||||
func (s Stream) Walk(fn WalkFunc, opts ...Option) Stream {
|
func (s Stream) Walk(fn WalkFunc, opts ...Option) Stream {
|
||||||
option := buildOptions(opts...)
|
option := buildOptions(opts...)
|
||||||
if option.unlimitedWorkers {
|
if option.unlimitedWorkers {
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
package fx
|
package fx
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"io"
|
|
||||||
"log"
|
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"reflect"
|
"reflect"
|
||||||
"runtime"
|
"runtime"
|
||||||
@@ -13,6 +11,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/zeromicro/go-zero/core/logx/logtest"
|
||||||
"github.com/zeromicro/go-zero/core/stringx"
|
"github.com/zeromicro/go-zero/core/stringx"
|
||||||
"go.uber.org/goleak"
|
"go.uber.org/goleak"
|
||||||
)
|
)
|
||||||
@@ -238,7 +237,7 @@ func TestLast(t *testing.T) {
|
|||||||
|
|
||||||
func TestMap(t *testing.T) {
|
func TestMap(t *testing.T) {
|
||||||
runCheckedTest(t, func(t *testing.T) {
|
runCheckedTest(t, func(t *testing.T) {
|
||||||
log.SetOutput(io.Discard)
|
logtest.Discard(t)
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
mapper MapFunc
|
mapper MapFunc
|
||||||
|
|||||||
@@ -96,7 +96,7 @@ func (h *ConsistentHash) AddWithWeight(node any, weight int) {
|
|||||||
h.AddWithReplicas(node, replicas)
|
h.AddWithReplicas(node, replicas)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get returns the corresponding node from h base on the given v.
|
// Get returns the corresponding node from h based on the given v.
|
||||||
func (h *ConsistentHash) Get(v any) (any, bool) {
|
func (h *ConsistentHash) Get(v any) (any, bool) {
|
||||||
h.lock.RLock()
|
h.lock.RLock()
|
||||||
defer h.lock.RUnlock()
|
defer h.lock.RUnlock()
|
||||||
|
|||||||
@@ -1,47 +1,70 @@
|
|||||||
package logx
|
package logx
|
||||||
|
|
||||||
// A LogConf is a logging config.
|
type (
|
||||||
type LogConf struct {
|
// A LogConf is a logging config.
|
||||||
// ServiceName represents the service name.
|
LogConf struct {
|
||||||
ServiceName string `json:",optional"`
|
// ServiceName represents the service name.
|
||||||
// Mode represents the logging mode, default is `console`.
|
ServiceName string `json:",optional"`
|
||||||
// console: log to console.
|
// Mode represents the logging mode, default is `console`.
|
||||||
// file: log to file.
|
// console: log to console.
|
||||||
// volume: used in k8s, prepend the hostname to the log file name.
|
// file: log to file.
|
||||||
Mode string `json:",default=console,options=[console,file,volume]"`
|
// volume: used in k8s, prepend the hostname to the log file name.
|
||||||
// Encoding represents the encoding type, default is `json`.
|
Mode string `json:",default=console,options=[console,file,volume]"`
|
||||||
// json: json encoding.
|
// Encoding represents the encoding type, default is `json`.
|
||||||
// plain: plain text encoding, typically used in development.
|
// json: json encoding.
|
||||||
Encoding string `json:",default=json,options=[json,plain]"`
|
// plain: plain text encoding, typically used in development.
|
||||||
// TimeFormat represents the time format, default is `2006-01-02T15:04:05.000Z07:00`.
|
Encoding string `json:",default=json,options=[json,plain]"`
|
||||||
TimeFormat string `json:",optional"`
|
// TimeFormat represents the time format, default is `2006-01-02T15:04:05.000Z07:00`.
|
||||||
// Path represents the log file path, default is `logs`.
|
TimeFormat string `json:",optional"`
|
||||||
Path string `json:",default=logs"`
|
// Path represents the log file path, default is `logs`.
|
||||||
// Level represents the log level, default is `info`.
|
Path string `json:",default=logs"`
|
||||||
Level string `json:",default=info,options=[debug,info,error,severe]"`
|
// Level represents the log level, default is `info`.
|
||||||
// MaxContentLength represents the max content bytes, default is no limit.
|
Level string `json:",default=info,options=[debug,info,error,severe]"`
|
||||||
MaxContentLength uint32 `json:",optional"`
|
// MaxContentLength represents the max content bytes, default is no limit.
|
||||||
// Compress represents whether to compress the log file, default is `false`.
|
MaxContentLength uint32 `json:",optional"`
|
||||||
Compress bool `json:",optional"`
|
// Compress represents whether to compress the log file, default is `false`.
|
||||||
// Stat represents whether to log statistics, default is `true`.
|
Compress bool `json:",optional"`
|
||||||
Stat bool `json:",default=true"`
|
// Stat represents whether to log statistics, default is `true`.
|
||||||
// KeepDays represents how many days the log files will be kept. Default to keep all files.
|
Stat bool `json:",default=true"`
|
||||||
// Only take effect when Mode is `file` or `volume`, both work when Rotation is `daily` or `size`.
|
// KeepDays represents how many days the log files will be kept. Default to keep all files.
|
||||||
KeepDays int `json:",optional"`
|
// Only take effect when Mode is `file` or `volume`, both work when Rotation is `daily` or `size`.
|
||||||
// StackCooldownMillis represents the cooldown time for stack logging, default is 100ms.
|
KeepDays int `json:",optional"`
|
||||||
StackCooldownMillis int `json:",default=100"`
|
// StackCooldownMillis represents the cooldown time for stack logging, default is 100ms.
|
||||||
// MaxBackups represents how many backup log files will be kept. 0 means all files will be kept forever.
|
StackCooldownMillis int `json:",default=100"`
|
||||||
// Only take effect when RotationRuleType is `size`.
|
// MaxBackups represents how many backup log files will be kept. 0 means all files will be kept forever.
|
||||||
// Even though `MaxBackups` sets 0, log files will still be removed
|
// Only take effect when RotationRuleType is `size`.
|
||||||
// if the `KeepDays` limitation is reached.
|
// Even though `MaxBackups` sets 0, log files will still be removed
|
||||||
MaxBackups int `json:",default=0"`
|
// if the `KeepDays` limitation is reached.
|
||||||
// MaxSize represents how much space the writing log file takes up. 0 means no limit. The unit is `MB`.
|
MaxBackups int `json:",default=0"`
|
||||||
// Only take effect when RotationRuleType is `size`
|
// MaxSize represents how much space the writing log file takes up. 0 means no limit. The unit is `MB`.
|
||||||
MaxSize int `json:",default=0"`
|
// Only take effect when RotationRuleType is `size`
|
||||||
// Rotation represents the type of log rotation rule. Default is `daily`.
|
MaxSize int `json:",default=0"`
|
||||||
// daily: daily rotation.
|
// Rotation represents the type of log rotation rule. Default is `daily`.
|
||||||
// size: size limited rotation.
|
// daily: daily rotation.
|
||||||
Rotation string `json:",default=daily,options=[daily,size]"`
|
// size: size limited rotation.
|
||||||
// FileTimeFormat represents the time format for file name, default is `2006-01-02T15:04:05.000Z07:00`.
|
Rotation string `json:",default=daily,options=[daily,size]"`
|
||||||
FileTimeFormat string `json:",optional"`
|
// FileTimeFormat represents the time format for file name, default is `2006-01-02T15:04:05.000Z07:00`.
|
||||||
}
|
FileTimeFormat string `json:",optional"`
|
||||||
|
// FieldKeys represents the field keys.
|
||||||
|
FieldKeys fieldKeyConf `json:",optional"`
|
||||||
|
}
|
||||||
|
|
||||||
|
fieldKeyConf struct {
|
||||||
|
// CallerKey represents the caller key.
|
||||||
|
CallerKey string `json:",default=caller"`
|
||||||
|
// ContentKey represents the content key.
|
||||||
|
ContentKey string `json:",default=content"`
|
||||||
|
// DurationKey represents the duration key.
|
||||||
|
DurationKey string `json:",default=duration"`
|
||||||
|
// LevelKey represents the level key.
|
||||||
|
LevelKey string `json:",default=level"`
|
||||||
|
// SpanKey represents the span key.
|
||||||
|
SpanKey string `json:",default=span"`
|
||||||
|
// TimestampKey represents the timestamp key.
|
||||||
|
TimestampKey string `json:",default=@timestamp"`
|
||||||
|
// TraceKey represents the trace key.
|
||||||
|
TraceKey string `json:",default=trace"`
|
||||||
|
// TruncatedKey represents the truncated key.
|
||||||
|
TruncatedKey string `json:",default=truncated"`
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|||||||
@@ -276,7 +276,8 @@ func SetUp(c LogConf) (err error) {
|
|||||||
// Because multiple services in one process might call SetUp respectively.
|
// Because multiple services in one process might call SetUp respectively.
|
||||||
// Need to wait for the first caller to complete the execution.
|
// Need to wait for the first caller to complete the execution.
|
||||||
setupOnce.Do(func() {
|
setupOnce.Do(func() {
|
||||||
setupLogLevel(c)
|
setupLogLevel(c.Level)
|
||||||
|
setupFieldKeys(c.FieldKeys)
|
||||||
|
|
||||||
if !c.Stat {
|
if !c.Stat {
|
||||||
DisableStat()
|
DisableStat()
|
||||||
@@ -480,8 +481,35 @@ func handleOptions(opts []LogOption) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupLogLevel(c LogConf) {
|
func setupFieldKeys(c fieldKeyConf) {
|
||||||
switch c.Level {
|
if len(c.CallerKey) > 0 {
|
||||||
|
callerKey = c.CallerKey
|
||||||
|
}
|
||||||
|
if len(c.ContentKey) > 0 {
|
||||||
|
contentKey = c.ContentKey
|
||||||
|
}
|
||||||
|
if len(c.DurationKey) > 0 {
|
||||||
|
durationKey = c.DurationKey
|
||||||
|
}
|
||||||
|
if len(c.LevelKey) > 0 {
|
||||||
|
levelKey = c.LevelKey
|
||||||
|
}
|
||||||
|
if len(c.SpanKey) > 0 {
|
||||||
|
spanKey = c.SpanKey
|
||||||
|
}
|
||||||
|
if len(c.TimestampKey) > 0 {
|
||||||
|
timestampKey = c.TimestampKey
|
||||||
|
}
|
||||||
|
if len(c.TraceKey) > 0 {
|
||||||
|
traceKey = c.TraceKey
|
||||||
|
}
|
||||||
|
if len(c.TruncatedKey) > 0 {
|
||||||
|
truncatedKey = c.TruncatedKey
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupLogLevel(level string) {
|
||||||
|
switch level {
|
||||||
case levelDebug:
|
case levelDebug:
|
||||||
SetLevel(DebugLevel)
|
SetLevel(DebugLevel)
|
||||||
case levelInfo:
|
case levelInfo:
|
||||||
|
|||||||
@@ -17,6 +17,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"go.opentelemetry.io/otel"
|
||||||
|
"go.opentelemetry.io/otel/sdk/trace"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -777,15 +779,9 @@ func TestSetup(t *testing.T) {
|
|||||||
MaxBackups: 3,
|
MaxBackups: 3,
|
||||||
MaxSize: 1024 * 1024,
|
MaxSize: 1024 * 1024,
|
||||||
}))
|
}))
|
||||||
setupLogLevel(LogConf{
|
setupLogLevel(levelInfo)
|
||||||
Level: levelInfo,
|
setupLogLevel(levelError)
|
||||||
})
|
setupLogLevel(levelSevere)
|
||||||
setupLogLevel(LogConf{
|
|
||||||
Level: levelError,
|
|
||||||
})
|
|
||||||
setupLogLevel(LogConf{
|
|
||||||
Level: levelSevere,
|
|
||||||
})
|
|
||||||
_, err := createOutput("")
|
_, err := createOutput("")
|
||||||
assert.NotNil(t, err)
|
assert.NotNil(t, err)
|
||||||
Disable()
|
Disable()
|
||||||
@@ -1157,3 +1153,66 @@ func (s *countingStringer) String() string {
|
|||||||
atomic.AddInt32(&s.count, 1)
|
atomic.AddInt32(&s.count, 1)
|
||||||
return "countingStringer"
|
return "countingStringer"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestLogKey(t *testing.T) {
|
||||||
|
setupOnce = sync.Once{}
|
||||||
|
MustSetup(LogConf{
|
||||||
|
ServiceName: "any",
|
||||||
|
Mode: "console",
|
||||||
|
Encoding: "json",
|
||||||
|
TimeFormat: timeFormat,
|
||||||
|
FieldKeys: fieldKeyConf{
|
||||||
|
CallerKey: "_caller",
|
||||||
|
ContentKey: "_content",
|
||||||
|
DurationKey: "_duration",
|
||||||
|
LevelKey: "_level",
|
||||||
|
SpanKey: "_span",
|
||||||
|
TimestampKey: "_timestamp",
|
||||||
|
TraceKey: "_trace",
|
||||||
|
TruncatedKey: "_truncated",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
setupFieldKeys(fieldKeyConf{
|
||||||
|
CallerKey: defaultCallerKey,
|
||||||
|
ContentKey: defaultContentKey,
|
||||||
|
DurationKey: defaultDurationKey,
|
||||||
|
LevelKey: defaultLevelKey,
|
||||||
|
SpanKey: defaultSpanKey,
|
||||||
|
TimestampKey: defaultTimestampKey,
|
||||||
|
TraceKey: defaultTraceKey,
|
||||||
|
TruncatedKey: defaultTruncatedKey,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
const message = "hello there"
|
||||||
|
w := new(mockWriter)
|
||||||
|
old := writer.Swap(w)
|
||||||
|
defer writer.Store(old)
|
||||||
|
|
||||||
|
otp := otel.GetTracerProvider()
|
||||||
|
tp := trace.NewTracerProvider(trace.WithSampler(trace.AlwaysSample()))
|
||||||
|
otel.SetTracerProvider(tp)
|
||||||
|
defer otel.SetTracerProvider(otp)
|
||||||
|
|
||||||
|
ctx, span := tp.Tracer("trace-id").Start(context.Background(), "span-id")
|
||||||
|
defer span.End()
|
||||||
|
|
||||||
|
WithContext(ctx).WithDuration(time.Second).Info(message)
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
var m map[string]string
|
||||||
|
if err := json.Unmarshal([]byte(w.String()), &m); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
assert.Equal(t, "info", m["_level"])
|
||||||
|
assert.Equal(t, message, m["_content"])
|
||||||
|
assert.Equal(t, "1000.0ms", m["_duration"])
|
||||||
|
assert.Regexp(t, `logx/logs_test.go:\d+`, m["_caller"])
|
||||||
|
assert.NotEmpty(t, m["_trace"])
|
||||||
|
assert.NotEmpty(t, m["_span"])
|
||||||
|
parsedTime, err := time.Parse(timeFormat, m["_timestamp"])
|
||||||
|
assert.True(t, err == nil)
|
||||||
|
assert.Equal(t, now.Minute(), parsedTime.Minute())
|
||||||
|
}
|
||||||
|
|||||||
@@ -423,3 +423,49 @@ type mockValue struct {
|
|||||||
Foo string `json:"foo"`
|
Foo string `json:"foo"`
|
||||||
Content any `json:"content"`
|
Content any `json:"content"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type testJson struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Age int `json:"age"`
|
||||||
|
Score float64 `json:"score"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t testJson) MarshalJSON() ([]byte, error) {
|
||||||
|
type testJsonImpl testJson
|
||||||
|
return json.Marshal(testJsonImpl(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t testJson) String() string {
|
||||||
|
return fmt.Sprintf("%s %d %f", t.Name, t.Age, t.Score)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogWithJson(t *testing.T) {
|
||||||
|
w := new(mockWriter)
|
||||||
|
old := writer.Swap(w)
|
||||||
|
writer.lock.RLock()
|
||||||
|
defer func() {
|
||||||
|
writer.lock.RUnlock()
|
||||||
|
writer.Store(old)
|
||||||
|
}()
|
||||||
|
|
||||||
|
l := WithContext(context.Background()).WithFields(Field("bar", testJson{
|
||||||
|
Name: "foo",
|
||||||
|
Age: 1,
|
||||||
|
Score: 1.0,
|
||||||
|
}))
|
||||||
|
l.Info(testlog)
|
||||||
|
|
||||||
|
type mockValue2 struct {
|
||||||
|
mockValue
|
||||||
|
Bar testJson `json:"bar"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var val mockValue2
|
||||||
|
err := json.Unmarshal([]byte(w.String()), &val)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, testlog, val.Content)
|
||||||
|
assert.Equal(t, "foo", val.Bar.Name)
|
||||||
|
assert.Equal(t, 1, val.Bar.Age)
|
||||||
|
assert.Equal(t, 1.0, val.Bar.Score)
|
||||||
|
}
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ type (
|
|||||||
gzip bool
|
gzip bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// SizeLimitRotateRule a rotation rule that make the log file rotated base on size
|
// SizeLimitRotateRule a rotation rule that makes the log file rotated based on size
|
||||||
SizeLimitRotateRule struct {
|
SizeLimitRotateRule struct {
|
||||||
DailyRotateRule
|
DailyRotateRule
|
||||||
maxSize int64
|
maxSize int64
|
||||||
|
|||||||
@@ -53,14 +53,14 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
callerKey = "caller"
|
defaultCallerKey = "caller"
|
||||||
contentKey = "content"
|
defaultContentKey = "content"
|
||||||
durationKey = "duration"
|
defaultDurationKey = "duration"
|
||||||
levelKey = "level"
|
defaultLevelKey = "level"
|
||||||
spanKey = "span"
|
defaultSpanKey = "span"
|
||||||
timestampKey = "@timestamp"
|
defaultTimestampKey = "@timestamp"
|
||||||
traceKey = "trace"
|
defaultTraceKey = "trace"
|
||||||
truncatedKey = "truncated"
|
defaultTruncatedKey = "truncated"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -73,3 +73,14 @@ var (
|
|||||||
|
|
||||||
truncatedField = Field(truncatedKey, true)
|
truncatedField = Field(truncatedKey, true)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
callerKey = defaultCallerKey
|
||||||
|
contentKey = defaultContentKey
|
||||||
|
durationKey = defaultDurationKey
|
||||||
|
levelKey = defaultLevelKey
|
||||||
|
spanKey = defaultSpanKey
|
||||||
|
timestampKey = defaultTimestampKey
|
||||||
|
traceKey = defaultTraceKey
|
||||||
|
truncatedKey = defaultTruncatedKey
|
||||||
|
)
|
||||||
|
|||||||
@@ -212,7 +212,6 @@ func newFileWriter(c LogConf) (Writer, error) {
|
|||||||
statFile := path.Join(c.Path, statFilename)
|
statFile := path.Join(c.Path, statFilename)
|
||||||
|
|
||||||
handleOptions(opts)
|
handleOptions(opts)
|
||||||
setupLogLevel(c)
|
|
||||||
|
|
||||||
if infoLog, err = createOutput(accessFile); err != nil {
|
if infoLog, err = createOutput(accessFile); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -423,6 +422,8 @@ func processFieldValue(value any) any {
|
|||||||
times = append(times, fmt.Sprint(t))
|
times = append(times, fmt.Sprint(t))
|
||||||
}
|
}
|
||||||
return times
|
return times
|
||||||
|
case json.Marshaler:
|
||||||
|
return val
|
||||||
case fmt.Stringer:
|
case fmt.Stringer:
|
||||||
return encodeStringer(val)
|
return encodeStringer(val)
|
||||||
case []fmt.Stringer:
|
case []fmt.Stringer:
|
||||||
@@ -443,6 +444,8 @@ func wrapLevelWithColor(level string) string {
|
|||||||
colour = color.FgRed
|
colour = color.FgRed
|
||||||
case levelError:
|
case levelError:
|
||||||
colour = color.FgRed
|
colour = color.FgRed
|
||||||
|
case levelSevere:
|
||||||
|
colour = color.FgRed
|
||||||
case levelFatal:
|
case levelFatal:
|
||||||
colour = color.FgRed
|
colour = color.FgRed
|
||||||
case levelInfo:
|
case levelInfo:
|
||||||
|
|||||||
@@ -104,14 +104,13 @@ func convertToString(val any, fullName string) (string, error) {
|
|||||||
func convertTypeFromString(kind reflect.Kind, str string) (any, error) {
|
func convertTypeFromString(kind reflect.Kind, str string) (any, error) {
|
||||||
switch kind {
|
switch kind {
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
switch strings.ToLower(str) {
|
if str == "1" || strings.EqualFold(str, "true") {
|
||||||
case "1", "true":
|
|
||||||
return true, nil
|
return true, nil
|
||||||
case "0", "false":
|
|
||||||
return false, nil
|
|
||||||
default:
|
|
||||||
return false, errTypeMismatch
|
|
||||||
}
|
}
|
||||||
|
if str == "0" || strings.EqualFold(str, "false") {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return false, errTypeMismatch
|
||||||
case reflect.Int:
|
case reflect.Int:
|
||||||
return strconv.ParseInt(str, 10, intSize)
|
return strconv.ParseInt(str, 10, intSize)
|
||||||
case reflect.Int8:
|
case reflect.Int8:
|
||||||
|
|||||||
@@ -334,3 +334,43 @@ func TestValidateValueRange(t *testing.T) {
|
|||||||
func TestSetMatchedPrimitiveValue(t *testing.T) {
|
func TestSetMatchedPrimitiveValue(t *testing.T) {
|
||||||
assert.Error(t, setMatchedPrimitiveValue(reflect.Func, reflect.ValueOf(2), "1"))
|
assert.Error(t, setMatchedPrimitiveValue(reflect.Func, reflect.ValueOf(2), "1"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConvertTypeFromString_Bool(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
want bool
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
// true cases
|
||||||
|
{name: "1", input: "1", want: true, wantErr: false},
|
||||||
|
{name: "true lowercase", input: "true", want: true, wantErr: false},
|
||||||
|
{name: "True mixed", input: "True", want: true, wantErr: false},
|
||||||
|
{name: "TRUE uppercase", input: "TRUE", want: true, wantErr: false},
|
||||||
|
{name: "TrUe mixed", input: "TrUe", want: true, wantErr: false},
|
||||||
|
// false cases
|
||||||
|
{name: "0", input: "0", want: false, wantErr: false},
|
||||||
|
{name: "false lowercase", input: "false", want: false, wantErr: false},
|
||||||
|
{name: "False mixed", input: "False", want: false, wantErr: false},
|
||||||
|
{name: "FALSE uppercase", input: "FALSE", want: false, wantErr: false},
|
||||||
|
{name: "FaLsE mixed", input: "FaLsE", want: false, wantErr: false},
|
||||||
|
// error cases
|
||||||
|
{name: "invalid yes", input: "yes", want: false, wantErr: true},
|
||||||
|
{name: "invalid no", input: "no", want: false, wantErr: true},
|
||||||
|
{name: "invalid empty", input: "", want: false, wantErr: true},
|
||||||
|
{name: "invalid 2", input: "2", want: false, wantErr: true},
|
||||||
|
{name: "invalid truee", input: "truee", want: false, wantErr: true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := convertTypeFromString(reflect.Bool, tt.input)
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.want, got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// An Unstable is used to generate random value around the mean value base on given deviation.
|
// An Unstable is used to generate random value around the mean value based on given deviation.
|
||||||
type Unstable struct {
|
type Unstable struct {
|
||||||
deviation float64
|
deviation float64
|
||||||
r *rand.Rand
|
r *rand.Rand
|
||||||
|
|||||||
@@ -3,6 +3,9 @@ package mr
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"runtime/debug"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
@@ -183,12 +186,16 @@ func buildOptions(opts ...Option) *mapReduceOptions {
|
|||||||
return options
|
return options
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func buildPanicInfo(r any, stack []byte) string {
|
||||||
|
return fmt.Sprintf("%+v\n\n%s", r, strings.TrimSpace(string(stack)))
|
||||||
|
}
|
||||||
|
|
||||||
func buildSource[T any](generate GenerateFunc[T], panicChan *onceChan) chan T {
|
func buildSource[T any](generate GenerateFunc[T], panicChan *onceChan) chan T {
|
||||||
source := make(chan T)
|
source := make(chan T)
|
||||||
go func() {
|
go func() {
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
panicChan.write(r)
|
panicChan.write(buildPanicInfo(r, debug.Stack()))
|
||||||
}
|
}
|
||||||
close(source)
|
close(source)
|
||||||
}()
|
}()
|
||||||
@@ -235,7 +242,7 @@ func executeMappers[T, U any](mCtx mapperContext[T, U]) {
|
|||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
atomic.AddInt32(&failed, 1)
|
atomic.AddInt32(&failed, 1)
|
||||||
mCtx.panicChan.write(r)
|
mCtx.panicChan.write(buildPanicInfo(r, debug.Stack()))
|
||||||
}
|
}
|
||||||
wg.Done()
|
wg.Done()
|
||||||
<-pool
|
<-pool
|
||||||
@@ -289,7 +296,7 @@ func mapReduceWithPanicChan[T, U, V any](source <-chan T, panicChan *onceChan, m
|
|||||||
defer func() {
|
defer func() {
|
||||||
drain(collector)
|
drain(collector)
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
panicChan.write(r)
|
panicChan.write(buildPanicInfo(r, debug.Stack()))
|
||||||
}
|
}
|
||||||
finish()
|
finish()
|
||||||
}()
|
}()
|
||||||
|
|||||||
@@ -3,8 +3,7 @@ package mr
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"fmt"
|
||||||
"log"
|
|
||||||
"runtime"
|
"runtime"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -16,9 +15,6 @@ import (
|
|||||||
|
|
||||||
var errDummy = errors.New("dummy")
|
var errDummy = errors.New("dummy")
|
||||||
|
|
||||||
func init() {
|
|
||||||
log.SetOutput(io.Discard)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFinish(t *testing.T) {
|
func TestFinish(t *testing.T) {
|
||||||
defer goleak.VerifyNone(t)
|
defer goleak.VerifyNone(t)
|
||||||
@@ -148,11 +144,28 @@ func TestForEach(t *testing.T) {
|
|||||||
|
|
||||||
assert.Equal(t, tasks/2, int(count))
|
assert.Equal(t, tasks/2, int(count))
|
||||||
})
|
})
|
||||||
|
}
|
||||||
|
|
||||||
t.Run("all", func(t *testing.T) {
|
func TestPanics(t *testing.T) {
|
||||||
defer goleak.VerifyNone(t)
|
defer goleak.VerifyNone(t)
|
||||||
|
|
||||||
|
const tasks = 1000
|
||||||
|
verify := func(t *testing.T, r any) {
|
||||||
|
panicStr := fmt.Sprintf("%v", r)
|
||||||
|
assert.Contains(t, panicStr, "foo")
|
||||||
|
assert.Contains(t, panicStr, "goroutine")
|
||||||
|
assert.Contains(t, panicStr, "runtime/debug.Stack")
|
||||||
|
panic(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("ForEach run panics", func(t *testing.T) {
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
verify(t, r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
assert.PanicsWithValue(t, "foo", func() {
|
|
||||||
ForEach(func(source chan<- int) {
|
ForEach(func(source chan<- int) {
|
||||||
for i := 0; i < tasks; i++ {
|
for i := 0; i < tasks; i++ {
|
||||||
source <- i
|
source <- i
|
||||||
@@ -162,28 +175,31 @@ func TestForEach(t *testing.T) {
|
|||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
|
||||||
|
|
||||||
func TestGeneratePanic(t *testing.T) {
|
t.Run("ForEach generate panics", func(t *testing.T) {
|
||||||
defer goleak.VerifyNone(t)
|
assert.Panics(t, func() {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
verify(t, r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
t.Run("all", func(t *testing.T) {
|
|
||||||
assert.PanicsWithValue(t, "foo", func() {
|
|
||||||
ForEach(func(source chan<- int) {
|
ForEach(func(source chan<- int) {
|
||||||
panic("foo")
|
panic("foo")
|
||||||
}, func(item int) {
|
}, func(item int) {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
|
||||||
|
|
||||||
func TestMapperPanic(t *testing.T) {
|
|
||||||
defer goleak.VerifyNone(t)
|
|
||||||
|
|
||||||
const tasks = 1000
|
|
||||||
var run int32
|
var run int32
|
||||||
t.Run("all", func(t *testing.T) {
|
t.Run("Mapper panics", func(t *testing.T) {
|
||||||
assert.PanicsWithValue(t, "foo", func() {
|
assert.Panics(t, func() {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
verify(t, r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
_, _ = MapReduce(func(source chan<- int) {
|
_, _ = MapReduce(func(source chan<- int) {
|
||||||
for i := 0; i < tasks; i++ {
|
for i := 0; i < tasks; i++ {
|
||||||
source <- i
|
source <- i
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"runtime/debug"
|
||||||
|
"runtime/metrics"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -28,10 +30,29 @@ func displayStatsWithWriter(writer io.Writer, interval ...time.Duration) {
|
|||||||
ticker := time.NewTicker(duration)
|
ticker := time.NewTicker(duration)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
for range ticker.C {
|
for range ticker.C {
|
||||||
var m runtime.MemStats
|
var (
|
||||||
runtime.ReadMemStats(&m)
|
alloc, totalAlloc, sys uint64
|
||||||
|
samples = []metrics.Sample{
|
||||||
|
{Name: "/memory/classes/heap/objects:bytes"},
|
||||||
|
{Name: "/gc/heap/allocs:bytes"},
|
||||||
|
{Name: "/memory/classes/total:bytes"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
metrics.Read(samples)
|
||||||
|
|
||||||
|
if samples[0].Value.Kind() == metrics.KindUint64 {
|
||||||
|
alloc = samples[0].Value.Uint64()
|
||||||
|
}
|
||||||
|
if samples[1].Value.Kind() == metrics.KindUint64 {
|
||||||
|
totalAlloc = samples[1].Value.Uint64()
|
||||||
|
}
|
||||||
|
if samples[2].Value.Kind() == metrics.KindUint64 {
|
||||||
|
sys = samples[2].Value.Uint64()
|
||||||
|
}
|
||||||
|
var stats debug.GCStats
|
||||||
|
debug.ReadGCStats(&stats)
|
||||||
fmt.Fprintf(writer, "Goroutines: %d, Alloc: %vm, TotalAlloc: %vm, Sys: %vm, NumGC: %v\n",
|
fmt.Fprintf(writer, "Goroutines: %d, Alloc: %vm, TotalAlloc: %vm, Sys: %vm, NumGC: %v\n",
|
||||||
runtime.NumGoroutine(), m.Alloc/mega, m.TotalAlloc/mega, m.Sys/mega, m.NumGC)
|
runtime.NumGoroutine(), alloc/mega, totalAlloc/mega, sys/mega, stats.NumGC)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
package stat
|
package stat
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"runtime"
|
"runtime/debug"
|
||||||
|
"runtime/metrics"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -56,8 +57,28 @@ func bToMb(b uint64) float32 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func printUsage() {
|
func printUsage() {
|
||||||
var m runtime.MemStats
|
var (
|
||||||
runtime.ReadMemStats(&m)
|
alloc, totalAlloc, sys uint64
|
||||||
|
samples = []metrics.Sample{
|
||||||
|
{Name: "/memory/classes/heap/objects:bytes"},
|
||||||
|
{Name: "/gc/heap/allocs:bytes"},
|
||||||
|
{Name: "/memory/classes/total:bytes"},
|
||||||
|
}
|
||||||
|
stats debug.GCStats
|
||||||
|
)
|
||||||
|
metrics.Read(samples)
|
||||||
|
|
||||||
|
if samples[0].Value.Kind() == metrics.KindUint64 {
|
||||||
|
alloc = samples[0].Value.Uint64()
|
||||||
|
}
|
||||||
|
if samples[1].Value.Kind() == metrics.KindUint64 {
|
||||||
|
totalAlloc = samples[1].Value.Uint64()
|
||||||
|
}
|
||||||
|
if samples[2].Value.Kind() == metrics.KindUint64 {
|
||||||
|
sys = samples[2].Value.Uint64()
|
||||||
|
}
|
||||||
|
debug.ReadGCStats(&stats)
|
||||||
|
|
||||||
logx.Statf("CPU: %dm, MEMORY: Alloc=%.1fMi, TotalAlloc=%.1fMi, Sys=%.1fMi, NumGC=%d",
|
logx.Statf("CPU: %dm, MEMORY: Alloc=%.1fMi, TotalAlloc=%.1fMi, Sys=%.1fMi, NumGC=%d",
|
||||||
CpuUsage(), bToMb(m.Alloc), bToMb(m.TotalAlloc), bToMb(m.Sys), m.NumGC)
|
CpuUsage(), bToMb(alloc), bToMb(totalAlloc), bToMb(sys), stats.NumGC)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -532,7 +532,7 @@ func createModel(t *testing.T, coll mon.Collection) *Model {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// mustNewTestModel returns a test Model with the given cache.
|
// mustNewTestModel returns a test Model with the given cache.
|
||||||
func mustNewTestModel(collection mon.Collection, c cache.CacheConf, opts ...cache.Option) *Model {
|
func mustNewTestModel(collection mon.Collection, c cache.CacheConf, opts ...cache.Option) *Model {
|
||||||
return &Model{
|
return &Model{
|
||||||
Model: &mon.Model{
|
Model: &mon.Model{
|
||||||
|
|||||||
@@ -259,12 +259,34 @@ func (s *Redis) BitPosCtx(ctx context.Context, key string, bit, start, end int64
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Blpop uses passed in redis connection to execute blocking queries.
|
// Blpop uses passed in redis connection to execute blocking queries.
|
||||||
|
//
|
||||||
|
// For blocking operations, you must create a dedicated RedisNode using CreateBlockingNode to avoid
|
||||||
|
// exhausting the connection pool. Blocking commands hold connections for extended periods and should
|
||||||
|
// not share the regular connection pool.
|
||||||
|
//
|
||||||
|
// Example usage:
|
||||||
|
//
|
||||||
|
// node, err := redis.CreateBlockingNode(rds)
|
||||||
|
// if err != nil {
|
||||||
|
// // handle error
|
||||||
|
// }
|
||||||
|
// defer node.Close()
|
||||||
|
//
|
||||||
|
// value, err := rds.Blpop(node, "mylist")
|
||||||
|
// if err != nil {
|
||||||
|
// // handle error
|
||||||
|
// }
|
||||||
|
//
|
||||||
// Doesn't benefit from pooling redis connections of blocking queries
|
// Doesn't benefit from pooling redis connections of blocking queries
|
||||||
func (s *Redis) Blpop(node RedisNode, key string) (string, error) {
|
func (s *Redis) Blpop(node RedisNode, key string) (string, error) {
|
||||||
return s.BlpopCtx(context.Background(), node, key)
|
return s.BlpopCtx(context.Background(), node, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
// BlpopCtx uses passed in redis connection to execute blocking queries.
|
// BlpopCtx uses passed in redis connection to execute blocking queries.
|
||||||
|
//
|
||||||
|
// For blocking operations, you must create a dedicated RedisNode using CreateBlockingNode.
|
||||||
|
// See Blpop for usage examples.
|
||||||
|
//
|
||||||
// Doesn't benefit from pooling redis connections of blocking queries
|
// Doesn't benefit from pooling redis connections of blocking queries
|
||||||
func (s *Redis) BlpopCtx(ctx context.Context, node RedisNode, key string) (string, error) {
|
func (s *Redis) BlpopCtx(ctx context.Context, node RedisNode, key string) (string, error) {
|
||||||
return s.BlpopWithTimeoutCtx(ctx, node, blockingQueryTimeout, key)
|
return s.BlpopWithTimeoutCtx(ctx, node, blockingQueryTimeout, key)
|
||||||
@@ -272,12 +294,18 @@ func (s *Redis) BlpopCtx(ctx context.Context, node RedisNode, key string) (strin
|
|||||||
|
|
||||||
// BlpopEx uses passed in redis connection to execute blpop command.
|
// BlpopEx uses passed in redis connection to execute blpop command.
|
||||||
// The difference against Blpop is that this method returns a bool to indicate success.
|
// The difference against Blpop is that this method returns a bool to indicate success.
|
||||||
|
//
|
||||||
|
// For blocking operations, you must create a dedicated RedisNode using CreateBlockingNode.
|
||||||
|
// See Blpop for usage examples.
|
||||||
func (s *Redis) BlpopEx(node RedisNode, key string) (string, bool, error) {
|
func (s *Redis) BlpopEx(node RedisNode, key string) (string, bool, error) {
|
||||||
return s.BlpopExCtx(context.Background(), node, key)
|
return s.BlpopExCtx(context.Background(), node, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
// BlpopExCtx uses passed in redis connection to execute blpop command.
|
// BlpopExCtx uses passed in redis connection to execute blpop command.
|
||||||
// The difference against Blpop is that this method returns a bool to indicate success.
|
// The difference against Blpop is that this method returns a bool to indicate success.
|
||||||
|
//
|
||||||
|
// For blocking operations, you must create a dedicated RedisNode using CreateBlockingNode.
|
||||||
|
// See Blpop for usage examples.
|
||||||
func (s *Redis) BlpopExCtx(ctx context.Context, node RedisNode, key string) (string, bool, error) {
|
func (s *Redis) BlpopExCtx(ctx context.Context, node RedisNode, key string) (string, bool, error) {
|
||||||
if node == nil {
|
if node == nil {
|
||||||
return "", false, ErrNilNode
|
return "", false, ErrNilNode
|
||||||
@@ -297,12 +325,18 @@ func (s *Redis) BlpopExCtx(ctx context.Context, node RedisNode, key string) (str
|
|||||||
|
|
||||||
// BlpopWithTimeout uses passed in redis connection to execute blpop command.
|
// BlpopWithTimeout uses passed in redis connection to execute blpop command.
|
||||||
// Control blocking query timeout
|
// Control blocking query timeout
|
||||||
|
//
|
||||||
|
// For blocking operations, you must create a dedicated RedisNode using CreateBlockingNode.
|
||||||
|
// See Blpop for usage examples.
|
||||||
func (s *Redis) BlpopWithTimeout(node RedisNode, timeout time.Duration, key string) (string, error) {
|
func (s *Redis) BlpopWithTimeout(node RedisNode, timeout time.Duration, key string) (string, error) {
|
||||||
return s.BlpopWithTimeoutCtx(context.Background(), node, timeout, key)
|
return s.BlpopWithTimeoutCtx(context.Background(), node, timeout, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
// BlpopWithTimeoutCtx uses passed in redis connection to execute blpop command.
|
// BlpopWithTimeoutCtx uses passed in redis connection to execute blpop command.
|
||||||
// Control blocking query timeout
|
// Control blocking query timeout
|
||||||
|
//
|
||||||
|
// For blocking operations, you must create a dedicated RedisNode using CreateBlockingNode.
|
||||||
|
// See Blpop for usage examples.
|
||||||
func (s *Redis) BlpopWithTimeoutCtx(ctx context.Context, node RedisNode, timeout time.Duration,
|
func (s *Redis) BlpopWithTimeoutCtx(ctx context.Context, node RedisNode, timeout time.Duration,
|
||||||
key string) (string, error) {
|
key string) (string, error) {
|
||||||
if node == nil {
|
if node == nil {
|
||||||
@@ -630,6 +664,28 @@ func (s *Redis) GetDelCtx(ctx context.Context, key string) (string, error) {
|
|||||||
return val, err
|
return val, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetEx is the implementation of redis getex command.
|
||||||
|
// Available since: redis version 6.2.0
|
||||||
|
func (s *Redis) GetEx(key string, seconds int) (string, error) {
|
||||||
|
return s.GetExCtx(context.Background(), key, seconds)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetExCtx is the implementation of redis getex command.
|
||||||
|
// Available since: redis version 6.2.0
|
||||||
|
func (s *Redis) GetExCtx(ctx context.Context, key string, seconds int) (string, error) {
|
||||||
|
conn, err := getRedis(s)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
val, err := conn.GetEx(ctx, key, time.Duration(seconds)*time.Second).Result()
|
||||||
|
if errors.Is(err, red.Nil) {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return val, err
|
||||||
|
}
|
||||||
|
|
||||||
// GetSet is the implementation of redis getset command.
|
// GetSet is the implementation of redis getset command.
|
||||||
func (s *Redis) GetSet(key, value string) (string, error) {
|
func (s *Redis) GetSet(key, value string) (string, error) {
|
||||||
return s.GetSetCtx(context.Background(), key, value)
|
return s.GetSetCtx(context.Background(), key, value)
|
||||||
@@ -1840,6 +1896,29 @@ func (s *Redis) XInfoStreamCtx(ctx context.Context, stream string) (*red.XInfoSt
|
|||||||
|
|
||||||
// XReadGroup reads messages from Redis streams as part of a consumer group.
|
// XReadGroup reads messages from Redis streams as part of a consumer group.
|
||||||
// It allows for distributed processing of stream messages with automatic message delivery semantics.
|
// It allows for distributed processing of stream messages with automatic message delivery semantics.
|
||||||
|
//
|
||||||
|
// For blocking operations, you must create a dedicated RedisNode using CreateBlockingNode to avoid
|
||||||
|
// exhausting the connection pool. Blocking commands hold connections for extended periods and should
|
||||||
|
// not share the regular connection pool.
|
||||||
|
//
|
||||||
|
// Example usage:
|
||||||
|
//
|
||||||
|
// node, err := redis.CreateBlockingNode(rds)
|
||||||
|
// if err != nil {
|
||||||
|
// // handle error
|
||||||
|
// }
|
||||||
|
// defer node.Close()
|
||||||
|
//
|
||||||
|
// streams, err := rds.XReadGroup(
|
||||||
|
// node, // RedisNode created with CreateBlockingNode
|
||||||
|
// "mygroup", // consumer group name
|
||||||
|
// "consumer1", // consumer ID
|
||||||
|
// 10, // max number of messages to read
|
||||||
|
// 5*time.Second, // block duration
|
||||||
|
// false, // noAck flag
|
||||||
|
// "mystream", // stream name
|
||||||
|
// )
|
||||||
|
//
|
||||||
// Doesn't benefit from pooling redis connections of blocking queries.
|
// Doesn't benefit from pooling redis connections of blocking queries.
|
||||||
func (s *Redis) XReadGroup(node RedisNode, group string, consumerId string, count int64,
|
func (s *Redis) XReadGroup(node RedisNode, group string, consumerId string, count int64,
|
||||||
block time.Duration, noAck bool, streams ...string) ([]red.XStream, error) {
|
block time.Duration, noAck bool, streams ...string) ([]red.XStream, error) {
|
||||||
@@ -1847,6 +1926,10 @@ func (s *Redis) XReadGroup(node RedisNode, group string, consumerId string, coun
|
|||||||
}
|
}
|
||||||
|
|
||||||
// XReadGroupCtx is the context-aware version of XReadGroup.
|
// XReadGroupCtx is the context-aware version of XReadGroup.
|
||||||
|
//
|
||||||
|
// For blocking operations, you must create a dedicated RedisNode using CreateBlockingNode to avoid
|
||||||
|
// exhausting the connection pool. See XReadGroup for usage examples.
|
||||||
|
//
|
||||||
// Doesn't benefit from pooling redis connections of blocking queries.
|
// Doesn't benefit from pooling redis connections of blocking queries.
|
||||||
func (s *Redis) XReadGroupCtx(ctx context.Context, node RedisNode, group string, consumerId string,
|
func (s *Redis) XReadGroupCtx(ctx context.Context, node RedisNode, group string, consumerId string,
|
||||||
count int64, block time.Duration, noAck bool, streams ...string) ([]red.XStream, error) {
|
count int64, block time.Duration, noAck bool, streams ...string) ([]red.XStream, error) {
|
||||||
|
|||||||
@@ -1104,6 +1104,45 @@ func TestRedis_GetDel(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRedis_GetEx(t *testing.T) {
|
||||||
|
t.Run("get_ex", func(t *testing.T) {
|
||||||
|
runOnRedis(t, func(client *Redis) {
|
||||||
|
val, err := client.GetEx("getex_key", 10)
|
||||||
|
assert.Equal(t, "", val)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
err = client.Set("getex_key", "getex_value")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
val, err = client.GetEx("getex_key", 10)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "getex_value", val)
|
||||||
|
val, err = client.Get("getex_key")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "getex_value", val)
|
||||||
|
|
||||||
|
ttl, err := client.Ttl("getex_key")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.True(t, ttl > 0 && ttl <= 10)
|
||||||
|
|
||||||
|
val, err = client.GetEx("getex_key", 5)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "getex_value", val)
|
||||||
|
|
||||||
|
ttl, err = client.Ttl("getex_key")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.True(t, ttl > 0 && ttl <= 5)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("get_ex_with_error", func(t *testing.T) {
|
||||||
|
runOnRedisWithError(t, func(client *Redis) {
|
||||||
|
_, err := newRedis(client.Addr, badType()).GetEx("hello", 10)
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestRedis_GetSet(t *testing.T) {
|
func TestRedis_GetSet(t *testing.T) {
|
||||||
t.Run("set_get", func(t *testing.T) {
|
t.Run("set_get", func(t *testing.T) {
|
||||||
runOnRedis(t, func(client *Redis) {
|
runOnRedis(t, func(client *Redis) {
|
||||||
|
|||||||
@@ -13,7 +13,37 @@ type ClosableNode interface {
|
|||||||
Close()
|
Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateBlockingNode returns a ClosableNode.
|
// CreateBlockingNode creates a dedicated RedisNode for blocking operations.
|
||||||
|
//
|
||||||
|
// Blocking Redis commands (like BLPOP, BRPOP, XREADGROUP with block parameter) hold connections
|
||||||
|
// for extended periods while waiting for data. Using them with the regular Redis connection pool
|
||||||
|
// can exhaust all available connections, causing other operations to fail or timeout.
|
||||||
|
//
|
||||||
|
// CreateBlockingNode creates a separate Redis client with a minimal connection pool (size 1) that
|
||||||
|
// is dedicated to blocking operations. This ensures blocking commands don't interfere with regular
|
||||||
|
// Redis operations.
|
||||||
|
//
|
||||||
|
// Example usage:
|
||||||
|
//
|
||||||
|
// rds := redis.MustNewRedis(redis.RedisConf{
|
||||||
|
// Host: "localhost:6379",
|
||||||
|
// Type: redis.NodeType,
|
||||||
|
// })
|
||||||
|
//
|
||||||
|
// // Create a dedicated node for blocking operations
|
||||||
|
// node, err := redis.CreateBlockingNode(rds)
|
||||||
|
// if err != nil {
|
||||||
|
// // handle error
|
||||||
|
// }
|
||||||
|
// defer node.Close() // Important: close the node when done
|
||||||
|
//
|
||||||
|
// // Use the node for blocking operations
|
||||||
|
// value, err := rds.Blpop(node, "mylist")
|
||||||
|
// if err != nil {
|
||||||
|
// // handle error
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// The returned ClosableNode must be closed when no longer needed to release resources.
|
||||||
func CreateBlockingNode(r *Redis) (ClosableNode, error) {
|
func CreateBlockingNode(r *Redis) (ClosableNode, error) {
|
||||||
timeout := readWriteTimeout + blockingQueryTimeout
|
timeout := readWriteTimeout + blockingQueryTimeout
|
||||||
|
|
||||||
|
|||||||
@@ -70,25 +70,16 @@ func getTaggedFieldValueMap(v reflect.Value) (map[string]any, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func getValueInterface(value reflect.Value) (any, error) {
|
func getValueInterface(value reflect.Value) (any, error) {
|
||||||
switch value.Kind() {
|
if !value.CanAddr() || !value.Addr().CanInterface() {
|
||||||
case reflect.Ptr:
|
return nil, ErrNotReadableValue
|
||||||
if !value.CanInterface() {
|
|
||||||
return nil, ErrNotReadableValue
|
|
||||||
}
|
|
||||||
|
|
||||||
if value.IsNil() {
|
|
||||||
baseValueType := mapping.Deref(value.Type())
|
|
||||||
value.Set(reflect.New(baseValueType))
|
|
||||||
}
|
|
||||||
|
|
||||||
return value.Interface(), nil
|
|
||||||
default:
|
|
||||||
if !value.CanAddr() || !value.Addr().CanInterface() {
|
|
||||||
return nil, ErrNotReadableValue
|
|
||||||
}
|
|
||||||
|
|
||||||
return value.Addr().Interface(), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if value.Kind() == reflect.Pointer && value.IsNil() {
|
||||||
|
baseValueType := mapping.Deref(value.Type())
|
||||||
|
value.Set(reflect.New(baseValueType))
|
||||||
|
}
|
||||||
|
|
||||||
|
return value.Addr().Interface(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func isScanFailed(err error) bool {
|
func isScanFailed(err error) bool {
|
||||||
|
|||||||
@@ -4,7 +4,9 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/DATA-DOG/go-sqlmock"
|
"github.com/DATA-DOG/go-sqlmock"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@@ -1575,6 +1577,782 @@ func TestAnonymousStructPrError(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUnmarshalRowsZeroValueStructPtr(t *testing.T) {
|
||||||
|
secondNamePtr := "second_ptr"
|
||||||
|
secondAgePtr := int64(30)
|
||||||
|
thirdNamePtr := "third_ptr"
|
||||||
|
thirdAgePtr := int64(0)
|
||||||
|
|
||||||
|
expect := []struct {
|
||||||
|
Name string
|
||||||
|
NamePtr *string
|
||||||
|
Age int64
|
||||||
|
AgePtr *int64
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
Name: "first",
|
||||||
|
NamePtr: nil,
|
||||||
|
Age: 2,
|
||||||
|
AgePtr: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "second",
|
||||||
|
NamePtr: &secondNamePtr,
|
||||||
|
Age: 3,
|
||||||
|
AgePtr: &secondAgePtr,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "",
|
||||||
|
NamePtr: &thirdNamePtr,
|
||||||
|
Age: 0,
|
||||||
|
AgePtr: &thirdAgePtr,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var value []struct {
|
||||||
|
Age int64 `db:"age"`
|
||||||
|
AgePtr *int64 `db:"age_ptr"`
|
||||||
|
Name string `db:"name"`
|
||||||
|
NamePtr *string `db:"name_ptr"`
|
||||||
|
}
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
rs := sqlmock.NewRows([]string{"name", "name_ptr", "age", "age_ptr"}).
|
||||||
|
AddRow("first", nil, 2, nil).
|
||||||
|
AddRow("second", "second_ptr", 3, 30).
|
||||||
|
AddRow("", "third_ptr", 0, 0)
|
||||||
|
|
||||||
|
mock.ExpectQuery("select (.+) from users where user=?").
|
||||||
|
WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
|
||||||
|
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRows(&value, rows, true)
|
||||||
|
}, "select name, name_ptr, age, age_ptr from users where user=?", "anyone"))
|
||||||
|
|
||||||
|
assert.Equal(t, 3, len(value), "应该返回3行数据")
|
||||||
|
|
||||||
|
for i, each := range expect {
|
||||||
|
|
||||||
|
assert.Equal(t, each.Name, value[i].Name)
|
||||||
|
assert.Equal(t, each.Age, value[i].Age)
|
||||||
|
|
||||||
|
if each.NamePtr == nil {
|
||||||
|
assert.Nil(t, value[i].NamePtr)
|
||||||
|
} else {
|
||||||
|
assert.NotNil(t, value[i].NamePtr)
|
||||||
|
assert.Equal(t, *each.NamePtr, *value[i].NamePtr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if each.AgePtr == nil {
|
||||||
|
assert.Nil(t, value[i].AgePtr)
|
||||||
|
} else {
|
||||||
|
assert.NotNil(t, value[i].AgePtr)
|
||||||
|
assert.Equal(t, *each.AgePtr, *value[i].AgePtr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnmarshalRowsAllNullStructPtrFields(t *testing.T) {
|
||||||
|
expect := []struct {
|
||||||
|
NamePtr *string
|
||||||
|
AgePtr *int64
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
NamePtr: nil,
|
||||||
|
AgePtr: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
NamePtr: stringPtr("second"),
|
||||||
|
AgePtr: int64Ptr(30),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
NamePtr: nil,
|
||||||
|
AgePtr: nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var value []struct {
|
||||||
|
AgePtr *int64 `db:"age_ptr"`
|
||||||
|
NamePtr *string `db:"name_ptr"`
|
||||||
|
}
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
rs := sqlmock.NewRows([]string{"name_ptr", "age_ptr"}).
|
||||||
|
AddRow(nil, nil).
|
||||||
|
AddRow("second", 30).
|
||||||
|
AddRow(nil, nil)
|
||||||
|
|
||||||
|
mock.ExpectQuery("select (.+) from users where user=?").
|
||||||
|
WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
|
||||||
|
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRows(&value, rows, true)
|
||||||
|
}, "select name_ptr, age_ptr from users where user=?", "anyone"))
|
||||||
|
|
||||||
|
assert.Equal(t, 3, len(value))
|
||||||
|
|
||||||
|
for i, each := range expect {
|
||||||
|
if each.NamePtr == nil {
|
||||||
|
assert.Nil(t, value[i].NamePtr)
|
||||||
|
} else {
|
||||||
|
assert.NotNil(t, value[i].NamePtr)
|
||||||
|
assert.Equal(t, *each.NamePtr, *value[i].NamePtr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if each.AgePtr == nil {
|
||||||
|
assert.Nil(t, value[i].AgePtr)
|
||||||
|
} else {
|
||||||
|
assert.NotNil(t, value[i].AgePtr)
|
||||||
|
assert.Equal(t, *each.AgePtr, *value[i].AgePtr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnmarshalRowsWithSqlNullTypes(t *testing.T) {
|
||||||
|
expect := []struct {
|
||||||
|
Name string
|
||||||
|
NullName sql.NullString
|
||||||
|
Age int64
|
||||||
|
NullAge sql.NullInt64
|
||||||
|
Score float64
|
||||||
|
NullScore sql.NullFloat64
|
||||||
|
Active bool
|
||||||
|
NullActive sql.NullBool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
Name: "first",
|
||||||
|
NullName: sql.NullString{
|
||||||
|
String: "",
|
||||||
|
Valid: false,
|
||||||
|
},
|
||||||
|
Age: 20,
|
||||||
|
NullAge: sql.NullInt64{
|
||||||
|
Int64: 0,
|
||||||
|
Valid: false,
|
||||||
|
},
|
||||||
|
Score: 85.5,
|
||||||
|
NullScore: sql.NullFloat64{
|
||||||
|
Float64: 0,
|
||||||
|
Valid: false,
|
||||||
|
},
|
||||||
|
Active: true,
|
||||||
|
NullActive: sql.NullBool{
|
||||||
|
Bool: false,
|
||||||
|
Valid: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "second",
|
||||||
|
NullName: sql.NullString{
|
||||||
|
String: "not_null_name",
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
Age: 25,
|
||||||
|
NullAge: sql.NullInt64{
|
||||||
|
Int64: 30,
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
Score: 90.0,
|
||||||
|
NullScore: sql.NullFloat64{
|
||||||
|
Float64: 95.5,
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
Active: false,
|
||||||
|
NullActive: sql.NullBool{
|
||||||
|
Bool: true,
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "third",
|
||||||
|
NullName: sql.NullString{
|
||||||
|
String: "",
|
||||||
|
Valid: false,
|
||||||
|
},
|
||||||
|
Age: 0,
|
||||||
|
NullAge: sql.NullInt64{
|
||||||
|
Int64: 0,
|
||||||
|
Valid: false,
|
||||||
|
},
|
||||||
|
Score: 0,
|
||||||
|
NullScore: sql.NullFloat64{
|
||||||
|
Float64: 0,
|
||||||
|
Valid: false,
|
||||||
|
},
|
||||||
|
Active: false,
|
||||||
|
NullActive: sql.NullBool{
|
||||||
|
Bool: false,
|
||||||
|
Valid: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var value []struct {
|
||||||
|
Name string `db:"name"`
|
||||||
|
NullName sql.NullString `db:"null_name"`
|
||||||
|
Age int64 `db:"age"`
|
||||||
|
NullAge sql.NullInt64 `db:"null_age"`
|
||||||
|
Score float64 `db:"score"`
|
||||||
|
NullScore sql.NullFloat64 `db:"null_score"`
|
||||||
|
Active bool `db:"active"`
|
||||||
|
NullActive sql.NullBool `db:"null_active"`
|
||||||
|
}
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
rs := sqlmock.NewRows([]string{
|
||||||
|
"name", "null_name", "age", "null_age", "score", "null_score", "active", "null_active",
|
||||||
|
}).
|
||||||
|
AddRow("first", nil, 20, nil, 85.5, nil, true, nil).
|
||||||
|
AddRow("second", "not_null_name", 25, 30, 90.0, 95.5, false, true).
|
||||||
|
AddRow("third", nil, 0, nil, 0, nil, false, nil)
|
||||||
|
|
||||||
|
mock.ExpectQuery("select (.+) from users where type=?").
|
||||||
|
WithArgs("test").WillReturnRows(rs)
|
||||||
|
|
||||||
|
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRows(&value, rows, true)
|
||||||
|
}, "select name, null_name, age, null_age, score, null_score, active, null_active from users where type=?", "test"))
|
||||||
|
|
||||||
|
assert.Equal(t, 3, len(value))
|
||||||
|
|
||||||
|
for i, each := range expect {
|
||||||
|
assert.Equal(t, each.Name, value[i].Name)
|
||||||
|
assert.Equal(t, each.Age, value[i].Age)
|
||||||
|
assert.Equal(t, each.Score, value[i].Score)
|
||||||
|
assert.Equal(t, each.Active, value[i].Active)
|
||||||
|
|
||||||
|
assert.Equal(t, each.NullName.Valid, value[i].NullName.Valid)
|
||||||
|
if each.NullName.Valid {
|
||||||
|
assert.Equal(t, each.NullName.String, value[i].NullName.String)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, each.NullAge.Valid, value[i].NullAge.Valid)
|
||||||
|
if each.NullAge.Valid {
|
||||||
|
assert.Equal(t, each.NullAge.Int64, value[i].NullAge.Int64)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, each.NullScore.Valid, value[i].NullScore.Valid)
|
||||||
|
if each.NullScore.Valid {
|
||||||
|
assert.Equal(t, each.NullScore.Float64, value[i].NullScore.Float64)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, each.NullActive.Valid, value[i].NullActive.Valid)
|
||||||
|
if each.NullActive.Valid {
|
||||||
|
assert.Equal(t, each.NullActive.Bool, value[i].NullActive.Bool)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnmarshalRowsSqlNullWithMixedData(t *testing.T) {
|
||||||
|
expect := []struct {
|
||||||
|
Name string
|
||||||
|
NullName sql.NullString
|
||||||
|
Age int64
|
||||||
|
NullAge sql.NullInt64
|
||||||
|
IsStudent bool
|
||||||
|
NullActive sql.NullBool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
Name: "student1",
|
||||||
|
NullName: sql.NullString{
|
||||||
|
String: "",
|
||||||
|
Valid: false,
|
||||||
|
},
|
||||||
|
Age: 18,
|
||||||
|
NullAge: sql.NullInt64{
|
||||||
|
Int64: 0,
|
||||||
|
Valid: false,
|
||||||
|
},
|
||||||
|
IsStudent: true,
|
||||||
|
NullActive: sql.NullBool{
|
||||||
|
Bool: false,
|
||||||
|
Valid: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "student2",
|
||||||
|
NullName: sql.NullString{
|
||||||
|
String: "has_nickname",
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
Age: 20,
|
||||||
|
NullAge: sql.NullInt64{
|
||||||
|
Int64: 22,
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
IsStudent: false,
|
||||||
|
NullActive: sql.NullBool{
|
||||||
|
Bool: true,
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var value []struct {
|
||||||
|
Name string `db:"name"`
|
||||||
|
NullName sql.NullString `db:"null_name"`
|
||||||
|
Age int64 `db:"age"`
|
||||||
|
NullAge sql.NullInt64 `db:"null_age"`
|
||||||
|
IsStudent bool `db:"is_student"`
|
||||||
|
NullActive sql.NullBool `db:"null_active"`
|
||||||
|
}
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
rs := sqlmock.NewRows([]string{"name", "null_name", "age", "null_age", "is_student", "null_active"}).
|
||||||
|
AddRow("student1", nil, 18, nil, true, nil).
|
||||||
|
AddRow("student2", "has_nickname", 20, 22, false, true)
|
||||||
|
|
||||||
|
mock.ExpectQuery("select (.+) from students where class=?").
|
||||||
|
WithArgs("A").WillReturnRows(rs)
|
||||||
|
|
||||||
|
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRows(&value, rows, true)
|
||||||
|
}, "select name, null_name, age, null_age, is_student, null_active from students where class=?", "A"))
|
||||||
|
|
||||||
|
assert.Equal(t, 2, len(value))
|
||||||
|
|
||||||
|
for i, each := range expect {
|
||||||
|
assert.Equal(t, each.Name, value[i].Name)
|
||||||
|
assert.Equal(t, each.Age, value[i].Age)
|
||||||
|
assert.Equal(t, each.IsStudent, value[i].IsStudent)
|
||||||
|
|
||||||
|
assert.Equal(t, each.NullName.Valid, value[i].NullName.Valid)
|
||||||
|
if each.NullName.Valid {
|
||||||
|
assert.Equal(t, each.NullName.String, value[i].NullName.String)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, each.NullAge.Valid, value[i].NullAge.Valid)
|
||||||
|
if each.NullAge.Valid {
|
||||||
|
assert.Equal(t, each.NullAge.Int64, value[i].NullAge.Int64)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, each.NullActive.Valid, value[i].NullActive.Valid)
|
||||||
|
if each.NullActive.Valid {
|
||||||
|
assert.Equal(t, each.NullActive.Bool, value[i].NullActive.Bool)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnmarshalRowsSqlNullTime(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
futureTime := now.AddDate(1, 0, 0)
|
||||||
|
|
||||||
|
expect := []struct {
|
||||||
|
Name string
|
||||||
|
BirthDate sql.NullTime
|
||||||
|
LastLogin sql.NullTime
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
Name: "user1",
|
||||||
|
BirthDate: sql.NullTime{
|
||||||
|
Time: time.Time{},
|
||||||
|
Valid: false,
|
||||||
|
},
|
||||||
|
LastLogin: sql.NullTime{
|
||||||
|
Time: now,
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "user2",
|
||||||
|
BirthDate: sql.NullTime{
|
||||||
|
Time: futureTime,
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
LastLogin: sql.NullTime{
|
||||||
|
Time: time.Time{},
|
||||||
|
Valid: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var value []struct {
|
||||||
|
Name string `db:"name"`
|
||||||
|
BirthDate sql.NullTime `db:"birth_date"`
|
||||||
|
LastLogin sql.NullTime `db:"last_login"`
|
||||||
|
}
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
rs := sqlmock.NewRows([]string{"name", "birth_date", "last_login"}).
|
||||||
|
AddRow("user1", nil, now).
|
||||||
|
AddRow("user2", futureTime, nil)
|
||||||
|
|
||||||
|
mock.ExpectQuery("select (.+) from users").
|
||||||
|
WillReturnRows(rs)
|
||||||
|
|
||||||
|
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRows(&value, rows, true)
|
||||||
|
}, "select name, birth_date, last_login from users"))
|
||||||
|
|
||||||
|
assert.Equal(t, 2, len(value))
|
||||||
|
|
||||||
|
for i, each := range expect {
|
||||||
|
assert.Equal(t, each.Name, value[i].Name)
|
||||||
|
|
||||||
|
assert.Equal(t, each.BirthDate.Valid, value[i].BirthDate.Valid)
|
||||||
|
if each.BirthDate.Valid {
|
||||||
|
assert.WithinDuration(t, each.BirthDate.Time, value[i].BirthDate.Time, time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, each.LastLogin.Valid, value[i].LastLogin.Valid)
|
||||||
|
if each.LastLogin.Valid {
|
||||||
|
assert.WithinDuration(t, each.LastLogin.Time, value[i].LastLogin.Time, time.Second)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnmarshalRowsSqlNullWithEmptyValues(t *testing.T) {
|
||||||
|
expect := []struct {
|
||||||
|
Name string
|
||||||
|
NullString sql.NullString
|
||||||
|
NullInt sql.NullInt64
|
||||||
|
NullFloat sql.NullFloat64
|
||||||
|
NullBool sql.NullBool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
Name: "empty_values",
|
||||||
|
NullString: sql.NullString{
|
||||||
|
String: "",
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
NullInt: sql.NullInt64{
|
||||||
|
Int64: 0,
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
NullFloat: sql.NullFloat64{
|
||||||
|
Float64: 0.0,
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
NullBool: sql.NullBool{
|
||||||
|
Bool: false,
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "null_values",
|
||||||
|
NullString: sql.NullString{
|
||||||
|
String: "",
|
||||||
|
Valid: false,
|
||||||
|
},
|
||||||
|
NullInt: sql.NullInt64{
|
||||||
|
Int64: 0,
|
||||||
|
Valid: false,
|
||||||
|
},
|
||||||
|
NullFloat: sql.NullFloat64{
|
||||||
|
Float64: 0.0,
|
||||||
|
Valid: false,
|
||||||
|
},
|
||||||
|
NullBool: sql.NullBool{
|
||||||
|
Bool: false,
|
||||||
|
Valid: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "mixed_values",
|
||||||
|
NullString: sql.NullString{
|
||||||
|
String: "actual_value",
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
NullInt: sql.NullInt64{
|
||||||
|
Int64: 0,
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
NullFloat: sql.NullFloat64{
|
||||||
|
Float64: 0.0,
|
||||||
|
Valid: false,
|
||||||
|
},
|
||||||
|
NullBool: sql.NullBool{
|
||||||
|
Bool: true,
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var value []struct {
|
||||||
|
Name string `db:"name"`
|
||||||
|
NullString sql.NullString `db:"null_string"`
|
||||||
|
NullInt sql.NullInt64 `db:"null_int"`
|
||||||
|
NullFloat sql.NullFloat64 `db:"null_float"`
|
||||||
|
NullBool sql.NullBool `db:"null_bool"`
|
||||||
|
}
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
rs := sqlmock.NewRows([]string{"name", "null_string", "null_int", "null_float", "null_bool"}).
|
||||||
|
AddRow("empty_values", "", 0, 0.0, false).
|
||||||
|
AddRow("null_values", nil, nil, nil, nil).
|
||||||
|
AddRow("mixed_values", "actual_value", 0, nil, true)
|
||||||
|
|
||||||
|
mock.ExpectQuery("select (.+) from test_table").
|
||||||
|
WillReturnRows(rs)
|
||||||
|
|
||||||
|
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRows(&value, rows, true)
|
||||||
|
}, "select name, null_string, null_int, null_float, null_bool from test_table"))
|
||||||
|
|
||||||
|
assert.Equal(t, 3, len(value))
|
||||||
|
|
||||||
|
for i, each := range expect {
|
||||||
|
|
||||||
|
assert.Equal(t, each.Name, value[i].Name)
|
||||||
|
|
||||||
|
assert.Equal(t, each.NullString.Valid, value[i].NullString.Valid)
|
||||||
|
if each.NullString.Valid {
|
||||||
|
assert.Equal(t, each.NullString.String, value[i].NullString.String)
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, "", value[i].NullString.String)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, each.NullInt.Valid, value[i].NullInt.Valid)
|
||||||
|
if each.NullInt.Valid {
|
||||||
|
assert.Equal(t, each.NullInt.Int64, value[i].NullInt.Int64)
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, int64(0), value[i].NullInt.Int64)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, each.NullFloat.Valid, value[i].NullFloat.Valid)
|
||||||
|
if each.NullFloat.Valid {
|
||||||
|
assert.Equal(t, each.NullFloat.Float64, value[i].NullFloat.Float64)
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, 0.0, value[i].NullFloat.Float64)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, each.NullBool.Valid, value[i].NullBool.Valid)
|
||||||
|
if each.NullBool.Valid {
|
||||||
|
assert.Equal(t, each.NullBool.Bool, value[i].NullBool.Bool)
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, false, value[i].NullBool.Bool)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnmarshalRowsSqlNullStringEmptyVsNull(t *testing.T) {
|
||||||
|
expect := []struct {
|
||||||
|
Name string
|
||||||
|
EmptyString sql.NullString
|
||||||
|
NullString sql.NullString
|
||||||
|
NormalString sql.NullString
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
Name: "row1",
|
||||||
|
EmptyString: sql.NullString{
|
||||||
|
String: "",
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
NullString: sql.NullString{
|
||||||
|
String: "",
|
||||||
|
Valid: false,
|
||||||
|
},
|
||||||
|
NormalString: sql.NullString{
|
||||||
|
String: "hello",
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "row2",
|
||||||
|
EmptyString: sql.NullString{
|
||||||
|
String: " ",
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
NullString: sql.NullString{
|
||||||
|
String: "",
|
||||||
|
Valid: false,
|
||||||
|
},
|
||||||
|
NormalString: sql.NullString{
|
||||||
|
String: "",
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var value []struct {
|
||||||
|
Name string `db:"name"`
|
||||||
|
EmptyString sql.NullString `db:"empty_string"`
|
||||||
|
NullString sql.NullString `db:"null_string"`
|
||||||
|
NormalString sql.NullString `db:"normal_string"`
|
||||||
|
}
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
rs := sqlmock.NewRows([]string{"name", "empty_string", "null_string", "normal_string"}).
|
||||||
|
AddRow("row1", "", nil, "hello").
|
||||||
|
AddRow("row2", " ", nil, "")
|
||||||
|
|
||||||
|
mock.ExpectQuery("select (.+) from string_test").
|
||||||
|
WillReturnRows(rs)
|
||||||
|
|
||||||
|
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRows(&value, rows, true)
|
||||||
|
}, "select name, empty_string, null_string, normal_string from string_test"))
|
||||||
|
|
||||||
|
assert.Equal(t, 2, len(value))
|
||||||
|
|
||||||
|
for i, each := range expect {
|
||||||
|
assert.True(t, value[i].EmptyString.Valid)
|
||||||
|
assert.Equal(t, each.EmptyString.String, value[i].EmptyString.String)
|
||||||
|
|
||||||
|
assert.False(t, value[i].NullString.Valid)
|
||||||
|
assert.Equal(t, "", value[i].NullString.String)
|
||||||
|
|
||||||
|
assert.Equal(t, each.NormalString.Valid, value[i].NormalString.Valid)
|
||||||
|
if each.NormalString.Valid {
|
||||||
|
assert.Equal(t, each.NormalString.String, value[i].NormalString.String)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetValueInterface(t *testing.T) {
|
||||||
|
t.Run("non_pointer_field", func(t *testing.T) {
|
||||||
|
type testStruct struct {
|
||||||
|
Name string
|
||||||
|
Age int
|
||||||
|
}
|
||||||
|
s := testStruct{}
|
||||||
|
v := reflect.ValueOf(&s).Elem()
|
||||||
|
|
||||||
|
nameField := v.Field(0)
|
||||||
|
result, err := getValueInterface(nameField)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
|
||||||
|
// Should return pointer to the field
|
||||||
|
ptr, ok := result.(*string)
|
||||||
|
assert.True(t, ok)
|
||||||
|
*ptr = "test"
|
||||||
|
assert.Equal(t, "test", s.Name)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("pointer_field_nil", func(t *testing.T) {
|
||||||
|
type testStruct struct {
|
||||||
|
NamePtr *string
|
||||||
|
AgePtr *int64
|
||||||
|
}
|
||||||
|
s := testStruct{}
|
||||||
|
v := reflect.ValueOf(&s).Elem()
|
||||||
|
|
||||||
|
// Test with nil pointer field
|
||||||
|
namePtrField := v.Field(0)
|
||||||
|
assert.True(t, namePtrField.IsNil(), "initial pointer should be nil")
|
||||||
|
|
||||||
|
result, err := getValueInterface(namePtrField)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
|
||||||
|
// Should have allocated the pointer
|
||||||
|
assert.False(t, namePtrField.IsNil(), "pointer should be allocated after getValueInterface")
|
||||||
|
|
||||||
|
// Should return pointer to pointer field
|
||||||
|
ptrPtr, ok := result.(**string)
|
||||||
|
assert.True(t, ok)
|
||||||
|
testValue := "initialized"
|
||||||
|
*ptrPtr = &testValue
|
||||||
|
assert.NotNil(t, s.NamePtr)
|
||||||
|
assert.Equal(t, "initialized", *s.NamePtr)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("pointer_field_already_allocated", func(t *testing.T) {
|
||||||
|
type testStruct struct {
|
||||||
|
NamePtr *string
|
||||||
|
}
|
||||||
|
initial := "existing"
|
||||||
|
s := testStruct{NamePtr: &initial}
|
||||||
|
v := reflect.ValueOf(&s).Elem()
|
||||||
|
|
||||||
|
namePtrField := v.Field(0)
|
||||||
|
assert.False(t, namePtrField.IsNil(), "pointer should not be nil initially")
|
||||||
|
|
||||||
|
result, err := getValueInterface(namePtrField)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
|
||||||
|
// Should return pointer to pointer field
|
||||||
|
ptrPtr, ok := result.(**string)
|
||||||
|
assert.True(t, ok)
|
||||||
|
|
||||||
|
// Verify it points to the existing value
|
||||||
|
assert.Equal(t, "existing", **ptrPtr)
|
||||||
|
|
||||||
|
// Modify through the returned pointer
|
||||||
|
newValue := "modified"
|
||||||
|
*ptrPtr = &newValue
|
||||||
|
assert.Equal(t, "modified", *s.NamePtr)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("pointer_field_zero_value", func(t *testing.T) {
|
||||||
|
type testStruct struct {
|
||||||
|
IntPtr *int
|
||||||
|
}
|
||||||
|
s := testStruct{}
|
||||||
|
v := reflect.ValueOf(&s).Elem()
|
||||||
|
|
||||||
|
intPtrField := v.Field(0)
|
||||||
|
result, err := getValueInterface(intPtrField)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// After calling getValueInterface, nil pointer should be allocated
|
||||||
|
assert.NotNil(t, s.IntPtr)
|
||||||
|
|
||||||
|
// Set zero value through returned interface
|
||||||
|
ptrPtr, ok := result.(**int)
|
||||||
|
assert.True(t, ok)
|
||||||
|
zero := 0
|
||||||
|
*ptrPtr = &zero
|
||||||
|
assert.Equal(t, 0, *s.IntPtr)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("not_addressable_value", func(t *testing.T) {
|
||||||
|
type testStruct struct {
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
s := testStruct{Name: "test"}
|
||||||
|
v := reflect.ValueOf(s) // Non-pointer, not addressable
|
||||||
|
|
||||||
|
nameField := v.Field(0)
|
||||||
|
result, err := getValueInterface(nameField)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Equal(t, ErrNotReadableValue, err)
|
||||||
|
assert.Nil(t, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("multiple_pointer_types", func(t *testing.T) {
|
||||||
|
type testStruct struct {
|
||||||
|
StringPtr *string
|
||||||
|
IntPtr *int
|
||||||
|
Int64Ptr *int64
|
||||||
|
FloatPtr *float64
|
||||||
|
BoolPtr *bool
|
||||||
|
}
|
||||||
|
s := testStruct{}
|
||||||
|
v := reflect.ValueOf(&s).Elem()
|
||||||
|
|
||||||
|
// Test each pointer type gets properly initialized
|
||||||
|
for i := 0; i < v.NumField(); i++ {
|
||||||
|
field := v.Field(i)
|
||||||
|
assert.True(t, field.IsNil(), "field %d should start as nil", i)
|
||||||
|
|
||||||
|
result, err := getValueInterface(field)
|
||||||
|
assert.NoError(t, err, "field %d should not error", i)
|
||||||
|
assert.NotNil(t, result, "field %d result should not be nil", i)
|
||||||
|
|
||||||
|
// After getValueInterface, pointer should be allocated
|
||||||
|
assert.False(t, field.IsNil(), "field %d should be allocated", i)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func stringPtr(s string) *string {
|
||||||
|
return &s
|
||||||
|
}
|
||||||
|
|
||||||
|
func int64Ptr(i int64) *int64 {
|
||||||
|
return &i
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkIgnore(b *testing.B) {
|
func BenchmarkIgnore(b *testing.B) {
|
||||||
db, mock, err := sqlmock.New()
|
db, mock, err := sqlmock.New()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -1,13 +1,12 @@
|
|||||||
package threading
|
package threading
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"io"
|
|
||||||
"log"
|
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/zeromicro/go-zero/core/logx/logtest"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestRoutineGroupRun(t *testing.T) {
|
func TestRoutineGroupRun(t *testing.T) {
|
||||||
@@ -25,7 +24,7 @@ func TestRoutineGroupRun(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRoutingGroupRunSafe(t *testing.T) {
|
func TestRoutingGroupRunSafe(t *testing.T) {
|
||||||
log.SetOutput(io.Discard)
|
logtest.Discard(t)
|
||||||
|
|
||||||
var count int32
|
var count int32
|
||||||
group := NewRoutineGroup()
|
group := NewRoutineGroup()
|
||||||
|
|||||||
@@ -3,13 +3,12 @@ package threading
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"io"
|
|
||||||
"log"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/zeromicro/go-zero/core/lang"
|
"github.com/zeromicro/go-zero/core/lang"
|
||||||
"github.com/zeromicro/go-zero/core/logx"
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
|
"github.com/zeromicro/go-zero/core/logx/logtest"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestRoutineId(t *testing.T) {
|
func TestRoutineId(t *testing.T) {
|
||||||
@@ -17,7 +16,7 @@ func TestRoutineId(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRunSafe(t *testing.T) {
|
func TestRunSafe(t *testing.T) {
|
||||||
log.SetOutput(io.Discard)
|
logtest.Discard(t)
|
||||||
|
|
||||||
i := 0
|
i := 0
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/core/lang"
|
|
||||||
"github.com/zeromicro/go-zero/core/logx"
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
"go.opentelemetry.io/otel"
|
"go.opentelemetry.io/otel"
|
||||||
"go.opentelemetry.io/otel/exporters/jaeger"
|
"go.opentelemetry.io/otel/exporters/jaeger"
|
||||||
@@ -30,42 +29,36 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
agents = make(map[string]lang.PlaceholderType)
|
once sync.Once
|
||||||
lock sync.Mutex
|
tp *sdktrace.TracerProvider
|
||||||
tp *sdktrace.TracerProvider
|
shutdownOnceFn = sync.OnceFunc(func() {
|
||||||
|
if tp != nil {
|
||||||
|
_ = tp.Shutdown(context.Background())
|
||||||
|
}
|
||||||
|
})
|
||||||
)
|
)
|
||||||
|
|
||||||
// StartAgent starts an opentelemetry agent.
|
// StartAgent starts an opentelemetry agent.
|
||||||
|
// It uses sync.Once to ensure the agent is initialized only once,
|
||||||
|
// similar to prometheus.StartAgent and logx.SetUp.
|
||||||
|
// This prevents multiple ServiceConf.SetUp() calls from reinitializing
|
||||||
|
// the global tracer provider when running multiple servers (e.g., REST + RPC)
|
||||||
|
// in the same process.
|
||||||
func StartAgent(c Config) {
|
func StartAgent(c Config) {
|
||||||
if c.Disabled {
|
if c.Disabled {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
lock.Lock()
|
once.Do(func() {
|
||||||
defer lock.Unlock()
|
if err := startAgent(c); err != nil {
|
||||||
|
logx.Error(err)
|
||||||
_, ok := agents[c.Endpoint]
|
}
|
||||||
if ok {
|
})
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// if error happens, let later calls run.
|
|
||||||
if err := startAgent(c); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
agents[c.Endpoint] = lang.Placeholder
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// StopAgent shuts down the span processors in the order they were registered.
|
// StopAgent shuts down the span processors in the order they were registered.
|
||||||
func StopAgent() {
|
func StopAgent() {
|
||||||
lock.Lock()
|
shutdownOnceFn()
|
||||||
defer lock.Unlock()
|
|
||||||
|
|
||||||
if tp != nil {
|
|
||||||
_ = tp.Shutdown(context.Background())
|
|
||||||
tp = nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func createExporter(c Config) (sdktrace.SpanExporter, error) {
|
func createExporter(c Config) (sdktrace.SpanExporter, error) {
|
||||||
|
|||||||
@@ -1,10 +1,13 @@
|
|||||||
package trace
|
package trace
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/zeromicro/go-zero/core/logx"
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
|
"go.opentelemetry.io/otel"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestStartAgent(t *testing.T) {
|
func TestStartAgent(t *testing.T) {
|
||||||
@@ -89,23 +92,305 @@ func TestStartAgent(t *testing.T) {
|
|||||||
StartAgent(c10)
|
StartAgent(c10)
|
||||||
defer StopAgent()
|
defer StopAgent()
|
||||||
|
|
||||||
lock.Lock()
|
// With sync.Once, only the first non-disabled config (c1) takes effect.
|
||||||
defer lock.Unlock()
|
// Subsequent calls are ignored, which is the desired behavior to prevent
|
||||||
|
// multiple servers (REST + RPC) from reinitializing the global tracer.
|
||||||
// because remotehost cannot be resolved
|
assert.NotNil(t, tp)
|
||||||
assert.Equal(t, 6, len(agents))
|
}
|
||||||
_, ok := agents[""]
|
|
||||||
assert.True(t, ok)
|
func TestCreateExporter_InvalidFilePath(t *testing.T) {
|
||||||
_, ok = agents[endpoint1]
|
logx.Disable()
|
||||||
assert.True(t, ok)
|
|
||||||
_, ok = agents[endpoint2]
|
c := Config{
|
||||||
assert.False(t, ok)
|
Name: "test-invalid-file",
|
||||||
_, ok = agents[endpoint5]
|
Endpoint: "/non-existent-directory/trace.log",
|
||||||
assert.True(t, ok)
|
Batcher: kindFile,
|
||||||
_, ok = agents[endpoint6]
|
}
|
||||||
assert.False(t, ok)
|
|
||||||
_, ok = agents[endpoint71]
|
_, err := createExporter(c)
|
||||||
assert.True(t, ok)
|
assert.Error(t, err)
|
||||||
_, ok = agents[endpoint72]
|
assert.Contains(t, err.Error(), "file exporter endpoint error")
|
||||||
assert.False(t, ok)
|
}
|
||||||
|
|
||||||
|
func TestCreateExporter_UnknownBatcher(t *testing.T) {
|
||||||
|
logx.Disable()
|
||||||
|
|
||||||
|
c := Config{
|
||||||
|
Name: "test-unknown",
|
||||||
|
Endpoint: "localhost:1234",
|
||||||
|
Batcher: "unknown-batcher-type",
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := createExporter(c)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "unknown exporter")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateExporter_ValidExporters(t *testing.T) {
|
||||||
|
logx.Disable()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
config Config
|
||||||
|
wantErr bool
|
||||||
|
errMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid file exporter",
|
||||||
|
config: Config{
|
||||||
|
Name: "file-test",
|
||||||
|
Endpoint: "/tmp/trace-test.log",
|
||||||
|
Batcher: kindFile,
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid file path",
|
||||||
|
config: Config{
|
||||||
|
Name: "file-test-invalid",
|
||||||
|
Endpoint: "/invalid-path/that/does/not/exist/trace.log",
|
||||||
|
Batcher: kindFile,
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errMsg: "file exporter endpoint error",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown batcher",
|
||||||
|
config: Config{
|
||||||
|
Name: "unknown-test",
|
||||||
|
Endpoint: "localhost:1234",
|
||||||
|
Batcher: "invalid-batcher",
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errMsg: "unknown exporter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "jaeger http",
|
||||||
|
config: Config{
|
||||||
|
Name: "jaeger-http",
|
||||||
|
Endpoint: "http://localhost:14268/api/traces",
|
||||||
|
Batcher: kindJaeger,
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "jaeger udp",
|
||||||
|
config: Config{
|
||||||
|
Name: "jaeger-udp",
|
||||||
|
Endpoint: "udp://localhost:6831",
|
||||||
|
Batcher: kindJaeger,
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zipkin",
|
||||||
|
config: Config{
|
||||||
|
Name: "zipkin",
|
||||||
|
Endpoint: "http://localhost:9411/api/v2/spans",
|
||||||
|
Batcher: kindZipkin,
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "otlpgrpc",
|
||||||
|
config: Config{
|
||||||
|
Name: "otlpgrpc",
|
||||||
|
Endpoint: "localhost:4317",
|
||||||
|
Batcher: kindOtlpGrpc,
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "otlpgrpc with headers",
|
||||||
|
config: Config{
|
||||||
|
Name: "otlpgrpc-headers",
|
||||||
|
Endpoint: "localhost:4317",
|
||||||
|
Batcher: kindOtlpGrpc,
|
||||||
|
OtlpHeaders: map[string]string{
|
||||||
|
"authorization": "Bearer token123",
|
||||||
|
"x-custom-key": "custom-value",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "otlphttp",
|
||||||
|
config: Config{
|
||||||
|
Name: "otlphttp",
|
||||||
|
Endpoint: "localhost:4318",
|
||||||
|
Batcher: kindOtlpHttp,
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "otlphttp with headers",
|
||||||
|
config: Config{
|
||||||
|
Name: "otlphttp-headers",
|
||||||
|
Endpoint: "localhost:4318",
|
||||||
|
Batcher: kindOtlpHttp,
|
||||||
|
OtlpHeaders: map[string]string{
|
||||||
|
"authorization": "Bearer token456",
|
||||||
|
"x-api-key": "api-key-value",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "otlphttp with headers and path",
|
||||||
|
config: Config{
|
||||||
|
Name: "otlphttp-headers-path",
|
||||||
|
Endpoint: "localhost:4318",
|
||||||
|
Batcher: kindOtlpHttp,
|
||||||
|
OtlpHttpPath: "/v1/traces",
|
||||||
|
OtlpHeaders: map[string]string{
|
||||||
|
"authorization": "Bearer token789",
|
||||||
|
"x-custom-trace": "trace-id",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "otlphttp with secure connection",
|
||||||
|
config: Config{
|
||||||
|
Name: "otlphttp-secure",
|
||||||
|
Endpoint: "localhost:4318",
|
||||||
|
Batcher: kindOtlpHttp,
|
||||||
|
OtlpHttpSecure: true,
|
||||||
|
OtlpHeaders: map[string]string{
|
||||||
|
"authorization": "Bearer secure-token",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
exporter, err := createExporter(tt.config)
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
if tt.errMsg != "" {
|
||||||
|
assert.Contains(t, err.Error(), tt.errMsg)
|
||||||
|
}
|
||||||
|
assert.Nil(t, exporter)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, exporter)
|
||||||
|
// Clean up the exporter
|
||||||
|
if exporter != nil {
|
||||||
|
_ = exporter.Shutdown(context.Background())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStopAgent(t *testing.T) {
|
||||||
|
logx.Disable()
|
||||||
|
|
||||||
|
// StopAgent should be idempotent and safe to call multiple times
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
StopAgent()
|
||||||
|
StopAgent()
|
||||||
|
StopAgent()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStartAgent_WithEndpoint(t *testing.T) {
|
||||||
|
logx.Disable()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
config Config
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty endpoint - no exporter created",
|
||||||
|
config: Config{
|
||||||
|
Name: "test-no-endpoint",
|
||||||
|
Sampler: 1.0,
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid endpoint with file exporter",
|
||||||
|
config: Config{
|
||||||
|
Name: "test-with-endpoint",
|
||||||
|
Endpoint: "/tmp/test-trace.log",
|
||||||
|
Batcher: kindFile,
|
||||||
|
Sampler: 1.0,
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "endpoint with invalid exporter type",
|
||||||
|
config: Config{
|
||||||
|
Name: "test-invalid-batcher",
|
||||||
|
Endpoint: "localhost:1234",
|
||||||
|
Batcher: "invalid-type",
|
||||||
|
Sampler: 1.0,
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "endpoint with invalid file path",
|
||||||
|
config: Config{
|
||||||
|
Name: "test-invalid-path",
|
||||||
|
Endpoint: "/non/existent/path/trace.log",
|
||||||
|
Batcher: kindFile,
|
||||||
|
Sampler: 1.0,
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Reset tp for each test
|
||||||
|
originalTp := tp
|
||||||
|
tp = nil
|
||||||
|
defer func() {
|
||||||
|
if tp != nil {
|
||||||
|
_ = tp.Shutdown(context.Background())
|
||||||
|
}
|
||||||
|
tp = originalTp
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := startAgent(tt.config)
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, tp, "TracerProvider should be created")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStartAgent_ErrorHandler(t *testing.T) {
|
||||||
|
// Setup a tracer provider to test error handler
|
||||||
|
originalTp := tp
|
||||||
|
tp = nil
|
||||||
|
defer func() {
|
||||||
|
if tp != nil {
|
||||||
|
_ = tp.Shutdown(context.Background())
|
||||||
|
}
|
||||||
|
tp = originalTp
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Call startAgent to set up the error handler
|
||||||
|
config := Config{
|
||||||
|
Name: "test-error-handler",
|
||||||
|
Sampler: 1.0,
|
||||||
|
}
|
||||||
|
err := startAgent(config)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, tp)
|
||||||
|
|
||||||
|
// Verify the error handler was set and can be called without panicking
|
||||||
|
// We test this by calling otel.Handle which will invoke the registered error handler
|
||||||
|
testErr := errors.New("test otel error")
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
otel.Handle(testErr)
|
||||||
|
}, "Error handler should handle errors without panicking")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,16 @@ import (
|
|||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// MetadataHeaderPrefix is the http prefix that represents custom metadata
|
||||||
|
// parameters to or from a gRPC call.
|
||||||
|
MetadataHeaderPrefix = "Grpc-Metadata-"
|
||||||
|
|
||||||
|
// MetadataTrailerPrefix is prepended to gRPC metadata as it is converted to
|
||||||
|
// HTTP headers in a response handled by go-zero gateway
|
||||||
|
MetadataTrailerPrefix = "Grpc-Trailer-"
|
||||||
|
)
|
||||||
|
|
||||||
type EventHandler struct {
|
type EventHandler struct {
|
||||||
Status *status.Status
|
Status *status.Status
|
||||||
writer io.Writer
|
writer io.Writer
|
||||||
@@ -31,9 +41,10 @@ func NewEventHandler(writer io.Writer, resolver jsonpb.AnyResolver) *EventHandle
|
|||||||
func (h *EventHandler) OnReceiveHeaders(md metadata.MD) {
|
func (h *EventHandler) OnReceiveHeaders(md metadata.MD) {
|
||||||
w, ok := h.writer.(http.ResponseWriter)
|
w, ok := h.writer.(http.ResponseWriter)
|
||||||
if ok {
|
if ok {
|
||||||
for k, v := range md {
|
for k, vs := range md {
|
||||||
for _, val := range v {
|
header := defaultOutgoingHeaderMatcher(k)
|
||||||
w.Header().Add(k, val)
|
for _, v := range vs {
|
||||||
|
w.Header().Add(header, v)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -48,9 +59,10 @@ func (h *EventHandler) OnReceiveResponse(message proto.Message) {
|
|||||||
func (h *EventHandler) OnReceiveTrailers(status *status.Status, md metadata.MD) {
|
func (h *EventHandler) OnReceiveTrailers(status *status.Status, md metadata.MD) {
|
||||||
w, ok := h.writer.(http.ResponseWriter)
|
w, ok := h.writer.(http.ResponseWriter)
|
||||||
if ok {
|
if ok {
|
||||||
for k, v := range md {
|
for k, vs := range md {
|
||||||
for _, val := range v {
|
header := defaultOutgoingTrailerMatcher(k)
|
||||||
w.Header().Add(k, val)
|
for _, v := range vs {
|
||||||
|
w.Header().Add(header, v)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -63,3 +75,11 @@ func (h *EventHandler) OnResolveMethod(_ *desc.MethodDescriptor) {
|
|||||||
|
|
||||||
func (h *EventHandler) OnSendHeaders(_ metadata.MD) {
|
func (h *EventHandler) OnSendHeaders(_ metadata.MD) {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func defaultOutgoingHeaderMatcher(key string) string {
|
||||||
|
return MetadataHeaderPrefix + key
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultOutgoingTrailerMatcher(key string) string {
|
||||||
|
return MetadataTrailerPrefix + key
|
||||||
|
}
|
||||||
|
|||||||
@@ -40,8 +40,8 @@ func TestEventHandler_OnReceiveTrailers(t *testing.T) {
|
|||||||
},
|
},
|
||||||
expectedStatus: codes.OK,
|
expectedStatus: codes.OK,
|
||||||
expectedHeader: map[string][]string{
|
expectedHeader: map[string][]string{
|
||||||
"X-Custom-Header": {"value1", "value2"},
|
"Grpc-Trailer-X-Custom-Header": {"value1", "value2"},
|
||||||
"X-Another-Header": {"single-value"},
|
"Grpc-Trailer-X-Another-Header": {"single-value"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -100,9 +100,9 @@ func TestEventHandler_OnReceiveHeaders(t *testing.T) {
|
|||||||
"x-another-header": []string{"single-value"},
|
"x-another-header": []string{"single-value"},
|
||||||
},
|
},
|
||||||
expectedHeader: map[string][]string{
|
expectedHeader: map[string][]string{
|
||||||
"Content-Type": {"application/json"},
|
"Grpc-Metadata-Content-Type": {"application/json"},
|
||||||
"X-Custom-Header": {"value1", "value2"},
|
"Grpc-Metadata-X-Custom-Header": {"value1", "value2"},
|
||||||
"X-Another-Header": {"single-value"},
|
"Grpc-Metadata-X-Another-Header": {"single-value"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -158,7 +158,81 @@ func TestEventHandler_OnReceiveHeaders_MultipleValues(t *testing.T) {
|
|||||||
"x-header-2": []string{"value3"},
|
"x-header-2": []string{"value3"},
|
||||||
})
|
})
|
||||||
|
|
||||||
// Check that headers are accumulated (not overwritten)
|
// Check that headers are accumulated (not overwritten) with proper prefix
|
||||||
assert.Equal(t, []string{"value1", "value2"}, recorder.Header()["X-Header-1"])
|
assert.Equal(t, []string{"value1", "value2"}, recorder.Header()["Grpc-Metadata-X-Header-1"])
|
||||||
assert.Equal(t, []string{"value3"}, recorder.Header()["X-Header-2"])
|
assert.Equal(t, []string{"value3"}, recorder.Header()["Grpc-Metadata-X-Header-2"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEventHandler_OnReceiveHeaders_MetadataPrefix(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
metadata metadata.MD
|
||||||
|
expectedHeader map[string][]string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "all metadata headers should be prefixed with Grpc-Metadata-",
|
||||||
|
metadata: metadata.MD{
|
||||||
|
"content-type": []string{"application/grpc"},
|
||||||
|
"x-custom-header": []string{"value1"},
|
||||||
|
"authorization": []string{"Bearer token"},
|
||||||
|
},
|
||||||
|
expectedHeader: map[string][]string{
|
||||||
|
"Grpc-Metadata-Content-Type": {"application/grpc"},
|
||||||
|
"Grpc-Metadata-X-Custom-Header": {"value1"},
|
||||||
|
"Grpc-Metadata-Authorization": {"Bearer token"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed case headers should be prefixed",
|
||||||
|
metadata: metadata.MD{
|
||||||
|
"Content-Type": []string{"APPLICATION/JSON"},
|
||||||
|
"X-Custom-Header": []string{"value1"},
|
||||||
|
},
|
||||||
|
expectedHeader: map[string][]string{
|
||||||
|
"Grpc-Metadata-Content-Type": {"APPLICATION/JSON"},
|
||||||
|
"Grpc-Metadata-X-Custom-Header": {"value1"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple values for same header",
|
||||||
|
metadata: metadata.MD{
|
||||||
|
"x-multi-header": []string{"value1", "value2", "value3"},
|
||||||
|
},
|
||||||
|
expectedHeader: map[string][]string{
|
||||||
|
"Grpc-Metadata-X-Multi-Header": {"value1", "value2", "value3"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty metadata",
|
||||||
|
metadata: metadata.MD{},
|
||||||
|
expectedHeader: map[string][]string{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
h := NewEventHandler(recorder, nil)
|
||||||
|
|
||||||
|
h.OnReceiveHeaders(tt.metadata)
|
||||||
|
|
||||||
|
// Check that headers are set correctly
|
||||||
|
for key, expectedValues := range tt.expectedHeader {
|
||||||
|
actualValues := recorder.Header()[key]
|
||||||
|
assert.Equal(t, expectedValues, actualValues, "Header %s should match", key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure no unexpected headers are set
|
||||||
|
for actualKey := range recorder.Header() {
|
||||||
|
found := false
|
||||||
|
for expectedKey := range tt.expectedHeader {
|
||||||
|
if actualKey == expectedKey {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.True(t, found, "Unexpected header found: %s", actualKey)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,16 +11,40 @@ const (
|
|||||||
metadataPrefix = "gateway-"
|
metadataPrefix = "gateway-"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// OpenTelemetry trace propagation headers that need to be forwarded to gRPC metadata.
|
||||||
|
// These headers are used by the W3C Trace Context standard for distributed tracing.
|
||||||
|
var traceHeaders = map[string]bool{
|
||||||
|
"traceparent": true,
|
||||||
|
"tracestate": true,
|
||||||
|
"baggage": true,
|
||||||
|
}
|
||||||
|
|
||||||
// ProcessHeaders builds the headers for the gateway from HTTP headers.
|
// ProcessHeaders builds the headers for the gateway from HTTP headers.
|
||||||
|
// It forwards both custom metadata headers (with Grpc-Metadata- prefix)
|
||||||
|
// and OpenTelemetry trace propagation headers (traceparent, tracestate, baggage)
|
||||||
|
// to ensure distributed tracing works correctly across the gateway.
|
||||||
func ProcessHeaders(header http.Header) []string {
|
func ProcessHeaders(header http.Header) []string {
|
||||||
var headers []string
|
var headers []string
|
||||||
|
|
||||||
for k, v := range header {
|
for k, v := range header {
|
||||||
|
// Forward OpenTelemetry trace propagation headers
|
||||||
|
// These must be lowercase per gRPC metadata conventions
|
||||||
|
if lowerKey := strings.ToLower(k); traceHeaders[lowerKey] {
|
||||||
|
for _, vv := range v {
|
||||||
|
headers = append(headers, lowerKey+":"+vv)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward custom metadata headers with Grpc-Metadata- prefix
|
||||||
if !strings.HasPrefix(k, metadataHeaderPrefix) {
|
if !strings.HasPrefix(k, metadataHeaderPrefix) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
key := fmt.Sprintf("%s%s", metadataPrefix, strings.TrimPrefix(k, metadataHeaderPrefix))
|
// gRPC metadata keys are case-insensitive and stored as lowercase,
|
||||||
|
// so we lowercase the key to match gRPC conventions
|
||||||
|
trimmedKey := strings.TrimPrefix(k, metadataHeaderPrefix)
|
||||||
|
key := strings.ToLower(fmt.Sprintf("%s%s", metadataPrefix, trimmedKey))
|
||||||
for _, vv := range v {
|
for _, vv := range v {
|
||||||
headers = append(headers, key+":"+vv)
|
headers = append(headers, key+":"+vv)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,5 +18,93 @@ func TestBuildHeadersWithValues(t *testing.T) {
|
|||||||
req := httptest.NewRequest("GET", "/", http.NoBody)
|
req := httptest.NewRequest("GET", "/", http.NoBody)
|
||||||
req.Header.Add("grpc-metadata-a", "b")
|
req.Header.Add("grpc-metadata-a", "b")
|
||||||
req.Header.Add("grpc-metadata-b", "b")
|
req.Header.Add("grpc-metadata-b", "b")
|
||||||
assert.ElementsMatch(t, []string{"gateway-A:b", "gateway-B:b"}, ProcessHeaders(req.Header))
|
assert.ElementsMatch(t, []string{"gateway-a:b", "gateway-b:b"}, ProcessHeaders(req.Header))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessHeadersWithTraceContext(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "/", http.NoBody)
|
||||||
|
req.Header.Set("traceparent", "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01")
|
||||||
|
req.Header.Set("tracestate", "key1=value1,key2=value2")
|
||||||
|
req.Header.Set("baggage", "userId=alice,serverNode=DF:28")
|
||||||
|
|
||||||
|
headers := ProcessHeaders(req.Header)
|
||||||
|
|
||||||
|
assert.Len(t, headers, 3)
|
||||||
|
assert.Contains(t, headers, "traceparent:00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01")
|
||||||
|
assert.Contains(t, headers, "tracestate:key1=value1,key2=value2")
|
||||||
|
assert.Contains(t, headers, "baggage:userId=alice,serverNode=DF:28")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessHeadersWithMixedHeaders(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "/", http.NoBody)
|
||||||
|
req.Header.Set("traceparent", "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01")
|
||||||
|
req.Header.Set("grpc-metadata-custom", "value1")
|
||||||
|
req.Header.Set("content-type", "application/json")
|
||||||
|
req.Header.Set("tracestate", "key1=value1")
|
||||||
|
|
||||||
|
headers := ProcessHeaders(req.Header)
|
||||||
|
|
||||||
|
// Should include trace headers and grpc-metadata headers, but not regular headers
|
||||||
|
assert.Len(t, headers, 3)
|
||||||
|
assert.Contains(t, headers, "traceparent:00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01")
|
||||||
|
assert.Contains(t, headers, "tracestate:key1=value1")
|
||||||
|
assert.Contains(t, headers, "gateway-custom:value1")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessHeadersTraceparentCaseInsensitive(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
headerKey string
|
||||||
|
headerVal string
|
||||||
|
expectedKey string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "lowercase traceparent",
|
||||||
|
headerKey: "traceparent",
|
||||||
|
headerVal: "00-trace-span-01",
|
||||||
|
expectedKey: "traceparent",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "uppercase Traceparent",
|
||||||
|
headerKey: "Traceparent",
|
||||||
|
headerVal: "00-trace-span-01",
|
||||||
|
expectedKey: "traceparent",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed case TraceParent",
|
||||||
|
headerKey: "TraceParent",
|
||||||
|
headerVal: "00-trace-span-01",
|
||||||
|
expectedKey: "traceparent",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "lowercase tracestate",
|
||||||
|
headerKey: "tracestate",
|
||||||
|
headerVal: "key=value",
|
||||||
|
expectedKey: "tracestate",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed case TraceState",
|
||||||
|
headerKey: "TraceState",
|
||||||
|
headerVal: "key=value",
|
||||||
|
expectedKey: "tracestate",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "/", http.NoBody)
|
||||||
|
req.Header.Set(tt.headerKey, tt.headerVal)
|
||||||
|
|
||||||
|
headers := ProcessHeaders(req.Header)
|
||||||
|
|
||||||
|
assert.Len(t, headers, 1)
|
||||||
|
assert.Contains(t, headers, tt.expectedKey+":"+tt.headerVal)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessHeadersEmptyHeaders(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "/", http.NoBody)
|
||||||
|
headers := ProcessHeaders(req.Header)
|
||||||
|
assert.Empty(t, headers)
|
||||||
}
|
}
|
||||||
|
|||||||
10
go.mod
10
go.mod
@@ -11,17 +11,17 @@ require (
|
|||||||
github.com/golang-jwt/jwt/v4 v4.5.2
|
github.com/golang-jwt/jwt/v4 v4.5.2
|
||||||
github.com/golang/protobuf v1.5.4
|
github.com/golang/protobuf v1.5.4
|
||||||
github.com/google/uuid v1.6.0
|
github.com/google/uuid v1.6.0
|
||||||
github.com/grafana/pyroscope-go v1.2.4
|
github.com/grafana/pyroscope-go v1.2.7
|
||||||
github.com/jackc/pgx/v5 v5.7.4
|
github.com/jackc/pgx/v5 v5.7.4
|
||||||
github.com/jhump/protoreflect v1.17.0
|
github.com/jhump/protoreflect v1.17.0
|
||||||
github.com/pelletier/go-toml/v2 v2.2.2
|
github.com/pelletier/go-toml/v2 v2.2.2
|
||||||
github.com/prometheus/client_golang v1.21.1
|
github.com/prometheus/client_golang v1.21.1
|
||||||
github.com/redis/go-redis/v9 v9.12.1
|
github.com/redis/go-redis/v9 v9.17.2
|
||||||
github.com/spaolacci/murmur3 v1.1.0
|
github.com/spaolacci/murmur3 v1.1.0
|
||||||
github.com/stretchr/testify v1.10.0
|
github.com/stretchr/testify v1.11.1
|
||||||
go.etcd.io/etcd/api/v3 v3.5.15
|
go.etcd.io/etcd/api/v3 v3.5.15
|
||||||
go.etcd.io/etcd/client/v3 v3.5.15
|
go.etcd.io/etcd/client/v3 v3.5.15
|
||||||
go.mongodb.org/mongo-driver/v2 v2.3.0
|
go.mongodb.org/mongo-driver/v2 v2.4.1
|
||||||
go.opentelemetry.io/otel v1.24.0
|
go.opentelemetry.io/otel v1.24.0
|
||||||
go.opentelemetry.io/otel/exporters/jaeger v1.17.0
|
go.opentelemetry.io/otel/exporters/jaeger v1.17.0
|
||||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.24.0
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.24.0
|
||||||
@@ -72,7 +72,7 @@ require (
|
|||||||
github.com/google/gnostic-models v0.6.8 // indirect
|
github.com/google/gnostic-models v0.6.8 // indirect
|
||||||
github.com/google/go-cmp v0.6.0 // indirect
|
github.com/google/go-cmp v0.6.0 // indirect
|
||||||
github.com/google/gofuzz v1.2.0 // indirect
|
github.com/google/gofuzz v1.2.0 // indirect
|
||||||
github.com/grafana/pyroscope-go/godeltaprof v0.1.8 // indirect
|
github.com/grafana/pyroscope-go/godeltaprof v0.1.9 // indirect
|
||||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 // indirect
|
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 // indirect
|
||||||
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 // indirect
|
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 // indirect
|
||||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||||
|
|||||||
20
go.sum
20
go.sum
@@ -78,10 +78,10 @@ github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1 h1:K6RDEckDVWvDI9JAJY
|
|||||||
github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
|
github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
|
||||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
github.com/grafana/pyroscope-go v1.2.4 h1:B22GMXz+O0nWLatxLuaP7o7L9dvP0clLvIpmeEQQM0Q=
|
github.com/grafana/pyroscope-go v1.2.7 h1:VWBBlqxjyR0Cwk2W6UrE8CdcdD80GOFNutj0Kb1T8ac=
|
||||||
github.com/grafana/pyroscope-go v1.2.4/go.mod h1:zzT9QXQAp2Iz2ZdS216UiV8y9uXJYQiGE1q8v1FyhqU=
|
github.com/grafana/pyroscope-go v1.2.7/go.mod h1:o/bpSLiJYYP6HQtvcoVKiE9s5RiNgjYTj1DhiddP2Pc=
|
||||||
github.com/grafana/pyroscope-go/godeltaprof v0.1.8 h1:iwOtYXeeVSAeYefJNaxDytgjKtUuKQbJqgAIjlnicKg=
|
github.com/grafana/pyroscope-go/godeltaprof v0.1.9 h1:c1Us8i6eSmkW+Ez05d3co8kasnuOY813tbMN8i/a3Og=
|
||||||
github.com/grafana/pyroscope-go/godeltaprof v0.1.8/go.mod h1:2+l7K7twW49Ct4wFluZD3tZ6e0SjanjcUUBPVD/UuGU=
|
github.com/grafana/pyroscope-go/godeltaprof v0.1.9/go.mod h1:2+l7K7twW49Ct4wFluZD3tZ6e0SjanjcUUBPVD/UuGU=
|
||||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 h1:bkypFPDjIYGfCYD5mRBvpqxfYX1YCS1PXdKYWi8FsN0=
|
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 h1:bkypFPDjIYGfCYD5mRBvpqxfYX1YCS1PXdKYWi8FsN0=
|
||||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0/go.mod h1:P+Lt/0by1T8bfcF3z737NnSbmxQAppXMRziHUxPOC8k=
|
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0/go.mod h1:P+Lt/0by1T8bfcF3z737NnSbmxQAppXMRziHUxPOC8k=
|
||||||
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw=
|
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw=
|
||||||
@@ -154,8 +154,8 @@ github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ
|
|||||||
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
|
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
|
||||||
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
|
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
|
||||||
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
|
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
|
||||||
github.com/redis/go-redis/v9 v9.12.1 h1:k5iquqv27aBtnTm2tIkROUDp8JBXhXZIVu1InSgvovg=
|
github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4ViluI=
|
||||||
github.com/redis/go-redis/v9 v9.12.1/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
|
github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
|
||||||
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
||||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||||
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
||||||
@@ -176,8 +176,8 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO
|
|||||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c=
|
github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c=
|
||||||
github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI=
|
github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI=
|
||||||
github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY=
|
github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY=
|
||||||
@@ -197,8 +197,8 @@ go.etcd.io/etcd/client/pkg/v3 v3.5.15 h1:fo0HpWz/KlHGMCC+YejpiCmyWDEuIpnTDzpJLB5
|
|||||||
go.etcd.io/etcd/client/pkg/v3 v3.5.15/go.mod h1:mXDI4NAOwEiszrHCb0aqfAYNCrZP4e9hRca3d1YK8EU=
|
go.etcd.io/etcd/client/pkg/v3 v3.5.15/go.mod h1:mXDI4NAOwEiszrHCb0aqfAYNCrZP4e9hRca3d1YK8EU=
|
||||||
go.etcd.io/etcd/client/v3 v3.5.15 h1:23M0eY4Fd/inNv1ZfU3AxrbbOdW79r9V9Rl62Nm6ip4=
|
go.etcd.io/etcd/client/v3 v3.5.15 h1:23M0eY4Fd/inNv1ZfU3AxrbbOdW79r9V9Rl62Nm6ip4=
|
||||||
go.etcd.io/etcd/client/v3 v3.5.15/go.mod h1:CLSJxrYjvLtHsrPKsy7LmZEE+DK2ktfd2bN4RhBMwlU=
|
go.etcd.io/etcd/client/v3 v3.5.15/go.mod h1:CLSJxrYjvLtHsrPKsy7LmZEE+DK2ktfd2bN4RhBMwlU=
|
||||||
go.mongodb.org/mongo-driver/v2 v2.3.0 h1:sh55yOXA2vUjW1QYw/2tRlHSQViwDyPnW61AwpZ4rtU=
|
go.mongodb.org/mongo-driver/v2 v2.4.1 h1:hGDMngUao03OVQ6sgV5csk+RWOIkF+CuLsTPobNMGNI=
|
||||||
go.mongodb.org/mongo-driver/v2 v2.3.0/go.mod h1:jHeEDJHJq7tm6ZF45Issun9dbogjfnPySb1vXA7EeAI=
|
go.mongodb.org/mongo-driver/v2 v2.4.1/go.mod h1:jHeEDJHJq7tm6ZF45Issun9dbogjfnPySb1vXA7EeAI=
|
||||||
go.opentelemetry.io/otel v1.24.0 h1:0LAOdjNmQeSTzGBzduGe/rU4tZhMwL5rWgtp9Ku5Jfo=
|
go.opentelemetry.io/otel v1.24.0 h1:0LAOdjNmQeSTzGBzduGe/rU4tZhMwL5rWgtp9Ku5Jfo=
|
||||||
go.opentelemetry.io/otel v1.24.0/go.mod h1:W7b9Ozg4nkF5tWI5zsXkaKKDjdVjpD4oAt9Qi/MArHo=
|
go.opentelemetry.io/otel v1.24.0/go.mod h1:W7b9Ozg4nkF5tWI5zsXkaKKDjdVjpD4oAt9Qi/MArHo=
|
||||||
go.opentelemetry.io/otel/exporters/jaeger v1.17.0 h1:D7UpUy2Xc2wsi1Ras6V40q806WM07rqoCWzXu7Sqy+4=
|
go.opentelemetry.io/otel/exporters/jaeger v1.17.0 h1:D7UpUy2Xc2wsi1Ras6V40q806WM07rqoCWzXu7Sqy+4=
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ func AddProbe(probe Probe) {
|
|||||||
defaultHealthManager.addProbe(probe)
|
defaultHealthManager.addProbe(probe)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateHttpHandler create health http handler base on given probe.
|
// CreateHttpHandler creates a health http handler based on the given probe.
|
||||||
func CreateHttpHandler(healthResponse string) http.HandlerFunc {
|
func CreateHttpHandler(healthResponse string) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, _ *http.Request) {
|
return func(w http.ResponseWriter, _ *http.Request) {
|
||||||
if defaultHealthManager.IsReady() {
|
if defaultHealthManager.IsReady() {
|
||||||
|
|||||||
167
readme-cn.md
167
readme-cn.md
@@ -17,7 +17,7 @@
|
|||||||
<a href="https://trendshift.io/repositories/3263" target="_blank"><img src="https://trendshift.io/api/badge/repositories/3263" alt="zeromicro%2Fgo-zero | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
<a href="https://trendshift.io/repositories/3263" target="_blank"><img src="https://trendshift.io/api/badge/repositories/3263" alt="zeromicro%2Fgo-zero | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||||
<a href="https://www.producthunt.com/posts/go-zero?utm_source=badge-featured&utm_medium=badge&utm_souce=badge-go-zero" target="_blank"><img src="https://api.producthunt.com/widgets/embed-image/v1/featured.svg?post_id=334030&theme=light" alt="go-zero - A web & rpc framework written in Go. | Product Hunt" style="width: 250px; height: 54px;" width="250" height="54" /></a>
|
<a href="https://www.producthunt.com/posts/go-zero?utm_source=badge-featured&utm_medium=badge&utm_souce=badge-go-zero" target="_blank"><img src="https://api.producthunt.com/widgets/embed-image/v1/featured.svg?post_id=334030&theme=light" alt="go-zero - A web & rpc framework written in Go. | Product Hunt" style="width: 250px; height: 54px;" width="250" height="54" /></a>
|
||||||
|
|
||||||
## 0. go-zero 介绍
|
## go-zero 介绍
|
||||||
|
|
||||||
go-zero(收录于 CNCF 云原生技术全景图:[https://landscape.cncf.io/?selected=go-zero](https://landscape.cncf.io/?selected=go-zero))是一个集成了各种工程实践的 web 和 rpc 框架。通过弹性设计保障了大并发服务端的稳定性,经受了充分的实战检验。
|
go-zero(收录于 CNCF 云原生技术全景图:[https://landscape.cncf.io/?selected=go-zero](https://landscape.cncf.io/?selected=go-zero))是一个集成了各种工程实践的 web 和 rpc 框架。通过弹性设计保障了大并发服务端的稳定性,经受了充分的实战检验。
|
||||||
|
|
||||||
@@ -25,72 +25,50 @@ go-zero 包含极简的 API 定义和生成工具 goctl,可以根据定义的
|
|||||||
|
|
||||||
使用 go-zero 的好处:
|
使用 go-zero 的好处:
|
||||||
|
|
||||||
* 轻松获得支撑千万日活服务的稳定性
|
* 经过千万日活服务验证的稳定性
|
||||||
* 内建级联超时控制、限流、自适应熔断、自适应降载等微服务治理能力,无需配置和额外代码
|
* 内建弹性保护:级联超时、限流、熔断、降载(无需配置)
|
||||||
* 微服务治理中间件可无缝集成到其它现有框架使用
|
* 极简 API 语法生成多端代码
|
||||||
* 极简的 API 描述,一键生成各端代码
|
* 自动参数校验和丰富的微服务工具包
|
||||||
* 自动校验客户端请求参数合法性
|
|
||||||
* 大量微服务治理和并发工具包
|
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
## 1. go-zero 框架背景
|
## go-zero 框架背景
|
||||||
|
|
||||||
18 年初,我们决定从 `Java+MongoDB` 的单体架构迁移到微服务架构,经过仔细思考和对比,我们决定:
|
18 年初,我们决定从 `Java+MongoDB` 的单体架构迁移到微服务架构,选择:
|
||||||
|
|
||||||
* 基于 Go 语言
|
* **基于 Go 语言** - 高效性能、简洁语法、极致部署体验、极低资源成本
|
||||||
* 高效的性能
|
* **自研微服务框架** - 更快速的问题定位、更便捷的新特性增加
|
||||||
* 简洁的语法
|
|
||||||
* 广泛验证的工程效率
|
|
||||||
* 极致的部署体验
|
|
||||||
* 极低的服务端资源成本
|
|
||||||
* 自研微服务框架
|
|
||||||
* 有过很多微服务框架自研经验
|
|
||||||
* 需要有更快速的问题定位能力
|
|
||||||
* 更便捷的增加新特性
|
|
||||||
|
|
||||||
## 2. go-zero 框架设计思考
|
## go-zero 框架设计思考
|
||||||
|
|
||||||
对于微服务框架的设计,我们期望保障微服务稳定性的同时,也要特别注重研发效率。所以设计之初,我们就有如下一些准则:
|
go-zero 遵循以下核心设计准则:
|
||||||
|
|
||||||
* 保持简单,第一原则
|
* **保持简单** - 简单是第一原则
|
||||||
* 弹性设计,面向故障编程
|
* **高可用** - 高并发、易扩展
|
||||||
* 工具大于约定和文档
|
* **弹性设计** - 面向故障编程
|
||||||
* 高可用、高并发、易扩展
|
* **工具驱动** - 工具大于约定和文档
|
||||||
* 对业务开发友好,封装复杂度
|
* **业务友好** - 封装复杂度、一事一法
|
||||||
* 约束做一件事只有一种方式
|
|
||||||
|
|
||||||
我们经历不到半年时间,彻底完成了从 `Java+MongoDB` 到 `Golang+MySQL` 为主的微服务体系迁移,并于 18 年 8 月底完全上线,稳定保障了业务后续迅速增长,确保了整个服务的高可用。
|
## go-zero 项目实现和特点
|
||||||
|
|
||||||
## 3. go-zero 项目实现和特点
|
go-zero 集成各种工程实践,主要特点:
|
||||||
|
|
||||||
go-zero 是一个集成了各种工程实践的包含 web 和 rpc 框架,有如下主要特点:
|
* **强大工具支持** - 尽可能少的代码编写
|
||||||
|
* **极简接口** - 完全兼容 net/http
|
||||||
* 强大的工具支持,尽可能少的代码编写
|
* **高性能** - 优化的速度和效率
|
||||||
* 极简的接口
|
* **弹性设计** - 内建限流、熔断、降载,自动触发、自动恢复
|
||||||
* 完全兼容 net/http
|
* **服务治理** - 内建服务发现、负载均衡、链路跟踪
|
||||||
* 支持中间件,方便扩展
|
* **开发工具** - API 参数自动校验、超时级联控制、自动缓存控制
|
||||||
* 高性能
|
|
||||||
* 面向故障编程,弹性设计
|
|
||||||
* 内建服务发现、负载均衡
|
|
||||||
* 内建限流、熔断、降载,且自动触发,自动恢复
|
|
||||||
* API 参数自动校验
|
|
||||||
* 超时级联控制
|
|
||||||
* 自动缓存控制
|
|
||||||
* 链路跟踪、统计报警等
|
|
||||||
* 高并发支撑,稳定保障了疫情期间每天的流量洪峰
|
|
||||||
|
|
||||||
如下图,我们从多个层面保障了整体服务的高可用:
|
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
## 4. 我们使用 go-zero 的基本架构图
|
## 我们使用 go-zero 的基本架构图
|
||||||
|
|
||||||
<img width="1067" alt="image" src="https://user-images.githubusercontent.com/1918356/171880582-11a86658-41c3-466c-95e7-7b1220eecc52.png">
|
<img width="1067" alt="image" src="https://user-images.githubusercontent.com/1918356/171880582-11a86658-41c3-466c-95e7-7b1220eecc52.png">
|
||||||
|
|
||||||
觉得不错的话,别忘 **star** 👏
|
觉得不错的话,别忘 **star** 👏
|
||||||
|
|
||||||
## 5. Installation
|
## Installation
|
||||||
|
|
||||||
在项目目录下通过如下命令安装:
|
在项目目录下通过如下命令安装:
|
||||||
|
|
||||||
@@ -98,7 +76,57 @@ go-zero 是一个集成了各种工程实践的包含 web 和 rpc 框架,有
|
|||||||
GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/zeromicro/go-zero
|
GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/zeromicro/go-zero
|
||||||
```
|
```
|
||||||
|
|
||||||
## 6. Quick Start
|
## AI 原生开发
|
||||||
|
|
||||||
|
go-zero 团队构建了完整的 AI 工具生态,让 Claude、GitHub Copilot、Cursor 生成符合 go-zero 规范的代码。
|
||||||
|
|
||||||
|
### 三大核心项目
|
||||||
|
|
||||||
|
**[ai-context](https://github.com/zeromicro/ai-context)** - AI 的工作流程指南
|
||||||
|
|
||||||
|
**[zero-skills](https://github.com/zeromicro/zero-skills)** - 模式库和示例
|
||||||
|
|
||||||
|
**[mcp-zero](https://github.com/zeromicro/mcp-zero)** - 基于 MCP 的代码生成工具
|
||||||
|
|
||||||
|
### 快速配置
|
||||||
|
|
||||||
|
#### GitHub Copilot
|
||||||
|
```bash
|
||||||
|
git submodule add https://github.com/zeromicro/ai-context.git .github/ai-context
|
||||||
|
ln -s ai-context/00-instructions.md .github/copilot-instructions.md # macOS/Linux
|
||||||
|
# Windows: mklink .github\copilot-instructions.md .github\ai-context\00-instructions.md
|
||||||
|
git submodule update --remote .github/ai-context # 更新
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Cursor
|
||||||
|
```bash
|
||||||
|
git submodule add https://github.com/zeromicro/ai-context.git .cursorrules
|
||||||
|
git submodule update --remote .cursorrules # 更新
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Windsurf
|
||||||
|
```bash
|
||||||
|
git submodule add https://github.com/zeromicro/ai-context.git .windsurfrules
|
||||||
|
git submodule update --remote .windsurfrules # 更新
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Claude Desktop
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/zeromicro/mcp-zero.git && cd mcp-zero && go build
|
||||||
|
# 配置: ~/Library/Application Support/Claude/claude_desktop_config.json
|
||||||
|
# 或: claude mcp add --transport stdio mcp-zero --env GOCTL_PATH=/path/to/goctl -- /path/to/mcp-zero
|
||||||
|
```
|
||||||
|
|
||||||
|
### 协同工作原理
|
||||||
|
|
||||||
|
AI 助手通过三个工具协同配合:
|
||||||
|
1. **ai-context** - 工作流程指导
|
||||||
|
2. **zero-skills** - 实现模式
|
||||||
|
3. **mcp-zero** - 实时代码生成
|
||||||
|
|
||||||
|
**示例**:创建新的 REST API → AI 读取 **ai-context** 了解工作流 → 调用 **mcp-zero** 生成代码 → 参考 **zero-skills** 实现模式 → 生成符合规范的代码 ✅
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
0. 完整示例请查看
|
0. 完整示例请查看
|
||||||
|
|
||||||
@@ -108,23 +136,22 @@ GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/zeromicro
|
|||||||
|
|
||||||
1. 安装 goctl 工具
|
1. 安装 goctl 工具
|
||||||
|
|
||||||
`goctl` 读作 `go control`,不要读成 `go C-T-L`。`goctl` 的意思是不要被代码控制,而是要去控制它。其中的 `go` 不是指 `golang`。在设计 `goctl` 之初,我就希望通过 `工具` 来解放我们的双手👈
|
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
# Go
|
# Go
|
||||||
GOPROXY=https://goproxy.cn/,direct go install github.com/zeromicro/go-zero/tools/goctl@latest
|
GOPROXY=https://goproxy.cn/,direct go install github.com/zeromicro/go-zero/tools/goctl@latest
|
||||||
|
|
||||||
# For Mac
|
# For Mac
|
||||||
brew install goctl
|
brew install goctl
|
||||||
|
|
||||||
# docker for all platforms
|
# docker for all platforms
|
||||||
docker pull kevinwan/goctl
|
docker pull kevinwan/goctl
|
||||||
# run goctl
|
# run goctl
|
||||||
docker run --rm -it -v `pwd`:/app kevinwan/goctl --help
|
docker run --rm -it -v `pwd`:/app kevinwan/goctl --help
|
||||||
```
|
```
|
||||||
|
|
||||||
确保 goctl 可执行,并且在 $PATH 环境变量里。
|
确保 goctl 可执行并在 $PATH 环境变量里。
|
||||||
|
|
||||||
2. 快速生成 api 服务
|
2. 快速生成 api 服务
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
@@ -157,7 +184,7 @@ GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/zeromicro
|
|||||||
* 可以在 `servicecontext.go` 里面传递依赖给 logic,比如 mysql, redis 等
|
* 可以在 `servicecontext.go` 里面传递依赖给 logic,比如 mysql, redis 等
|
||||||
* 在 api 定义的 `get/post/put/delete` 等请求对应的 logic 里增加业务处理逻辑
|
* 在 api 定义的 `get/post/put/delete` 等请求对应的 logic 里增加业务处理逻辑
|
||||||
|
|
||||||
3. 可以根据 api 文件生成前端需要的 Java, TypeScript, Dart, JavaScript 代码
|
3. 生成多语言客户端代码
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
goctl api java -api greet.api -dir greet
|
goctl api java -api greet.api -dir greet
|
||||||
@@ -165,17 +192,17 @@ GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/zeromicro
|
|||||||
...
|
...
|
||||||
```
|
```
|
||||||
|
|
||||||
## 7. Benchmark
|
## Benchmark
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
[测试代码见这里](https://github.com/smallnest/go-web-framework-benchmark)
|
[测试代码见这里](https://github.com/smallnest/go-web-framework-benchmark)
|
||||||
|
|
||||||
## 8. 文档
|
## 文档
|
||||||
|
|
||||||
* API 文档
|
* API 文档
|
||||||
|
|
||||||
[https://go-zero.dev/cn/](https://go-zero.dev/cn/)
|
[https://go-zero.dev](https://go-zero.dev)
|
||||||
|
|
||||||
* awesome 系列(更多文章见『微服务实践』公众号)
|
* awesome 系列(更多文章见『微服务实践』公众号)
|
||||||
|
|
||||||
@@ -192,9 +219,9 @@ GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/zeromicro
|
|||||||
| [goctl-android](https://github.com/zeromicro/goctl-android) | 生成 `java (android)` 端 `http client` 请求代码 |
|
| [goctl-android](https://github.com/zeromicro/goctl-android) | 生成 `java (android)` 端 `http client` 请求代码 |
|
||||||
| [goctl-go-compact](https://github.com/zeromicro/goctl-go-compact) | 合并 `api` 里同一个 `group` 里的 `handler` 到一个 `go` 文件 |
|
| [goctl-go-compact](https://github.com/zeromicro/goctl-go-compact) | 合并 `api` 里同一个 `group` 里的 `handler` 到一个 `go` 文件 |
|
||||||
|
|
||||||
## 9. go-zero 用户
|
## go-zero 用户
|
||||||
|
|
||||||
go-zero 已被许多公司用于生产部署,接入场景如在线教育、电商业务、游戏、区块链等,目前为止,已使用 go-zero 的公司包括但不限于:
|
go-zero 已被众多公司用于生产部署,场景涵盖在线教育、电商、游戏、区块链等。目前使用 go-zero 的公司包括但不限于:
|
||||||
|
|
||||||
>1. 好未来
|
>1. 好未来
|
||||||
>2. 上海晓信信息科技有限公司(晓黑板)
|
>2. 上海晓信信息科技有限公司(晓黑板)
|
||||||
@@ -304,10 +331,14 @@ go-zero 已被许多公司用于生产部署,接入场景如在线教育、电
|
|||||||
>106. 无锡盛算信息技术有限公司
|
>106. 无锡盛算信息技术有限公司
|
||||||
>107. 深圳市聚货通信息科技有限公司
|
>107. 深圳市聚货通信息科技有限公司
|
||||||
>108. 浙江银盾云科技有限公司
|
>108. 浙江银盾云科技有限公司
|
||||||
|
>109. 南京造世网络科技有限公司
|
||||||
|
>110. 温州飞儿云信息技术有限公司
|
||||||
|
>111. 统信软件
|
||||||
|
>112. 深圳坐标软件集团有限公司
|
||||||
|
|
||||||
如果贵公司也已使用 go-zero,欢迎在 [登记地址](https://github.com/zeromicro/go-zero/issues/602) 登记,仅仅为了推广,不做其它用途。
|
如果贵公司也已使用 go-zero,欢迎在 [登记地址](https://github.com/zeromicro/go-zero/issues/602) 登记,仅仅为了推广,不做其它用途。
|
||||||
|
|
||||||
## 10. CNCF 云原生技术全景图
|
## CNCF 云原生技术全景图
|
||||||
|
|
||||||
<p float="left">
|
<p float="left">
|
||||||
<img src="https://raw.githubusercontent.com/zeromicro/zero-doc/main/doc/images/cncf-logo.svg" width="200"/>
|
<img src="https://raw.githubusercontent.com/zeromicro/zero-doc/main/doc/images/cncf-logo.svg" width="200"/>
|
||||||
@@ -316,13 +347,13 @@ go-zero 已被许多公司用于生产部署,接入场景如在线教育、电
|
|||||||
|
|
||||||
go-zero 收录在 [CNCF Cloud Native 云原生技术全景图](https://landscape.cncf.io/?selected=go-zero)。
|
go-zero 收录在 [CNCF Cloud Native 云原生技术全景图](https://landscape.cncf.io/?selected=go-zero)。
|
||||||
|
|
||||||
## 11. 微信公众号
|
## 微信公众号
|
||||||
|
|
||||||
`go-zero` 相关文章和视频都会在 `微服务实践` 公众号整理呈现,欢迎扫码关注 👏
|
`go-zero` 相关文章和视频都会在 `微服务实践` 公众号整理呈现,欢迎扫码关注 👏
|
||||||
|
|
||||||
<img src="https://raw.githubusercontent.com/zeromicro/zero-doc/main/doc/images/zeromicro.jpg" alt="wechat" width="600" />
|
<img src="https://raw.githubusercontent.com/zeromicro/zero-doc/main/doc/images/zeromicro.jpg" alt="wechat" width="600" />
|
||||||
|
|
||||||
## 12. 微信交流群
|
## 微信交流群
|
||||||
|
|
||||||
如果文档中未能覆盖的任何疑问,欢迎您在群里提出,我们会尽快答复。
|
如果文档中未能覆盖的任何疑问,欢迎您在群里提出,我们会尽快答复。
|
||||||
|
|
||||||
@@ -332,10 +363,4 @@ go-zero 收录在 [CNCF Cloud Native 云原生技术全景图](https://landscape
|
|||||||
|
|
||||||
加群之前有劳点一下 ***star***,一个小小的 ***star*** 是作者们回答海量问题的动力!🤝
|
加群之前有劳点一下 ***star***,一个小小的 ***star*** 是作者们回答海量问题的动力!🤝
|
||||||
|
|
||||||
<img src="https://raw.githubusercontent.com/zeromicro/zero-doc/main/doc/images/wechat.jpg" alt="wechat" width="300" />
|
<img src="https://raw.githubusercontent.com/zeromicro/zero-doc/main/doc/images/wechat.jpg" alt="wechat" width="300" />
|
||||||
|
|
||||||
## 13. 知识星球
|
|
||||||
|
|
||||||
官方团队运营的知识星球
|
|
||||||
|
|
||||||
<img src="https://raw.githubusercontent.com/zeromicro/zero-doc/main/doc/images/zsxq.jpg" alt="知识星球" width="300" />
|
|
||||||
152
readme.md
152
readme.md
@@ -42,61 +42,39 @@ go-zero contains simple API description syntax and code generation tool called `
|
|||||||
|
|
||||||
## Backgrounds of go-zero
|
## Backgrounds of go-zero
|
||||||
|
|
||||||
In early 2018, we embarked on a transformative journey to redesign our system, transitioning from a monolithic architecture built with Java and MongoDB to a microservices architecture. After careful research and comparison, we made a deliberate choice to:
|
In early 2018, we transitioned from a Java+MongoDB monolithic architecture to microservices, choosing:
|
||||||
|
|
||||||
* Go Beyond with Golang
|
* **Golang** - High performance, simple syntax, excellent deployment experience, and low resource consumption
|
||||||
* Great performance
|
* **Self-designed microservice framework** - Better problem isolation, easier feature extension, and faster issue resolution
|
||||||
* Simple syntax
|
|
||||||
* Proven engineering efficiency
|
|
||||||
* Extreme deployment experience
|
|
||||||
* Less server resource consumption
|
|
||||||
|
|
||||||
* Self-Design Our Microservice Architecture
|
|
||||||
* Microservice architecture facilitates the creation of scalable, flexible, and maintainable software systems with independent, reusable components.
|
|
||||||
* Easy to locate the problems within microservices.
|
|
||||||
* Easy to extend the features by adding or modifying specific microservices without impacting the entire system.
|
|
||||||
|
|
||||||
## Design considerations on go-zero
|
## Design considerations on go-zero
|
||||||
|
|
||||||
By designing the microservice architecture, we expected to ensure stability, as well as productivity. And from just the beginning, we have the following design principles:
|
go-zero follows these core design principles:
|
||||||
|
|
||||||
* Keep it simple
|
* **Simplicity** - Keep it simple, first principle
|
||||||
* High availability
|
* **High availability** - Stable under high concurrency
|
||||||
* Stable on high concurrency
|
* **Resilience** - Failure-oriented programming with adaptive protection
|
||||||
* Easy to extend
|
* **Developer friendly** - Encapsulate complexity, one way to do one thing
|
||||||
* Resilience design, failure-oriented programming
|
* **Easy to extend** - Flexible architecture for growth
|
||||||
* Try best to be friendly to the business logic development, encapsulate the complexity
|
|
||||||
* One thing, one way
|
|
||||||
|
|
||||||
After almost half a year, we finished the transfer from a monolithic system to microservice system and deployed on August 2018. The new system guaranteed business growth and system stability.
|
|
||||||
|
|
||||||
## The implementation and features of go-zero
|
## The implementation and features of go-zero
|
||||||
|
|
||||||
go-zero is a web and rpc framework that integrates lots of engineering practices. The features are mainly listed below:
|
go-zero integrates engineering best practices:
|
||||||
|
|
||||||
* Powerful tool included, less code to write
|
* **Code generation** - Powerful tools to minimize boilerplate
|
||||||
* Simple interfaces
|
* **Simple API** - Clean interfaces, fully compatible with net/http
|
||||||
* Fully compatible with net/http
|
* **High performance** - Optimized for speed and efficiency
|
||||||
* Middlewares are supported, easy to extend
|
* **Resilience** - Built-in circuit breaker, rate limiting, load shedding, timeout control
|
||||||
* High performance
|
* **Service mesh** - Service discovery, load balancing, call tracing
|
||||||
* Failure-oriented programming, resilience design
|
* **Developer tools** - Auto parameter validation, cache management, metrics and monitoring
|
||||||
* Builtin service discovery, load balancing
|
|
||||||
* Builtin concurrency control, adaptive circuit breaker, adaptive load shedding, auto-trigger, auto recover
|
|
||||||
* Auto validation of API request parameters
|
|
||||||
* Chained timeout control
|
|
||||||
* Auto management of data caching
|
|
||||||
* Call tracing, metrics, and monitoring
|
|
||||||
* High concurrency protected
|
|
||||||
|
|
||||||
As below, go-zero protects the system with a couple of layers and mechanisms:
|
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
## The simplified architecture that we use with go-zero
|
## Architecture with go-zero
|
||||||
|
|
||||||
<img width="1067" alt="image" src="https://user-images.githubusercontent.com/1918356/171880372-5010d846-e8b1-4942-8fe2-e2bbb584f762.png">
|
<img width="1067" alt="image" src="https://user-images.githubusercontent.com/1918356/171880372-5010d846-e8b1-4942-8fe2-e2bbb584f762.png">
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
Run the following command under your project:
|
Run the following command under your project:
|
||||||
|
|
||||||
@@ -104,9 +82,59 @@ Run the following command under your project:
|
|||||||
go get -u github.com/zeromicro/go-zero
|
go get -u github.com/zeromicro/go-zero
|
||||||
```
|
```
|
||||||
|
|
||||||
## Quick Start
|
## AI-Native Development
|
||||||
|
|
||||||
1. Full examples can be checked out from below:
|
The go-zero team provides AI tooling for Claude, GitHub Copilot, Cursor to generate framework-compliant code.
|
||||||
|
|
||||||
|
### Three Core Projects
|
||||||
|
|
||||||
|
**[ai-context](https://github.com/zeromicro/ai-context)** - Workflow guide for AI assistants
|
||||||
|
|
||||||
|
**[zero-skills](https://github.com/zeromicro/zero-skills)** - Pattern library with examples
|
||||||
|
|
||||||
|
**[mcp-zero](https://github.com/zeromicro/mcp-zero)** - Code generation tools via Model Context Protocol
|
||||||
|
|
||||||
|
### Quick Setup
|
||||||
|
|
||||||
|
#### GitHub Copilot
|
||||||
|
```bash
|
||||||
|
git submodule add https://github.com/zeromicro/ai-context.git .github/ai-context
|
||||||
|
ln -s ai-context/00-instructions.md .github/copilot-instructions.md # macOS/Linux
|
||||||
|
# Windows: mklink .github\copilot-instructions.md .github\ai-context\00-instructions.md
|
||||||
|
git submodule update --remote .github/ai-context # Update
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Cursor
|
||||||
|
```bash
|
||||||
|
git submodule add https://github.com/zeromicro/ai-context.git .cursorrules
|
||||||
|
git submodule update --remote .cursorrules # Update
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Windsurf
|
||||||
|
```bash
|
||||||
|
git submodule add https://github.com/zeromicro/ai-context.git .windsurfrules
|
||||||
|
git submodule update --remote .windsurfrules # Update
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Claude Desktop
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/zeromicro/mcp-zero.git && cd mcp-zero && go build
|
||||||
|
# Configure: ~/Library/Application Support/Claude/claude_desktop_config.json
|
||||||
|
# Or: claude mcp add --transport stdio mcp-zero --env GOCTL_PATH=/path/to/goctl -- /path/to/mcp-zero
|
||||||
|
```
|
||||||
|
|
||||||
|
### How It Works
|
||||||
|
|
||||||
|
AI assistants use these tools together:
|
||||||
|
1. **ai-context** - workflow guidance
|
||||||
|
2. **zero-skills** - implementation patterns
|
||||||
|
3. **mcp-zero** - real-time code generation
|
||||||
|
|
||||||
|
**Example**: Creating a REST API → AI reads **ai-context** for workflow → calls **mcp-zero** to generate code → references **zero-skills** for patterns → produces production-ready code ✅
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
1. Full examples:
|
||||||
|
|
||||||
[Rapid development of microservice systems](https://github.com/zeromicro/zero-doc/blob/main/doc/shorturl-en.md)
|
[Rapid development of microservice systems](https://github.com/zeromicro/zero-doc/blob/main/doc/shorturl-en.md)
|
||||||
|
|
||||||
@@ -114,24 +142,22 @@ go get -u github.com/zeromicro/go-zero
|
|||||||
|
|
||||||
2. Install goctl
|
2. Install goctl
|
||||||
|
|
||||||
`goctl`can be read as `go control`. `goctl` means not to be controlled by code, instead, we control it. The inside `go` is not `golang`. At the very beginning, I was expecting it to help us improve productivity, and make our lives easier.
|
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
# for Go
|
# for Go
|
||||||
go install github.com/zeromicro/go-zero/tools/goctl@latest
|
go install github.com/zeromicro/go-zero/tools/goctl@latest
|
||||||
|
|
||||||
# For Mac
|
# For Mac
|
||||||
brew install goctl
|
brew install goctl
|
||||||
|
|
||||||
# docker for all platforms
|
# docker for all platforms
|
||||||
docker pull kevinwan/goctl
|
docker pull kevinwan/goctl
|
||||||
# run goctl
|
# run goctl
|
||||||
docker run --rm -it -v `pwd`:/app kevinwan/goctl --help
|
docker run --rm -it -v `pwd`:/app kevinwan/goctl --help
|
||||||
```
|
```
|
||||||
|
|
||||||
make sure goctl is executable and in your $PATH.
|
Ensure goctl is executable and in your $PATH.
|
||||||
|
|
||||||
3. Create the API file, like greet.api, you can install the plugin of goctl in vs code, api syntax is supported.
|
3. Create the API file (greet.api):
|
||||||
|
|
||||||
```go
|
```go
|
||||||
type (
|
type (
|
||||||
@@ -150,19 +176,19 @@ go get -u github.com/zeromicro/go-zero
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
the .api files also can be generated by goctl, like below:
|
Generate .api template:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
goctl api -o greet.api
|
goctl api -o greet.api
|
||||||
```
|
```
|
||||||
|
|
||||||
4. Generate the go server-side code
|
4. Generate Go server code
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
goctl api go -api greet.api -dir greet
|
goctl api go -api greet.api -dir greet
|
||||||
```
|
```
|
||||||
|
|
||||||
the generated files look like:
|
Generated structure:
|
||||||
|
|
||||||
```Plain Text
|
```Plain Text
|
||||||
├── greet
|
├── greet
|
||||||
@@ -184,7 +210,7 @@ go get -u github.com/zeromicro/go-zero
|
|||||||
└── greet.api // api description file
|
└── greet.api // api description file
|
||||||
```
|
```
|
||||||
|
|
||||||
the generated code can be run directly:
|
Run the service:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
cd greet
|
cd greet
|
||||||
@@ -192,15 +218,15 @@ go get -u github.com/zeromicro/go-zero
|
|||||||
go run greet.go -f etc/greet-api.yaml
|
go run greet.go -f etc/greet-api.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
by default, it’s listening on port 8888, while it can be changed in the configuration file.
|
Default port: 8888 (configurable in etc/greet-api.yaml)
|
||||||
|
|
||||||
you can check it by curl:
|
Test with curl:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl -i http://localhost:8888/greet/from/you
|
curl -i http://localhost:8888/greet/from/you
|
||||||
```
|
```
|
||||||
|
|
||||||
the response looks like below:
|
Response:
|
||||||
|
|
||||||
```http
|
```http
|
||||||
HTTP/1.1 200 OK
|
HTTP/1.1 200 OK
|
||||||
@@ -208,12 +234,12 @@ go get -u github.com/zeromicro/go-zero
|
|||||||
Content-Length: 0
|
Content-Length: 0
|
||||||
```
|
```
|
||||||
|
|
||||||
5. Write the business logic code
|
5. Write business logic
|
||||||
|
|
||||||
* the dependencies can be passed into the logic within servicecontext.go, like mysql, redis, etc.
|
* Pass dependencies (mysql, redis, etc.) via servicecontext.go
|
||||||
* add the logic code in a logic package according to .api file
|
* Add logic code in the logic package per .api definition
|
||||||
|
|
||||||
6. Generate code like Java, TypeScript, Dart, JavaScript, etc. just from the api file
|
6. Generate client code for multiple languages
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
goctl api java -api greet.api -dir greet
|
goctl api java -api greet.api -dir greet
|
||||||
@@ -234,11 +260,11 @@ go get -u github.com/zeromicro/go-zero
|
|||||||
* [Rapid development of microservice systems - multiple RPCs](https://github.com/zeromicro/zero-doc/blob/main/docs/zero/bookstore-en.md)
|
* [Rapid development of microservice systems - multiple RPCs](https://github.com/zeromicro/zero-doc/blob/main/docs/zero/bookstore-en.md)
|
||||||
* [Examples](https://github.com/zeromicro/zero-examples)
|
* [Examples](https://github.com/zeromicro/zero-examples)
|
||||||
|
|
||||||
## Chat group
|
## Chat group
|
||||||
|
|
||||||
Join the chat via https://discord.gg/4JQvC5A4Fe
|
Join the chat via https://discord.gg/4JQvC5A4Fe
|
||||||
|
|
||||||
## Cloud Native Landscape
|
## Cloud Native Landscape
|
||||||
|
|
||||||
<p float="left">
|
<p float="left">
|
||||||
<img src="https://raw.githubusercontent.com/zeromicro/zero-doc/main/doc/images/cncf-logo.svg" width="200"/>
|
<img src="https://raw.githubusercontent.com/zeromicro/zero-doc/main/doc/images/cncf-logo.svg" width="200"/>
|
||||||
|
|||||||
@@ -389,7 +389,9 @@ func buildSSERoutes(routes []Route) []Route {
|
|||||||
// because SSE requires the connection to be kept alive indefinitely.
|
// because SSE requires the connection to be kept alive indefinitely.
|
||||||
rc := http.NewResponseController(w)
|
rc := http.NewResponseController(w)
|
||||||
if err := rc.SetWriteDeadline(time.Time{}); err != nil {
|
if err := rc.SetWriteDeadline(time.Time{}); err != nil {
|
||||||
logc.Errorf(r.Context(), "set conn write deadline failed: %v", err)
|
// Some ResponseWriter implementations (like timeoutWriter) don't support SetWriteDeadline.
|
||||||
|
// This is expected behavior and doesn't affect SSE functionality.
|
||||||
|
logc.Debugf(r.Context(), "unable to clear write deadline for SSE connection: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
w.Header().Set(header.ContentType, header.ContentTypeEventStream)
|
w.Header().Set(header.ContentType, header.ContentTypeEventStream)
|
||||||
|
|||||||
@@ -24,12 +24,16 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
limitBodyBytes = 1024
|
limitBodyBytes = 1024
|
||||||
limitDetailedBodyBytes = 4096
|
limitDetailedBodyBytes = 4096
|
||||||
defaultSlowThreshold = time.Millisecond * 500
|
defaultSlowThreshold = time.Millisecond * 500
|
||||||
|
defaultSSESlowThreshold = time.Minute * 3
|
||||||
)
|
)
|
||||||
|
|
||||||
var slowThreshold = syncx.ForAtomicDuration(defaultSlowThreshold)
|
var (
|
||||||
|
slowThreshold = syncx.ForAtomicDuration(defaultSlowThreshold)
|
||||||
|
sseSlowThreshold = syncx.ForAtomicDuration(defaultSSESlowThreshold)
|
||||||
|
)
|
||||||
|
|
||||||
// LogHandler returns a middleware that logs http request and response.
|
// LogHandler returns a middleware that logs http request and response.
|
||||||
func LogHandler(next http.Handler) http.Handler {
|
func LogHandler(next http.Handler) http.Handler {
|
||||||
@@ -109,6 +113,11 @@ func SetSlowThreshold(threshold time.Duration) {
|
|||||||
slowThreshold.Set(threshold)
|
slowThreshold.Set(threshold)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetSSESlowThreshold sets the slow threshold for SSE requests.
|
||||||
|
func SetSSESlowThreshold(threshold time.Duration) {
|
||||||
|
sseSlowThreshold.Set(threshold)
|
||||||
|
}
|
||||||
|
|
||||||
func dumpRequest(r *http.Request) string {
|
func dumpRequest(r *http.Request) string {
|
||||||
reqContent, err := httputil.DumpRequest(r, true)
|
reqContent, err := httputil.DumpRequest(r, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -118,6 +127,14 @@ func dumpRequest(r *http.Request) string {
|
|||||||
return string(reqContent)
|
return string(reqContent)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getSlowThreshold(r *http.Request) time.Duration {
|
||||||
|
if r.Header.Get(headerAccept) == valueSSE {
|
||||||
|
return sseSlowThreshold.Load()
|
||||||
|
} else {
|
||||||
|
return slowThreshold.Load()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func isOkResponse(code int) bool {
|
func isOkResponse(code int) bool {
|
||||||
// not server error
|
// not server error
|
||||||
return code < http.StatusInternalServerError
|
return code < http.StatusInternalServerError
|
||||||
@@ -129,7 +146,8 @@ func logBrief(r *http.Request, code int, timer *utils.ElapsedTimer, logs *intern
|
|||||||
logger := logx.WithContext(r.Context()).WithDuration(duration)
|
logger := logx.WithContext(r.Context()).WithDuration(duration)
|
||||||
buf.WriteString(fmt.Sprintf("[HTTP] %s - %s %s - %s - %s",
|
buf.WriteString(fmt.Sprintf("[HTTP] %s - %s %s - %s - %s",
|
||||||
wrapStatusCode(code), wrapMethod(r.Method), r.RequestURI, httpx.GetRemoteAddr(r), r.UserAgent()))
|
wrapStatusCode(code), wrapMethod(r.Method), r.RequestURI, httpx.GetRemoteAddr(r), r.UserAgent()))
|
||||||
if duration > slowThreshold.Load() {
|
|
||||||
|
if duration > getSlowThreshold(r) {
|
||||||
logger.Slowf("[HTTP] %s - %s %s - %s - %s - slowcall(%s)",
|
logger.Slowf("[HTTP] %s - %s %s - %s - %s - slowcall(%s)",
|
||||||
wrapStatusCode(code), wrapMethod(r.Method), r.RequestURI, httpx.GetRemoteAddr(r), r.UserAgent(),
|
wrapStatusCode(code), wrapMethod(r.Method), r.RequestURI, httpx.GetRemoteAddr(r), r.UserAgent(),
|
||||||
timex.ReprOfDuration(duration))
|
timex.ReprOfDuration(duration))
|
||||||
@@ -160,7 +178,8 @@ func logDetails(r *http.Request, response *detailLoggedResponseWriter, timer *ut
|
|||||||
logger := logx.WithContext(r.Context())
|
logger := logx.WithContext(r.Context())
|
||||||
buf.WriteString(fmt.Sprintf("[HTTP] %s - %d - %s - %s\n=> %s\n",
|
buf.WriteString(fmt.Sprintf("[HTTP] %s - %d - %s - %s\n=> %s\n",
|
||||||
r.Method, code, r.RemoteAddr, timex.ReprOfDuration(duration), dumpRequest(r)))
|
r.Method, code, r.RemoteAddr, timex.ReprOfDuration(duration), dumpRequest(r)))
|
||||||
if duration > slowThreshold.Load() {
|
|
||||||
|
if duration > getSlowThreshold(r) {
|
||||||
logger.Slowf("[HTTP] %s - %d - %s - slowcall(%s)\n=> %s\n", r.Method, code, r.RemoteAddr,
|
logger.Slowf("[HTTP] %s - %d - %s - slowcall(%s)\n=> %s\n", r.Method, code, r.RemoteAddr,
|
||||||
timex.ReprOfDuration(duration), dumpRequest(r))
|
timex.ReprOfDuration(duration), dumpRequest(r))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -88,6 +88,96 @@ func TestLogHandlerSlow(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestLogHandlerSSE(t *testing.T) {
|
||||||
|
handlers := []func(handler http.Handler) http.Handler{
|
||||||
|
LogHandler,
|
||||||
|
DetailedLogHandler,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, logHandler := range handlers {
|
||||||
|
t.Run("SSE request with normal duration", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
|
||||||
|
req.Header.Set(headerAccept, valueSSE)
|
||||||
|
|
||||||
|
handler := logHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
time.Sleep(defaultSlowThreshold + time.Second)
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(resp, req)
|
||||||
|
assert.Equal(t, http.StatusOK, resp.Code)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("SSE request exceeding SSE threshold", func(t *testing.T) {
|
||||||
|
originalThreshold := sseSlowThreshold.Load()
|
||||||
|
SetSSESlowThreshold(time.Millisecond * 100)
|
||||||
|
defer SetSSESlowThreshold(originalThreshold)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
|
||||||
|
req.Header.Set(headerAccept, valueSSE)
|
||||||
|
|
||||||
|
handler := logHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
time.Sleep(time.Millisecond * 150)
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(resp, req)
|
||||||
|
assert.Equal(t, http.StatusOK, resp.Code)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogHandlerThresholdSelection(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
acceptHeader string
|
||||||
|
expectedIsSSE bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Regular HTTP request",
|
||||||
|
acceptHeader: "text/html",
|
||||||
|
expectedIsSSE: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SSE request",
|
||||||
|
acceptHeader: valueSSE,
|
||||||
|
expectedIsSSE: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "No Accept header",
|
||||||
|
acceptHeader: "",
|
||||||
|
expectedIsSSE: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
|
||||||
|
if tt.acceptHeader != "" {
|
||||||
|
req.Header.Set(headerAccept, tt.acceptHeader)
|
||||||
|
}
|
||||||
|
|
||||||
|
SetSlowThreshold(time.Millisecond * 100)
|
||||||
|
SetSSESlowThreshold(time.Millisecond * 200)
|
||||||
|
defer func() {
|
||||||
|
SetSlowThreshold(defaultSlowThreshold)
|
||||||
|
SetSSESlowThreshold(defaultSSESlowThreshold)
|
||||||
|
}()
|
||||||
|
|
||||||
|
handler := LogHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
time.Sleep(time.Millisecond * 150)
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(resp, req)
|
||||||
|
assert.Equal(t, http.StatusOK, resp.Code)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestDetailedLogHandler_LargeBody(t *testing.T) {
|
func TestDetailedLogHandler_LargeBody(t *testing.T) {
|
||||||
lbuf := logtest.NewCollector(t)
|
lbuf := logtest.NewCollector(t)
|
||||||
|
|
||||||
@@ -139,6 +229,12 @@ func TestSetSlowThreshold(t *testing.T) {
|
|||||||
assert.Equal(t, time.Second, slowThreshold.Load())
|
assert.Equal(t, time.Second, slowThreshold.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSetSSESlowThreshold(t *testing.T) {
|
||||||
|
assert.Equal(t, defaultSSESlowThreshold, sseSlowThreshold.Load())
|
||||||
|
SetSSESlowThreshold(time.Minute * 10)
|
||||||
|
assert.Equal(t, time.Minute*10, sseSlowThreshold.Load())
|
||||||
|
}
|
||||||
|
|
||||||
func TestWrapMethodWithColor(t *testing.T) {
|
func TestWrapMethodWithColor(t *testing.T) {
|
||||||
// no tty
|
// no tty
|
||||||
assert.Equal(t, http.MethodGet, wrapMethod(http.MethodGet))
|
assert.Equal(t, http.MethodGet, wrapMethod(http.MethodGet))
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
// TraceOption defines the method to customize an traceOptions.
|
// TraceOption defines the method to customize a traceOptions.
|
||||||
TraceOption func(options *traceOptions)
|
TraceOption func(options *traceOptions)
|
||||||
|
|
||||||
// traceOptions is TraceHandler options.
|
// traceOptions is TraceHandler options.
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ const (
|
|||||||
var (
|
var (
|
||||||
// ErrInvalidMethod is an error that indicates not a valid http method.
|
// ErrInvalidMethod is an error that indicates not a valid http method.
|
||||||
ErrInvalidMethod = errors.New("not a valid http method")
|
ErrInvalidMethod = errors.New("not a valid http method")
|
||||||
// ErrInvalidPath is an error that indicates path is not start with /.
|
// ErrInvalidPath is an error that indicates path does not start with /.
|
||||||
ErrInvalidPath = errors.New("path must begin with '/'")
|
ErrInvalidPath = errors.New("path must begin with '/'")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -92,7 +92,7 @@ Port: 0
|
|||||||
Path: "/",
|
Path: "/",
|
||||||
Handler: nil,
|
Handler: nil,
|
||||||
}, WithJwt("thesecret"), WithSignature(SignatureConf{}),
|
}, WithJwt("thesecret"), WithSignature(SignatureConf{}),
|
||||||
WithJwtTransition("preivous", "thenewone"))
|
WithJwtTransition("previous", "thenewone"))
|
||||||
|
|
||||||
func() {
|
func() {
|
||||||
defer func() {
|
defer func() {
|
||||||
|
|||||||
@@ -90,6 +90,7 @@ func init() {
|
|||||||
newCmdFlags.StringVar(&new.VarStringHome, "home")
|
newCmdFlags.StringVar(&new.VarStringHome, "home")
|
||||||
newCmdFlags.StringVar(&new.VarStringRemote, "remote")
|
newCmdFlags.StringVar(&new.VarStringRemote, "remote")
|
||||||
newCmdFlags.StringVar(&new.VarStringBranch, "branch")
|
newCmdFlags.StringVar(&new.VarStringBranch, "branch")
|
||||||
|
newCmdFlags.StringVar(&new.VarStringModule, "module")
|
||||||
newCmdFlags.StringVarWithDefaultValue(&new.VarStringStyle, "style", config.DefaultFormat)
|
newCmdFlags.StringVarWithDefaultValue(&new.VarStringStyle, "style", config.DefaultFormat)
|
||||||
|
|
||||||
pluginCmdFlags.StringVarP(&plugin.VarStringPlugin, "plugin", "p")
|
pluginCmdFlags.StringVarP(&plugin.VarStringPlugin, "plugin", "p")
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ import '../vars/vars.dart';
|
|||||||
/// Send GET request.
|
/// Send GET request.
|
||||||
///
|
///
|
||||||
/// ok: the function that will be called on success.
|
/// ok: the function that will be called on success.
|
||||||
/// fail:the fuction that will be called on failure.
|
/// fail:the function that will be called on failure.
|
||||||
/// eventually:the function that will be called regardless of success or failure.
|
/// eventually:the function that will be called regardless of success or failure.
|
||||||
Future apiGet(String path,
|
Future apiGet(String path,
|
||||||
{Map<String, String> header,
|
{Map<String, String> header,
|
||||||
@@ -47,7 +47,7 @@ Future apiGet(String path,
|
|||||||
///
|
///
|
||||||
/// data: the data to post, it will be marshaled to json automatically.
|
/// data: the data to post, it will be marshaled to json automatically.
|
||||||
/// ok: the function that will be called on success.
|
/// ok: the function that will be called on success.
|
||||||
/// fail:the fuction that will be called on failure.
|
/// fail:the function that will be called on failure.
|
||||||
/// eventually:the function that will be called regardless of success or failure.
|
/// eventually:the function that will be called regardless of success or failure.
|
||||||
Future apiPost(String path, dynamic data,
|
Future apiPost(String path, dynamic data,
|
||||||
{Map<String, String> header,
|
{Map<String, String> header,
|
||||||
@@ -132,7 +132,7 @@ Future _apiRequest(String method, String path, dynamic data,
|
|||||||
/// data: any request class that will be converted to json automatically
|
/// data: any request class that will be converted to json automatically
|
||||||
/// ok: is called when request succeeds
|
/// ok: is called when request succeeds
|
||||||
/// fail: is called when request fails
|
/// fail: is called when request fails
|
||||||
/// eventually: is always called until the nearby functions returns
|
/// eventually: is always called after the nearby function returns
|
||||||
Future apiPost(String path, dynamic data,
|
Future apiPost(String path, dynamic data,
|
||||||
{Map<String, String>? header,
|
{Map<String, String>? header,
|
||||||
Function(Map<String, dynamic>)? ok,
|
Function(Map<String, dynamic>)? ok,
|
||||||
@@ -146,7 +146,7 @@ Future _apiRequest(String method, String path, dynamic data,
|
|||||||
///
|
///
|
||||||
/// ok: is called when request succeeds
|
/// ok: is called when request succeeds
|
||||||
/// fail: is called when request fails
|
/// fail: is called when request fails
|
||||||
/// eventually: is always called until the nearby functions returns
|
/// eventually: is always called after the nearby function returns
|
||||||
Future apiGet(String path,
|
Future apiGet(String path,
|
||||||
{Map<String, String>? header,
|
{Map<String, String>? header,
|
||||||
Function(Map<String, dynamic>)? ok,
|
Function(Map<String, dynamic>)? ok,
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ func DocCommand(_ *cobra.Command, _ []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !pathx.FileExists(dir) {
|
if !pathx.FileExists(dir) {
|
||||||
return fmt.Errorf("dir %s not exsit", dir)
|
return fmt.Errorf("dir %s not exist", dir)
|
||||||
}
|
}
|
||||||
|
|
||||||
dir, err := filepath.Abs(dir)
|
dir, err := filepath.Abs(dir)
|
||||||
|
|||||||
@@ -1,3 +1,6 @@
|
|||||||
|
// Code scaffolded by goctl. Safe to edit.
|
||||||
|
// goctl {{.version}}
|
||||||
|
|
||||||
package config
|
package config
|
||||||
|
|
||||||
import {{.authImport}}
|
import {{.authImport}}
|
||||||
|
|||||||
@@ -75,6 +75,11 @@ func GoCommand(_ *cobra.Command, _ []string) error {
|
|||||||
|
|
||||||
// DoGenProject gen go project files with api file
|
// DoGenProject gen go project files with api file
|
||||||
func DoGenProject(apiFile, dir, style string, withTest bool) error {
|
func DoGenProject(apiFile, dir, style string, withTest bool) error {
|
||||||
|
return DoGenProjectWithModule(apiFile, dir, "", style, withTest)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoGenProjectWithModule gen go project files with api file using custom module name
|
||||||
|
func DoGenProjectWithModule(apiFile, dir, moduleName, style string, withTest bool) error {
|
||||||
api, err := parser.Parse(apiFile)
|
api, err := parser.Parse(apiFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -90,23 +95,31 @@ func DoGenProject(apiFile, dir, style string, withTest bool) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
logx.Must(pathx.MkdirIfNotExist(dir))
|
logx.Must(pathx.MkdirIfNotExist(dir))
|
||||||
rootPkg, err := golang.GetParentPackage(dir)
|
|
||||||
|
var rootPkg, projectPkg string
|
||||||
|
if len(moduleName) > 0 {
|
||||||
|
rootPkg, projectPkg, err = golang.GetParentPackageWithModule(dir, moduleName)
|
||||||
|
} else {
|
||||||
|
rootPkg, projectPkg, err = golang.GetParentPackage(dir)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
logx.Must(genEtc(dir, cfg, api))
|
logx.Must(genEtc(dir, cfg, api))
|
||||||
logx.Must(genConfig(dir, cfg, api))
|
logx.Must(genConfig(dir, projectPkg, cfg, api))
|
||||||
logx.Must(genMain(dir, rootPkg, cfg, api))
|
logx.Must(genMain(dir, rootPkg, projectPkg, cfg, api))
|
||||||
logx.Must(genServiceContext(dir, rootPkg, cfg, api))
|
logx.Must(genServiceContext(dir, rootPkg, projectPkg, cfg, api))
|
||||||
logx.Must(genTypes(dir, cfg, api))
|
logx.Must(genTypes(dir, cfg, api))
|
||||||
logx.Must(genRoutes(dir, rootPkg, cfg, api))
|
logx.Must(genRoutes(dir, rootPkg, projectPkg, cfg, api))
|
||||||
logx.Must(genHandlers(dir, rootPkg, cfg, api))
|
logx.Must(genHandlers(dir, rootPkg, projectPkg, cfg, api))
|
||||||
logx.Must(genLogic(dir, rootPkg, cfg, api))
|
logx.Must(genLogic(dir, rootPkg, projectPkg, cfg, api))
|
||||||
logx.Must(genMiddleware(dir, cfg, api))
|
logx.Must(genMiddleware(dir, cfg, api))
|
||||||
if withTest {
|
if withTest {
|
||||||
logx.Must(genHandlersTest(dir, rootPkg, cfg, api))
|
logx.Must(genHandlersTest(dir, rootPkg, projectPkg, cfg, api))
|
||||||
logx.Must(genLogicTest(dir, rootPkg, cfg, api))
|
logx.Must(genLogicTest(dir, rootPkg, projectPkg, cfg, api))
|
||||||
|
logx.Must(genServiceContextTest(dir, rootPkg, projectPkg, cfg, api))
|
||||||
|
logx.Must(genIntegrationTest(dir, rootPkg, projectPkg, cfg, api))
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := backupAndSweep(apiFile); err != nil {
|
if err := backupAndSweep(apiFile); err != nil {
|
||||||
|
|||||||
181
tools/goctl/api/gogen/gencomment_test.go
Normal file
181
tools/goctl/api/gogen/gencomment_test.go
Normal file
@@ -0,0 +1,181 @@
|
|||||||
|
package gogen
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/zeromicro/go-zero/tools/goctl/internal/version"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestGenerationComments verifies that all generated files have appropriate generation comments
|
||||||
|
func TestGenerationComments(t *testing.T) {
|
||||||
|
// Create a temporary directory for our test
|
||||||
|
tempDir, err := os.MkdirTemp("", "goctl_test_")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer os.RemoveAll(tempDir)
|
||||||
|
|
||||||
|
// Create a simple API spec for testing
|
||||||
|
apiContent := `
|
||||||
|
syntax = "v1"
|
||||||
|
|
||||||
|
type HelloRequest {
|
||||||
|
Name string ` + "`json:\"name\"`" + `
|
||||||
|
}
|
||||||
|
|
||||||
|
type HelloResponse {
|
||||||
|
Message string ` + "`json:\"message\"`" + `
|
||||||
|
}
|
||||||
|
|
||||||
|
service hello-api {
|
||||||
|
@handler helloHandler
|
||||||
|
post /hello (HelloRequest) returns (HelloResponse)
|
||||||
|
}`
|
||||||
|
|
||||||
|
// Write the API spec to a temporary file
|
||||||
|
apiFile := filepath.Join(tempDir, "test.api")
|
||||||
|
err = os.WriteFile(apiFile, []byte(apiContent), 0644)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Parse and generate the API files using the correct function signature
|
||||||
|
err = DoGenProject(apiFile, tempDir, "gozero", false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Define expected files and their comment types
|
||||||
|
expectedFiles := map[string]string{
|
||||||
|
// Files that should have "DO NOT EDIT" comments (regenerated files)
|
||||||
|
"internal/types/types.go": "DO NOT EDIT",
|
||||||
|
|
||||||
|
// Files that should have "Safe to edit" comments (scaffolded files)
|
||||||
|
"internal/handler/hellohandler.go": "Safe to edit",
|
||||||
|
"internal/config/config.go": "Safe to edit",
|
||||||
|
"hello.go": "Safe to edit", // main file
|
||||||
|
"internal/svc/servicecontext.go": "Safe to edit",
|
||||||
|
"internal/logic/hellologic.go": "Safe to edit",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check each file for the correct generation comment
|
||||||
|
for filePath, expectedCommentType := range expectedFiles {
|
||||||
|
fullPath := filepath.Join(tempDir, filePath)
|
||||||
|
|
||||||
|
// Skip if file doesn't exist (some files might not be generated in all cases)
|
||||||
|
if _, err := os.Stat(fullPath); os.IsNotExist(err) {
|
||||||
|
t.Logf("File %s does not exist, skipping", filePath)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
content, err := os.ReadFile(fullPath)
|
||||||
|
require.NoError(t, err, "Failed to read file: %s", filePath)
|
||||||
|
|
||||||
|
contentStr := string(content)
|
||||||
|
lines := strings.Split(contentStr, "\n")
|
||||||
|
|
||||||
|
// Check that the file starts with proper generation comments
|
||||||
|
require.GreaterOrEqual(t, len(lines), 2, "File %s should have at least 2 lines", filePath)
|
||||||
|
|
||||||
|
if expectedCommentType == "DO NOT EDIT" {
|
||||||
|
assert.Contains(t, lines[0], "// Code generated by goctl. DO NOT EDIT.",
|
||||||
|
"File %s should have 'DO NOT EDIT' comment as first line", filePath)
|
||||||
|
} else if expectedCommentType == "Safe to edit" {
|
||||||
|
assert.Contains(t, lines[0], "// Code scaffolded by goctl. Safe to edit.",
|
||||||
|
"File %s should have 'Safe to edit' comment as first line", filePath)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that the second line contains the version
|
||||||
|
assert.Contains(t, lines[1], "// goctl",
|
||||||
|
"File %s should have version comment as second line", filePath)
|
||||||
|
assert.Contains(t, lines[1], version.BuildVersion,
|
||||||
|
"File %s should contain version %s in second line", filePath, version.BuildVersion)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRoutesGenerationComment verifies routes files have "DO NOT EDIT" comment
|
||||||
|
func TestRoutesGenerationComment(t *testing.T) {
|
||||||
|
// Create a temporary directory for our test
|
||||||
|
tempDir, err := os.MkdirTemp("", "goctl_routes_test_")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer os.RemoveAll(tempDir)
|
||||||
|
|
||||||
|
// Create an API spec with multiple handlers to ensure routes file is generated
|
||||||
|
apiContent := `
|
||||||
|
syntax = "v1"
|
||||||
|
|
||||||
|
type HelloRequest {
|
||||||
|
Name string ` + "`json:\"name\"`" + `
|
||||||
|
}
|
||||||
|
|
||||||
|
type HelloResponse {
|
||||||
|
Message string ` + "`json:\"message\"`" + `
|
||||||
|
}
|
||||||
|
|
||||||
|
service hello-api {
|
||||||
|
@handler helloHandler
|
||||||
|
post /hello (HelloRequest) returns (HelloResponse)
|
||||||
|
|
||||||
|
@handler worldHandler
|
||||||
|
get /world returns (HelloResponse)
|
||||||
|
}`
|
||||||
|
|
||||||
|
// Write the API spec to a temporary file
|
||||||
|
apiFile := filepath.Join(tempDir, "test.api")
|
||||||
|
err = os.WriteFile(apiFile, []byte(apiContent), 0644)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Generate the API files using the correct function signature
|
||||||
|
err = DoGenProject(apiFile, tempDir, "gozero", false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Check the routes file specifically
|
||||||
|
routesFile := filepath.Join(tempDir, "internal/handler/routes.go")
|
||||||
|
if _, err := os.Stat(routesFile); os.IsNotExist(err) {
|
||||||
|
t.Skip("Routes file not generated, skipping test")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
content, err := os.ReadFile(routesFile)
|
||||||
|
require.NoError(t, err, "Failed to read routes.go")
|
||||||
|
|
||||||
|
contentStr := string(content)
|
||||||
|
lines := strings.Split(contentStr, "\n")
|
||||||
|
|
||||||
|
// Check that routes.go has "DO NOT EDIT" comment
|
||||||
|
require.GreaterOrEqual(t, len(lines), 2, "Routes file should have at least 2 lines")
|
||||||
|
assert.Contains(t, lines[0], "// Code generated by goctl. DO NOT EDIT.",
|
||||||
|
"Routes file should have 'DO NOT EDIT' comment")
|
||||||
|
assert.Contains(t, lines[1], "// goctl",
|
||||||
|
"Routes file should have version comment")
|
||||||
|
assert.Contains(t, lines[1], version.BuildVersion,
|
||||||
|
"Routes file should contain version %s", version.BuildVersion)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestVersionInTemplateData verifies that version is correctly passed to templates
|
||||||
|
func TestVersionInTemplateData(t *testing.T) {
|
||||||
|
// Test that BuildVersion is available
|
||||||
|
assert.NotEmpty(t, version.BuildVersion, "BuildVersion should not be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCommentsFollowGoStandards verifies our comments follow Go community standards
|
||||||
|
func TestCommentsFollowGoStandards(t *testing.T) {
|
||||||
|
// Test the format of our generation comments
|
||||||
|
doNotEditComment := "// Code generated by goctl. DO NOT EDIT."
|
||||||
|
safeToEditComment := "// Code scaffolded by goctl. Safe to edit."
|
||||||
|
|
||||||
|
// Both should be valid Go comments
|
||||||
|
assert.True(t, strings.HasPrefix(doNotEditComment, "//"),
|
||||||
|
"DO NOT EDIT comment should start with //")
|
||||||
|
assert.True(t, strings.HasPrefix(safeToEditComment, "//"),
|
||||||
|
"Safe to edit comment should start with //")
|
||||||
|
|
||||||
|
// Should contain key information
|
||||||
|
assert.Contains(t, doNotEditComment, "goctl",
|
||||||
|
"DO NOT EDIT comment should mention goctl")
|
||||||
|
assert.Contains(t, safeToEditComment, "goctl",
|
||||||
|
"Safe to edit comment should mention goctl")
|
||||||
|
assert.Contains(t, doNotEditComment, "DO NOT EDIT",
|
||||||
|
"Should clearly state DO NOT EDIT")
|
||||||
|
assert.Contains(t, safeToEditComment, "Safe to edit",
|
||||||
|
"Should clearly state Safe to edit")
|
||||||
|
}
|
||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
|
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/config"
|
"github.com/zeromicro/go-zero/tools/goctl/config"
|
||||||
|
"github.com/zeromicro/go-zero/tools/goctl/internal/version"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/util/format"
|
"github.com/zeromicro/go-zero/tools/goctl/util/format"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/vars"
|
"github.com/zeromicro/go-zero/tools/goctl/vars"
|
||||||
)
|
)
|
||||||
@@ -29,7 +30,7 @@ const (
|
|||||||
//go:embed config.tpl
|
//go:embed config.tpl
|
||||||
var configTemplate string
|
var configTemplate string
|
||||||
|
|
||||||
func genConfig(dir string, cfg *config.Config, api *spec.ApiSpec) error {
|
func genConfig(dir, projectPkg string, cfg *config.Config, api *spec.ApiSpec) error {
|
||||||
filename, err := format.FileNamingFormat(cfg.NamingFormat, configFile)
|
filename, err := format.FileNamingFormat(cfg.NamingFormat, configFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -60,6 +61,8 @@ func genConfig(dir string, cfg *config.Config, api *spec.ApiSpec) error {
|
|||||||
"authImport": authImportStr,
|
"authImport": authImportStr,
|
||||||
"auth": strings.Join(auths, "\n"),
|
"auth": strings.Join(auths, "\n"),
|
||||||
"jwtTrans": strings.Join(jwtTransList, "\n"),
|
"jwtTrans": strings.Join(jwtTransList, "\n"),
|
||||||
|
"projectPkg": projectPkg,
|
||||||
|
"version": version.BuildVersion,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
|
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/config"
|
"github.com/zeromicro/go-zero/tools/goctl/config"
|
||||||
|
"github.com/zeromicro/go-zero/tools/goctl/internal/version"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/util"
|
"github.com/zeromicro/go-zero/tools/goctl/util"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/util/format"
|
"github.com/zeromicro/go-zero/tools/goctl/util/format"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
|
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
|
||||||
@@ -22,7 +23,7 @@ var (
|
|||||||
sseHandlerTemplate string
|
sseHandlerTemplate string
|
||||||
)
|
)
|
||||||
|
|
||||||
func genHandler(dir, rootPkg string, cfg *config.Config, group spec.Group, route spec.Route) error {
|
func genHandler(dir, rootPkg, projectPkg string, cfg *config.Config, group spec.Group, route spec.Route) error {
|
||||||
handler := getHandlerName(route)
|
handler := getHandlerName(route)
|
||||||
handlerPath := getHandlerFolderPath(group, route)
|
handlerPath := getHandlerFolderPath(group, route)
|
||||||
pkgName := handlerPath[strings.LastIndex(handlerPath, "/")+1:]
|
pkgName := handlerPath[strings.LastIndex(handlerPath, "/")+1:]
|
||||||
@@ -37,9 +38,11 @@ func genHandler(dir, rootPkg string, cfg *config.Config, group spec.Group, route
|
|||||||
}
|
}
|
||||||
|
|
||||||
var builtinTemplate = handlerTemplate
|
var builtinTemplate = handlerTemplate
|
||||||
|
var templateFile = handlerTemplateFile
|
||||||
sse := group.GetAnnotation("sse")
|
sse := group.GetAnnotation("sse")
|
||||||
if sse == "true" {
|
if sse == "true" {
|
||||||
builtinTemplate = sseHandlerTemplate
|
builtinTemplate = sseHandlerTemplate
|
||||||
|
templateFile = sseHandlerTemplateFile
|
||||||
}
|
}
|
||||||
|
|
||||||
return genFile(fileGenConfig{
|
return genFile(fileGenConfig{
|
||||||
@@ -48,7 +51,7 @@ func genHandler(dir, rootPkg string, cfg *config.Config, group spec.Group, route
|
|||||||
filename: filename + ".go",
|
filename: filename + ".go",
|
||||||
templateName: "handlerTemplate",
|
templateName: "handlerTemplate",
|
||||||
category: category,
|
category: category,
|
||||||
templateFile: handlerTemplateFile,
|
templateFile: templateFile,
|
||||||
builtinTemplate: builtinTemplate,
|
builtinTemplate: builtinTemplate,
|
||||||
data: map[string]any{
|
data: map[string]any{
|
||||||
"PkgName": pkgName,
|
"PkgName": pkgName,
|
||||||
@@ -63,14 +66,16 @@ func genHandler(dir, rootPkg string, cfg *config.Config, group spec.Group, route
|
|||||||
"HasRequest": len(route.RequestTypeName()) > 0,
|
"HasRequest": len(route.RequestTypeName()) > 0,
|
||||||
"HasDoc": len(route.JoinedDoc()) > 0,
|
"HasDoc": len(route.JoinedDoc()) > 0,
|
||||||
"Doc": getDoc(route.JoinedDoc()),
|
"Doc": getDoc(route.JoinedDoc()),
|
||||||
|
"projectPkg": projectPkg,
|
||||||
|
"version": version.BuildVersion,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func genHandlers(dir, rootPkg string, cfg *config.Config, api *spec.ApiSpec) error {
|
func genHandlers(dir, rootPkg, projectPkg string, cfg *config.Config, api *spec.ApiSpec) error {
|
||||||
for _, group := range api.Service.Groups {
|
for _, group := range api.Service.Groups {
|
||||||
for _, route := range group.Routes {
|
for _, route := range group.Routes {
|
||||||
if err := genHandler(dir, rootPkg, cfg, group, route); err != nil {
|
if err := genHandler(dir, rootPkg, projectPkg, cfg, group, route); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
|
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/config"
|
"github.com/zeromicro/go-zero/tools/goctl/config"
|
||||||
|
"github.com/zeromicro/go-zero/tools/goctl/internal/version"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/util"
|
"github.com/zeromicro/go-zero/tools/goctl/util"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/util/format"
|
"github.com/zeromicro/go-zero/tools/goctl/util/format"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
|
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
|
||||||
@@ -15,7 +16,7 @@ import (
|
|||||||
//go:embed handler_test.tpl
|
//go:embed handler_test.tpl
|
||||||
var handlerTestTemplate string
|
var handlerTestTemplate string
|
||||||
|
|
||||||
func genHandlerTest(dir, rootPkg string, cfg *config.Config, group spec.Group, route spec.Route) error {
|
func genHandlerTest(dir, rootPkg, projectPkg string, cfg *config.Config, group spec.Group, route spec.Route) error {
|
||||||
handler := getHandlerName(route)
|
handler := getHandlerName(route)
|
||||||
handlerPath := getHandlerFolderPath(group, route)
|
handlerPath := getHandlerFolderPath(group, route)
|
||||||
pkgName := handlerPath[strings.LastIndex(handlerPath, "/")+1:]
|
pkgName := handlerPath[strings.LastIndex(handlerPath, "/")+1:]
|
||||||
@@ -50,14 +51,16 @@ func genHandlerTest(dir, rootPkg string, cfg *config.Config, group spec.Group, r
|
|||||||
"HasRequest": len(route.RequestTypeName()) > 0,
|
"HasRequest": len(route.RequestTypeName()) > 0,
|
||||||
"HasDoc": len(route.JoinedDoc()) > 0,
|
"HasDoc": len(route.JoinedDoc()) > 0,
|
||||||
"Doc": getDoc(route.JoinedDoc()),
|
"Doc": getDoc(route.JoinedDoc()),
|
||||||
|
"projectPkg": projectPkg,
|
||||||
|
"version": version.BuildVersion,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func genHandlersTest(dir, rootPkg string, cfg *config.Config, api *spec.ApiSpec) error {
|
func genHandlersTest(dir, rootPkg, projectPkg string, cfg *config.Config, api *spec.ApiSpec) error {
|
||||||
for _, group := range api.Service.Groups {
|
for _, group := range api.Service.Groups {
|
||||||
for _, route := range group.Routes {
|
for _, route := range group.Routes {
|
||||||
if err := genHandlerTest(dir, rootPkg, cfg, group, route); err != nil {
|
if err := genHandlerTest(dir, rootPkg, projectPkg, cfg, group, route); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
42
tools/goctl/api/gogen/genintegrationtest.go
Normal file
42
tools/goctl/api/gogen/genintegrationtest.go
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
package gogen
|
||||||
|
|
||||||
|
import (
|
||||||
|
_ "embed"
|
||||||
|
|
||||||
|
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||||
|
"github.com/zeromicro/go-zero/tools/goctl/config"
|
||||||
|
"github.com/zeromicro/go-zero/tools/goctl/internal/version"
|
||||||
|
"github.com/zeromicro/go-zero/tools/goctl/util/format"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:embed integration_test.tpl
|
||||||
|
var integrationTestTemplate string
|
||||||
|
|
||||||
|
func genIntegrationTest(dir, rootPkg, projectPkg string, cfg *config.Config, api *spec.ApiSpec) error {
|
||||||
|
serviceName := api.Service.Name
|
||||||
|
if len(serviceName) == 0 {
|
||||||
|
serviceName = "server"
|
||||||
|
}
|
||||||
|
|
||||||
|
filename, err := format.FileNamingFormat(cfg.NamingFormat, serviceName)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return genFile(fileGenConfig{
|
||||||
|
dir: dir,
|
||||||
|
subdir: "",
|
||||||
|
filename: filename + "_test.go",
|
||||||
|
templateName: "integrationTestTemplate",
|
||||||
|
category: category,
|
||||||
|
templateFile: integrationTestTemplateFile,
|
||||||
|
builtinTemplate: integrationTestTemplate,
|
||||||
|
data: map[string]any{
|
||||||
|
"projectPkg": projectPkg,
|
||||||
|
"serviceName": serviceName,
|
||||||
|
"version": version.BuildVersion,
|
||||||
|
"hasRoutes": len(api.Service.Routes()) > 0,
|
||||||
|
"routes": api.Service.Routes(),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"github.com/zeromicro/go-zero/tools/goctl/api/parser/g4/gen/api"
|
"github.com/zeromicro/go-zero/tools/goctl/api/parser/g4/gen/api"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/config"
|
"github.com/zeromicro/go-zero/tools/goctl/config"
|
||||||
|
"github.com/zeromicro/go-zero/tools/goctl/internal/version"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/util/format"
|
"github.com/zeromicro/go-zero/tools/goctl/util/format"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
|
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/vars"
|
"github.com/zeromicro/go-zero/tools/goctl/vars"
|
||||||
@@ -23,10 +24,10 @@ var (
|
|||||||
sseLogicTemplate string
|
sseLogicTemplate string
|
||||||
)
|
)
|
||||||
|
|
||||||
func genLogic(dir, rootPkg string, cfg *config.Config, api *spec.ApiSpec) error {
|
func genLogic(dir, rootPkg, projectPkg string, cfg *config.Config, api *spec.ApiSpec) error {
|
||||||
for _, g := range api.Service.Groups {
|
for _, g := range api.Service.Groups {
|
||||||
for _, r := range g.Routes {
|
for _, r := range g.Routes {
|
||||||
err := genLogicByRoute(dir, rootPkg, cfg, g, r)
|
err := genLogicByRoute(dir, rootPkg, projectPkg, cfg, g, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -35,7 +36,7 @@ func genLogic(dir, rootPkg string, cfg *config.Config, api *spec.ApiSpec) error
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func genLogicByRoute(dir, rootPkg string, cfg *config.Config, group spec.Group, route spec.Route) error {
|
func genLogicByRoute(dir, rootPkg, projectPkg string, cfg *config.Config, group spec.Group, route spec.Route) error {
|
||||||
logic := getLogicName(route)
|
logic := getLogicName(route)
|
||||||
goFile, err := format.FileNamingFormat(cfg.NamingFormat, logic)
|
goFile, err := format.FileNamingFormat(cfg.NamingFormat, logic)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -60,9 +61,11 @@ func genLogicByRoute(dir, rootPkg string, cfg *config.Config, group spec.Group,
|
|||||||
|
|
||||||
subDir := getLogicFolderPath(group, route)
|
subDir := getLogicFolderPath(group, route)
|
||||||
builtinTemplate := logicTemplate
|
builtinTemplate := logicTemplate
|
||||||
|
templateFile := logicTemplateFile
|
||||||
sse := group.GetAnnotation("sse")
|
sse := group.GetAnnotation("sse")
|
||||||
if sse == "true" {
|
if sse == "true" {
|
||||||
builtinTemplate = sseLogicTemplate
|
builtinTemplate = sseLogicTemplate
|
||||||
|
templateFile = sseLogicTemplateFile
|
||||||
responseString = "error"
|
responseString = "error"
|
||||||
returnString = "return nil"
|
returnString = "return nil"
|
||||||
resp := responseGoTypeName(route, typesPacket)
|
resp := responseGoTypeName(route, typesPacket)
|
||||||
@@ -79,7 +82,7 @@ func genLogicByRoute(dir, rootPkg string, cfg *config.Config, group spec.Group,
|
|||||||
filename: goFile + ".go",
|
filename: goFile + ".go",
|
||||||
templateName: "logicTemplate",
|
templateName: "logicTemplate",
|
||||||
category: category,
|
category: category,
|
||||||
templateFile: logicTemplateFile,
|
templateFile: templateFile,
|
||||||
builtinTemplate: builtinTemplate,
|
builtinTemplate: builtinTemplate,
|
||||||
data: map[string]any{
|
data: map[string]any{
|
||||||
"pkgName": subDir[strings.LastIndex(subDir, "/")+1:],
|
"pkgName": subDir[strings.LastIndex(subDir, "/")+1:],
|
||||||
@@ -91,6 +94,8 @@ func genLogicByRoute(dir, rootPkg string, cfg *config.Config, group spec.Group,
|
|||||||
"request": requestString,
|
"request": requestString,
|
||||||
"hasDoc": len(route.JoinedDoc()) > 0,
|
"hasDoc": len(route.JoinedDoc()) > 0,
|
||||||
"doc": getDoc(route.JoinedDoc()),
|
"doc": getDoc(route.JoinedDoc()),
|
||||||
|
"projectPkg": projectPkg,
|
||||||
|
"version": version.BuildVersion,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
|
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/config"
|
"github.com/zeromicro/go-zero/tools/goctl/config"
|
||||||
|
"github.com/zeromicro/go-zero/tools/goctl/internal/version"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/util/format"
|
"github.com/zeromicro/go-zero/tools/goctl/util/format"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
|
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
|
||||||
)
|
)
|
||||||
@@ -14,10 +15,10 @@ import (
|
|||||||
//go:embed logic_test.tpl
|
//go:embed logic_test.tpl
|
||||||
var logicTestTemplate string
|
var logicTestTemplate string
|
||||||
|
|
||||||
func genLogicTest(dir, rootPkg string, cfg *config.Config, api *spec.ApiSpec) error {
|
func genLogicTest(dir, rootPkg, projectPkg string, cfg *config.Config, api *spec.ApiSpec) error {
|
||||||
for _, g := range api.Service.Groups {
|
for _, g := range api.Service.Groups {
|
||||||
for _, r := range g.Routes {
|
for _, r := range g.Routes {
|
||||||
err := genLogicTestByRoute(dir, rootPkg, cfg, g, r)
|
err := genLogicTestByRoute(dir, rootPkg, projectPkg, cfg, g, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -26,7 +27,7 @@ func genLogicTest(dir, rootPkg string, cfg *config.Config, api *spec.ApiSpec) er
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func genLogicTestByRoute(dir, rootPkg string, cfg *config.Config, group spec.Group, route spec.Route) error {
|
func genLogicTestByRoute(dir, rootPkg, projectPkg string, cfg *config.Config, group spec.Group, route spec.Route) error {
|
||||||
logic := getLogicName(route)
|
logic := getLogicName(route)
|
||||||
goFile, err := format.FileNamingFormat(cfg.NamingFormat, logic)
|
goFile, err := format.FileNamingFormat(cfg.NamingFormat, logic)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -73,6 +74,8 @@ func genLogicTestByRoute(dir, rootPkg string, cfg *config.Config, group spec.Gro
|
|||||||
"requestType": requestType,
|
"requestType": requestType,
|
||||||
"hasDoc": len(route.JoinedDoc()) > 0,
|
"hasDoc": len(route.JoinedDoc()) > 0,
|
||||||
"doc": getDoc(route.JoinedDoc()),
|
"doc": getDoc(route.JoinedDoc()),
|
||||||
|
"projectPkg": projectPkg,
|
||||||
|
"version": version.BuildVersion,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
|
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/config"
|
"github.com/zeromicro/go-zero/tools/goctl/config"
|
||||||
|
"github.com/zeromicro/go-zero/tools/goctl/internal/version"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/util/format"
|
"github.com/zeromicro/go-zero/tools/goctl/util/format"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
|
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/vars"
|
"github.com/zeromicro/go-zero/tools/goctl/vars"
|
||||||
@@ -15,7 +16,7 @@ import (
|
|||||||
//go:embed main.tpl
|
//go:embed main.tpl
|
||||||
var mainTemplate string
|
var mainTemplate string
|
||||||
|
|
||||||
func genMain(dir, rootPkg string, cfg *config.Config, api *spec.ApiSpec) error {
|
func genMain(dir, rootPkg, projectPkg string, cfg *config.Config, api *spec.ApiSpec) error {
|
||||||
name := strings.ToLower(api.Service.Name)
|
name := strings.ToLower(api.Service.Name)
|
||||||
filename, err := format.FileNamingFormat(cfg.NamingFormat, name)
|
filename, err := format.FileNamingFormat(cfg.NamingFormat, name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -38,6 +39,8 @@ func genMain(dir, rootPkg string, cfg *config.Config, api *spec.ApiSpec) error {
|
|||||||
data: map[string]string{
|
data: map[string]string{
|
||||||
"importPackages": genMainImports(rootPkg),
|
"importPackages": genMainImports(rootPkg),
|
||||||
"serviceName": configName,
|
"serviceName": configName,
|
||||||
|
"projectPkg": projectPkg,
|
||||||
|
"version": version.BuildVersion,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
|
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/config"
|
"github.com/zeromicro/go-zero/tools/goctl/config"
|
||||||
|
"github.com/zeromicro/go-zero/tools/goctl/internal/version"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/util/format"
|
"github.com/zeromicro/go-zero/tools/goctl/util/format"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -31,7 +32,8 @@ func genMiddleware(dir string, cfg *config.Config, api *spec.ApiSpec) error {
|
|||||||
templateFile: middlewareImplementCodeFile,
|
templateFile: middlewareImplementCodeFile,
|
||||||
builtinTemplate: middlewareImplementCode,
|
builtinTemplate: middlewareImplementCode,
|
||||||
data: map[string]string{
|
data: map[string]string{
|
||||||
"name": strings.Title(name),
|
"name": strings.Title(name),
|
||||||
|
"version": version.BuildVersion,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -79,7 +79,7 @@ type (
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
func genRoutes(dir, rootPkg string, cfg *config.Config, api *spec.ApiSpec) error {
|
func genRoutes(dir, rootPkg, projectPkg string, cfg *config.Config, api *spec.ApiSpec) error {
|
||||||
var builder strings.Builder
|
var builder strings.Builder
|
||||||
groups, err := getRoutes(api)
|
groups, err := getRoutes(api)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -211,6 +211,7 @@ rest.WithPrefix("%s"),`, g.prefix)
|
|||||||
"importPackages": genRouteImports(rootPkg, api),
|
"importPackages": genRouteImports(rootPkg, api),
|
||||||
"routesAdditions": strings.TrimSpace(builder.String()),
|
"routesAdditions": strings.TrimSpace(builder.String()),
|
||||||
"version": version.BuildVersion,
|
"version": version.BuildVersion,
|
||||||
|
"projectPkg": projectPkg,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
153
tools/goctl/api/gogen/gensse_test.go
Normal file
153
tools/goctl/api/gogen/gensse_test.go
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
package gogen
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSSEGeneration(t *testing.T) {
|
||||||
|
// Create a temporary directory for test
|
||||||
|
dir := t.TempDir()
|
||||||
|
|
||||||
|
// Create a test API file with SSE annotation
|
||||||
|
apiContent := `syntax = "v1"
|
||||||
|
|
||||||
|
type SseReq {
|
||||||
|
Message string ` + "`json:\"message\"`" + `
|
||||||
|
}
|
||||||
|
|
||||||
|
type SseResp {
|
||||||
|
Data string ` + "`json:\"data\"`" + `
|
||||||
|
}
|
||||||
|
|
||||||
|
@server (
|
||||||
|
sse: true
|
||||||
|
)
|
||||||
|
service Test {
|
||||||
|
@handler Sse
|
||||||
|
get /sse (SseReq) returns (SseResp)
|
||||||
|
}
|
||||||
|
`
|
||||||
|
apiFile := filepath.Join(dir, "test.api")
|
||||||
|
err := os.WriteFile(apiFile, []byte(apiContent), 0644)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Generate code
|
||||||
|
err = DoGenProject(apiFile, dir, "gozero", false)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Read generated handler file
|
||||||
|
handlerPath := filepath.Join(dir, "internal/handler/ssehandler.go")
|
||||||
|
handlerContent, err := os.ReadFile(handlerPath)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Read generated logic file
|
||||||
|
logicPath := filepath.Join(dir, "internal/logic/sselogic.go")
|
||||||
|
logicContent, err := os.ReadFile(logicPath)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
handlerStr := string(handlerContent)
|
||||||
|
logicStr := string(logicContent)
|
||||||
|
|
||||||
|
// Verify SSE-specific patterns in handler
|
||||||
|
// Handler should call: err := l.Sse(&req, client)
|
||||||
|
assert.Contains(t, handlerStr, "err := l.Sse(&req, client)",
|
||||||
|
"Handler should call logic with client channel parameter")
|
||||||
|
|
||||||
|
// Handler should NOT have the regular pattern: resp, err := l.Sse(&req)
|
||||||
|
assert.NotContains(t, handlerStr, "resp, err := l.Sse(&req)",
|
||||||
|
"Handler should not use regular pattern with resp return")
|
||||||
|
|
||||||
|
// Handler should use threading.GoSafeCtx
|
||||||
|
assert.Contains(t, handlerStr, "threading.GoSafeCtx",
|
||||||
|
"Handler should use threading.GoSafeCtx for SSE")
|
||||||
|
|
||||||
|
// Handler should create client channel
|
||||||
|
assert.Contains(t, handlerStr, "client := make(chan",
|
||||||
|
"Handler should create client channel")
|
||||||
|
|
||||||
|
// Verify SSE-specific patterns in logic
|
||||||
|
// Logic should have signature: Sse(req *types.SseReq, client chan<- *types.SseResp) error
|
||||||
|
assert.Contains(t, logicStr, "func (l *SseLogic) Sse(req *types.SseReq, client chan<- *types.SseResp) error",
|
||||||
|
"Logic should have SSE signature with client channel parameter")
|
||||||
|
|
||||||
|
// Logic should NOT have regular signature: Sse(req *types.SseReq) (resp *types.SseResp, err error)
|
||||||
|
assert.NotContains(t, logicStr, "(resp *types.SseResp, err error)",
|
||||||
|
"Logic should not have regular signature with resp return")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNonSSEGeneration(t *testing.T) {
|
||||||
|
// Create a temporary directory for test
|
||||||
|
dir := t.TempDir()
|
||||||
|
|
||||||
|
// Create a test API file WITHOUT SSE annotation
|
||||||
|
apiContent := `syntax = "v1"
|
||||||
|
|
||||||
|
type SseReq {
|
||||||
|
Message string ` + "`json:\"message\"`" + `
|
||||||
|
}
|
||||||
|
|
||||||
|
type SseResp {
|
||||||
|
Data string ` + "`json:\"data\"`" + `
|
||||||
|
}
|
||||||
|
|
||||||
|
service Test {
|
||||||
|
@handler Sse
|
||||||
|
get /sse (SseReq) returns (SseResp)
|
||||||
|
}
|
||||||
|
`
|
||||||
|
apiFile := filepath.Join(dir, "test.api")
|
||||||
|
err := os.WriteFile(apiFile, []byte(apiContent), 0644)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Generate code
|
||||||
|
err = DoGenProject(apiFile, dir, "gozero", false)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Read generated handler file
|
||||||
|
handlerPath := filepath.Join(dir, "internal/handler/ssehandler.go")
|
||||||
|
handlerContent, err := os.ReadFile(handlerPath)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Read generated logic file
|
||||||
|
logicPath := filepath.Join(dir, "internal/logic/sselogic.go")
|
||||||
|
logicContent, err := os.ReadFile(logicPath)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
handlerStr := string(handlerContent)
|
||||||
|
logicStr := string(logicContent)
|
||||||
|
|
||||||
|
// Verify regular (non-SSE) patterns in handler
|
||||||
|
// Handler should call: resp, err := l.Sse(&req)
|
||||||
|
assert.Contains(t, handlerStr, "resp, err := l.Sse(&req)",
|
||||||
|
"Handler should use regular pattern with resp return")
|
||||||
|
|
||||||
|
// Handler should NOT have SSE pattern: err := l.Sse(&req, client)
|
||||||
|
assert.NotContains(t, handlerStr, "err := l.Sse(&req, client)",
|
||||||
|
"Handler should not use SSE pattern")
|
||||||
|
|
||||||
|
// Handler should NOT use threading.GoSafeCtx
|
||||||
|
assert.NotContains(t, handlerStr, "threading.GoSafeCtx",
|
||||||
|
"Handler should not use threading.GoSafeCtx for regular routes")
|
||||||
|
|
||||||
|
// Verify regular (non-SSE) patterns in logic
|
||||||
|
// Logic should have signature: Sse(req *types.SseReq) (resp *types.SseResp, err error)
|
||||||
|
assert.Contains(t, logicStr, "(resp *types.SseResp, err error)",
|
||||||
|
"Logic should have regular signature with resp return")
|
||||||
|
|
||||||
|
// Logic should NOT have SSE signature with client parameter
|
||||||
|
linesToCheck := strings.Split(logicStr, "\n")
|
||||||
|
hasSSESignature := false
|
||||||
|
for _, line := range linesToCheck {
|
||||||
|
if strings.Contains(line, "func (l *SseLogic) Sse") && strings.Contains(line, "client chan<-") {
|
||||||
|
hasSSESignature = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.False(t, hasSSESignature,
|
||||||
|
"Logic should not have SSE signature with client channel parameter")
|
||||||
|
}
|
||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
|
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/config"
|
"github.com/zeromicro/go-zero/tools/goctl/config"
|
||||||
|
"github.com/zeromicro/go-zero/tools/goctl/internal/version"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/util/format"
|
"github.com/zeromicro/go-zero/tools/goctl/util/format"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
|
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/vars"
|
"github.com/zeromicro/go-zero/tools/goctl/vars"
|
||||||
@@ -17,7 +18,7 @@ const contextFilename = "service_context"
|
|||||||
//go:embed svc.tpl
|
//go:embed svc.tpl
|
||||||
var contextTemplate string
|
var contextTemplate string
|
||||||
|
|
||||||
func genServiceContext(dir, rootPkg string, cfg *config.Config, api *spec.ApiSpec) error {
|
func genServiceContext(dir, rootPkg, projectPkg string, cfg *config.Config, api *spec.ApiSpec) error {
|
||||||
filename, err := format.FileNamingFormat(cfg.NamingFormat, contextFilename)
|
filename, err := format.FileNamingFormat(cfg.NamingFormat, contextFilename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -53,6 +54,8 @@ func genServiceContext(dir, rootPkg string, cfg *config.Config, api *spec.ApiSpe
|
|||||||
"config": "config.Config",
|
"config": "config.Config",
|
||||||
"middleware": middlewareStr,
|
"middleware": middlewareStr,
|
||||||
"middlewareAssignment": middlewareAssignment,
|
"middlewareAssignment": middlewareAssignment,
|
||||||
|
"projectPkg": projectPkg,
|
||||||
|
"version": version.BuildVersion,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
34
tools/goctl/api/gogen/gensvctest.go
Normal file
34
tools/goctl/api/gogen/gensvctest.go
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
package gogen
|
||||||
|
|
||||||
|
import (
|
||||||
|
_ "embed"
|
||||||
|
|
||||||
|
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||||
|
"github.com/zeromicro/go-zero/tools/goctl/config"
|
||||||
|
"github.com/zeromicro/go-zero/tools/goctl/internal/version"
|
||||||
|
"github.com/zeromicro/go-zero/tools/goctl/util/format"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:embed svc_test.tpl
|
||||||
|
var svcTestTemplate string
|
||||||
|
|
||||||
|
func genServiceContextTest(dir, rootPkg, projectPkg string, cfg *config.Config, api *spec.ApiSpec) error {
|
||||||
|
filename, err := format.FileNamingFormat(cfg.NamingFormat, contextFilename)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return genFile(fileGenConfig{
|
||||||
|
dir: dir,
|
||||||
|
subdir: contextDir,
|
||||||
|
filename: filename + "_test.go",
|
||||||
|
templateName: "svcTestTemplate",
|
||||||
|
category: category,
|
||||||
|
templateFile: svcTestTemplateFile,
|
||||||
|
builtinTemplate: svcTestTemplate,
|
||||||
|
data: map[string]any{
|
||||||
|
"projectPkg": projectPkg,
|
||||||
|
"version": version.BuildVersion,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -1,3 +1,6 @@
|
|||||||
|
// Code scaffolded by goctl. Safe to edit.
|
||||||
|
// goctl {{.version}}
|
||||||
|
|
||||||
package {{.PkgName}}
|
package {{.PkgName}}
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -1,3 +1,6 @@
|
|||||||
|
// Code scaffolded by goctl. Safe to edit.
|
||||||
|
// goctl {{.version}}
|
||||||
|
|
||||||
package {{.PkgName}}
|
package {{.PkgName}}
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
116
tools/goctl/api/gogen/integration_test.tpl
Normal file
116
tools/goctl/api/gogen/integration_test.tpl
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
// Code scaffolded by goctl. Safe to edit.
|
||||||
|
// goctl {{.version}}
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"{{.projectPkg}}/internal/config"
|
||||||
|
"{{.projectPkg}}/internal/handler"
|
||||||
|
"{{.projectPkg}}/internal/svc"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/zeromicro/go-zero/rest"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMain(m *testing.M) {
|
||||||
|
// TODO: Add setup/teardown logic here if needed
|
||||||
|
m.Run()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerIntegration(t *testing.T) {
|
||||||
|
// Create test server
|
||||||
|
c := config.Config{
|
||||||
|
RestConf: rest.RestConf{
|
||||||
|
Host: "127.0.0.1",
|
||||||
|
Port: 0, // Use random available port
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
server := rest.MustNewServer(c.RestConf)
|
||||||
|
defer server.Stop()
|
||||||
|
|
||||||
|
ctx := svc.NewServiceContext(c)
|
||||||
|
handler.RegisterHandlers(server, ctx)
|
||||||
|
|
||||||
|
// Create serverless wrapper for testing
|
||||||
|
serverless, err := rest.NewServerless(server)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
method string
|
||||||
|
path string
|
||||||
|
body string
|
||||||
|
expectedStatus int
|
||||||
|
setup func()
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "health check",
|
||||||
|
method: http.MethodGet,
|
||||||
|
path: "/health",
|
||||||
|
expectedStatus: http.StatusNotFound, // Adjust based on actual routes
|
||||||
|
setup: func() {},
|
||||||
|
},
|
||||||
|
{{if .hasRoutes}}{{range .routes}}{
|
||||||
|
name: "{{.Method}} {{.Path}}",
|
||||||
|
method: "{{.Method}}",
|
||||||
|
path: "{{.Path}}",
|
||||||
|
expectedStatus: http.StatusOK, // TODO: Adjust expected status
|
||||||
|
setup: func() {
|
||||||
|
// TODO: Add setup logic for this endpoint
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{{end}}{{end}}{
|
||||||
|
name: "not found route",
|
||||||
|
method: http.MethodGet,
|
||||||
|
path: "/nonexistent",
|
||||||
|
expectedStatus: http.StatusNotFound,
|
||||||
|
setup: func() {},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tt.setup()
|
||||||
|
|
||||||
|
req, err := http.NewRequest(tt.method, tt.path, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
serverless.Serve(rr, req)
|
||||||
|
|
||||||
|
assert.Equal(t, tt.expectedStatus, rr.Code)
|
||||||
|
|
||||||
|
// TODO: Add response body assertions
|
||||||
|
t.Logf("Response: %s", rr.Body.String())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerLifecycle(t *testing.T) {
|
||||||
|
c := config.Config{
|
||||||
|
RestConf: rest.RestConf{
|
||||||
|
Host: "127.0.0.1",
|
||||||
|
Port: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
server := rest.MustNewServer(c.RestConf)
|
||||||
|
|
||||||
|
// Test server can start and stop without errors
|
||||||
|
ctx := svc.NewServiceContext(c)
|
||||||
|
handler.RegisterHandlers(server, ctx)
|
||||||
|
|
||||||
|
// In a real integration test, you might start the server in a goroutine
|
||||||
|
// and test actual HTTP requests, but for scaffolding we keep it simple
|
||||||
|
server.Stop()
|
||||||
|
|
||||||
|
// TODO: Add more lifecycle tests as needed
|
||||||
|
assert.True(t, true, "Server lifecycle test passed")
|
||||||
|
}
|
||||||
17
tools/goctl/api/gogen/jwt.api
Executable file
17
tools/goctl/api/gogen/jwt.api
Executable file
@@ -0,0 +1,17 @@
|
|||||||
|
type Request {
|
||||||
|
Name string `path:"name,options=you|me"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Response {
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
@server(
|
||||||
|
jwt: Auth
|
||||||
|
jwtTransition: Trans
|
||||||
|
middleware: TokenValidate
|
||||||
|
)
|
||||||
|
service A-api {
|
||||||
|
@handler GreetHandler
|
||||||
|
get /greet/from/:name(Request) returns (Response)
|
||||||
|
}
|
||||||
@@ -1,3 +1,6 @@
|
|||||||
|
// Code scaffolded by goctl. Safe to edit.
|
||||||
|
// goctl {{.version}}
|
||||||
|
|
||||||
package {{.pkgName}}
|
package {{.pkgName}}
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -1,3 +1,6 @@
|
|||||||
|
// Code scaffolded by goctl. Safe to edit.
|
||||||
|
// goctl {{.version}}
|
||||||
|
|
||||||
package {{.pkgName}}
|
package {{.pkgName}}
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -1,3 +1,6 @@
|
|||||||
|
// Code scaffolded by goctl. Safe to edit.
|
||||||
|
// goctl {{.version}}
|
||||||
|
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -1,3 +1,6 @@
|
|||||||
|
// Code scaffolded by goctl. Safe to edit.
|
||||||
|
// goctl {{.version}}
|
||||||
|
|
||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import "net/http"
|
import "net/http"
|
||||||
|
|||||||
@@ -1,3 +1,6 @@
|
|||||||
|
// Code scaffolded by goctl. Safe to edit.
|
||||||
|
// goctl {{.version}}
|
||||||
|
|
||||||
package {{.PkgName}}
|
package {{.PkgName}}
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -27,11 +30,10 @@ func {{.HandlerName}}(svcCtx *svc.ServiceContext) http.HandlerFunc {
|
|||||||
// w.Header().Set("Cache-Control", "no-cache")
|
// w.Header().Set("Cache-Control", "no-cache")
|
||||||
// w.Header().Set("Connection", "keep-alive")
|
// w.Header().Set("Connection", "keep-alive")
|
||||||
client := make(chan {{.ResponseType}}, 16)
|
client := make(chan {{.ResponseType}}, 16)
|
||||||
defer func() {
|
|
||||||
close(client)
|
|
||||||
}()
|
|
||||||
l := {{.LogicName}}.New{{.LogicType}}(r.Context(), svcCtx)
|
l := {{.LogicName}}.New{{.LogicType}}(r.Context(), svcCtx)
|
||||||
threading.GoSafeCtx(r.Context(), func() {
|
threading.GoSafeCtx(r.Context(), func() {
|
||||||
|
defer close(client)
|
||||||
err := l.{{.Call}}({{if .HasRequest}}&req, {{end}}client)
|
err := l.{{.Call}}({{if .HasRequest}}&req, {{end}}client)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logc.Errorw(r.Context(), "{{.HandlerName}}", logc.Field("error", err))
|
logc.Errorw(r.Context(), "{{.HandlerName}}", logc.Field("error", err))
|
||||||
@@ -41,7 +43,10 @@ func {{.HandlerName}}(svcCtx *svc.ServiceContext) http.HandlerFunc {
|
|||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case data := <-client:
|
case data, ok := <-client:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
output, err := json.Marshal(data)
|
output, err := json.Marshal(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logc.Errorw(r.Context(), "{{.HandlerName}}", logc.Field("error", err))
|
logc.Errorw(r.Context(), "{{.HandlerName}}", logc.Field("error", err))
|
||||||
|
|||||||
@@ -1,3 +1,6 @@
|
|||||||
|
// Code scaffolded by goctl. Safe to edit.
|
||||||
|
// goctl {{.version}}
|
||||||
|
|
||||||
package {{.pkgName}}
|
package {{.pkgName}}
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -1,3 +1,6 @@
|
|||||||
|
// Code scaffolded by goctl. Safe to edit.
|
||||||
|
// goctl {{.version}}
|
||||||
|
|
||||||
package svc
|
package svc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
60
tools/goctl/api/gogen/svc_test.tpl
Normal file
60
tools/goctl/api/gogen/svc_test.tpl
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
// Code scaffolded by goctl. Safe to edit.
|
||||||
|
// goctl {{.version}}
|
||||||
|
|
||||||
|
package svc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"{{.projectPkg}}/internal/config"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewServiceContext(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
config config.Config
|
||||||
|
setup func() config.Config
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "default config",
|
||||||
|
setup: func() config.Config {
|
||||||
|
return config.Config{}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid config",
|
||||||
|
setup: func() config.Config {
|
||||||
|
return config.Config{
|
||||||
|
// TODO: Add valid config values here
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
c := tt.setup()
|
||||||
|
svcCtx := NewServiceContext(c)
|
||||||
|
|
||||||
|
// Basic assertions
|
||||||
|
require.NotNil(t, svcCtx)
|
||||||
|
assert.Equal(t, c, svcCtx.Config)
|
||||||
|
|
||||||
|
// TODO: Add additional assertions for middleware and dependencies
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServiceContext_Initialization(t *testing.T) {
|
||||||
|
c := config.Config{}
|
||||||
|
svcCtx := NewServiceContext(c)
|
||||||
|
|
||||||
|
// Verify service context is properly initialized
|
||||||
|
assert.NotNil(t, svcCtx)
|
||||||
|
assert.Equal(t, c, svcCtx.Config)
|
||||||
|
|
||||||
|
// TODO: Add tests for middleware initialization if any
|
||||||
|
// TODO: Add tests for external dependencies if any
|
||||||
|
}
|
||||||
@@ -22,6 +22,8 @@ const (
|
|||||||
routesTemplateFile = "routes.tpl"
|
routesTemplateFile = "routes.tpl"
|
||||||
routesAdditionTemplateFile = "route-addition.tpl"
|
routesAdditionTemplateFile = "route-addition.tpl"
|
||||||
typesTemplateFile = "types.tpl"
|
typesTemplateFile = "types.tpl"
|
||||||
|
svcTestTemplateFile = "svc_test.tpl"
|
||||||
|
integrationTestTemplateFile = "integration_test.tpl"
|
||||||
)
|
)
|
||||||
|
|
||||||
var templates = map[string]string{
|
var templates = map[string]string{
|
||||||
@@ -39,6 +41,8 @@ var templates = map[string]string{
|
|||||||
routesTemplateFile: routesTemplate,
|
routesTemplateFile: routesTemplate,
|
||||||
routesAdditionTemplateFile: routesAdditionTemplate,
|
routesAdditionTemplateFile: routesAdditionTemplate,
|
||||||
typesTemplateFile: typesTemplate,
|
typesTemplateFile: typesTemplate,
|
||||||
|
svcTestTemplateFile: svcTestTemplate,
|
||||||
|
integrationTestTemplateFile: integrationTestTemplate,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Category returns the category of the api files.
|
// Category returns the category of the api files.
|
||||||
|
|||||||
@@ -27,6 +27,8 @@ var (
|
|||||||
VarStringBranch string
|
VarStringBranch string
|
||||||
// VarStringStyle describes the style of output files.
|
// VarStringStyle describes the style of output files.
|
||||||
VarStringStyle string
|
VarStringStyle string
|
||||||
|
// VarStringModule describes the module name for go.mod.
|
||||||
|
VarStringModule string
|
||||||
)
|
)
|
||||||
|
|
||||||
// CreateServiceCommand fast create service
|
// CreateServiceCommand fast create service
|
||||||
@@ -83,6 +85,6 @@ func CreateServiceCommand(_ *cobra.Command, args []string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = gogen.DoGenProject(apiFilePath, abs, VarStringStyle, false)
|
err = gogen.DoGenProjectWithModule(apiFilePath, abs, VarStringModule, VarStringStyle, false)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
205
tools/goctl/api/new/newservice_test.go
Normal file
205
tools/goctl/api/new/newservice_test.go
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
package new
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/zeromicro/go-zero/tools/goctl/api/gogen"
|
||||||
|
"github.com/zeromicro/go-zero/tools/goctl/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDoGenProjectWithModule_Integration(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
moduleName string
|
||||||
|
serviceName string
|
||||||
|
expectedMod string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "with custom module",
|
||||||
|
moduleName: "github.com/test/customapi",
|
||||||
|
serviceName: "myservice",
|
||||||
|
expectedMod: "github.com/test/customapi",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with empty module",
|
||||||
|
moduleName: "",
|
||||||
|
serviceName: "myservice",
|
||||||
|
expectedMod: "myservice",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with simple module",
|
||||||
|
moduleName: "simpleapi",
|
||||||
|
serviceName: "testapi",
|
||||||
|
expectedMod: "simpleapi",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Create temporary directory
|
||||||
|
tempDir, err := os.MkdirTemp("", "goctl-api-module-test-*")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer os.RemoveAll(tempDir)
|
||||||
|
|
||||||
|
// Create service directory
|
||||||
|
serviceDir := filepath.Join(tempDir, tt.serviceName)
|
||||||
|
err = os.MkdirAll(serviceDir, 0755)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Create a simple API file for testing
|
||||||
|
apiContent := `syntax = "v1"
|
||||||
|
|
||||||
|
type Request {
|
||||||
|
Name string ` + "`" + `path:"name,options=you|me"` + "`" + `
|
||||||
|
}
|
||||||
|
|
||||||
|
type Response {
|
||||||
|
Message string ` + "`" + `json:"message"` + "`" + `
|
||||||
|
}
|
||||||
|
|
||||||
|
service ` + tt.serviceName + `-api {
|
||||||
|
@handler ` + tt.serviceName + `Handler
|
||||||
|
get /from/:name(Request) returns (Response)
|
||||||
|
}
|
||||||
|
`
|
||||||
|
apiFile := filepath.Join(serviceDir, tt.serviceName+".api")
|
||||||
|
err = os.WriteFile(apiFile, []byte(apiContent), 0644)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Call the module-aware service creation function
|
||||||
|
err = gogen.DoGenProjectWithModule(apiFile, serviceDir, tt.moduleName, config.DefaultFormat, false)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Check go.mod file
|
||||||
|
goModPath := filepath.Join(serviceDir, "go.mod")
|
||||||
|
assert.FileExists(t, goModPath)
|
||||||
|
|
||||||
|
// Verify module name in go.mod
|
||||||
|
content, err := os.ReadFile(goModPath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Contains(t, string(content), "module "+tt.expectedMod)
|
||||||
|
|
||||||
|
// Check basic directory structure was created
|
||||||
|
assert.DirExists(t, filepath.Join(serviceDir, "etc"))
|
||||||
|
assert.DirExists(t, filepath.Join(serviceDir, "internal"))
|
||||||
|
assert.DirExists(t, filepath.Join(serviceDir, "internal", "handler"))
|
||||||
|
assert.DirExists(t, filepath.Join(serviceDir, "internal", "logic"))
|
||||||
|
assert.DirExists(t, filepath.Join(serviceDir, "internal", "svc"))
|
||||||
|
assert.DirExists(t, filepath.Join(serviceDir, "internal", "types"))
|
||||||
|
assert.DirExists(t, filepath.Join(serviceDir, "internal", "config"))
|
||||||
|
|
||||||
|
// Check that main.go imports use correct module
|
||||||
|
mainGoPath := filepath.Join(serviceDir, tt.serviceName+".go")
|
||||||
|
if _, err := os.Stat(mainGoPath); err == nil {
|
||||||
|
mainContent, err := os.ReadFile(mainGoPath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
// Check for import of internal packages with correct module path
|
||||||
|
assert.Contains(t, string(mainContent), `"`+tt.expectedMod+"/internal/")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateServiceCommand_Integration(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
moduleName string
|
||||||
|
serviceName string
|
||||||
|
expectedMod string
|
||||||
|
shouldError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid service with custom module",
|
||||||
|
moduleName: "github.com/example/testapi",
|
||||||
|
serviceName: "myapi",
|
||||||
|
expectedMod: "github.com/example/testapi",
|
||||||
|
shouldError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid service with no module",
|
||||||
|
moduleName: "",
|
||||||
|
serviceName: "simpleapi",
|
||||||
|
expectedMod: "simpleapi",
|
||||||
|
shouldError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid service name with hyphens",
|
||||||
|
moduleName: "github.com/test/api",
|
||||||
|
serviceName: "my-api",
|
||||||
|
expectedMod: "",
|
||||||
|
shouldError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if tt.shouldError && tt.serviceName == "my-api" {
|
||||||
|
// Test that service names with hyphens are rejected
|
||||||
|
// This is tested in the actual command function, not the generate function
|
||||||
|
assert.Contains(t, tt.serviceName, "-")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create temporary directory
|
||||||
|
tempDir, err := os.MkdirTemp("", "goctl-create-service-test-*")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer os.RemoveAll(tempDir)
|
||||||
|
|
||||||
|
// Change to temp directory
|
||||||
|
oldDir, _ := os.Getwd()
|
||||||
|
defer os.Chdir(oldDir)
|
||||||
|
os.Chdir(tempDir)
|
||||||
|
|
||||||
|
// Set the module variable as the command would
|
||||||
|
VarStringModule = tt.moduleName
|
||||||
|
VarStringStyle = config.DefaultFormat
|
||||||
|
|
||||||
|
// Create the service directory manually since we're testing the core functionality
|
||||||
|
serviceDir := filepath.Join(tempDir, tt.serviceName)
|
||||||
|
|
||||||
|
// Simulate what CreateServiceCommand does - create API file and call DoGenProjectWithModule
|
||||||
|
err = os.MkdirAll(serviceDir, 0755)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Create API file
|
||||||
|
apiContent := `syntax = "v1"
|
||||||
|
|
||||||
|
type Request {
|
||||||
|
Name string ` + "`" + `path:"name,options=you|me"` + "`" + `
|
||||||
|
}
|
||||||
|
|
||||||
|
type Response {
|
||||||
|
Message string ` + "`" + `json:"message"` + "`" + `
|
||||||
|
}
|
||||||
|
|
||||||
|
service ` + tt.serviceName + `-api {
|
||||||
|
@handler ` + tt.serviceName + `Handler
|
||||||
|
get /from/:name(Request) returns (Response)
|
||||||
|
}
|
||||||
|
`
|
||||||
|
apiFile := filepath.Join(serviceDir, tt.serviceName+".api")
|
||||||
|
err = os.WriteFile(apiFile, []byte(apiContent), 0644)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Call DoGenProjectWithModule as CreateServiceCommand does
|
||||||
|
err = gogen.DoGenProjectWithModule(apiFile, serviceDir, VarStringModule, VarStringStyle, false)
|
||||||
|
|
||||||
|
if tt.shouldError {
|
||||||
|
assert.Error(t, err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify go.mod
|
||||||
|
goModPath := filepath.Join(serviceDir, "go.mod")
|
||||||
|
assert.FileExists(t, goModPath)
|
||||||
|
content, err := os.ReadFile(goModPath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Contains(t, string(content), "module "+tt.expectedMod)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -28,7 +28,7 @@ type (
|
|||||||
syntax *SyntaxExpr
|
syntax *SyntaxExpr
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParserOption defines an function with argument Parser
|
// ParserOption defines a function with argument Parser
|
||||||
ParserOption func(p *Parser)
|
ParserOption func(p *Parser)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -268,7 +268,7 @@ func (v *ApiVisitor) VisitReplybody(ctx *api.ReplybodyContext) any {
|
|||||||
v.panic(lit.Expr(), fmt.Sprintf("expecting 'ID', but found golang keyword '%s'", lit.Expr().Text()))
|
v.panic(lit.Expr(), fmt.Sprintf("expecting 'ID', but found golang keyword '%s'", lit.Expr().Text()))
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
v.panic(dt.Expr(), fmt.Sprintf("unsupport %s", dt.Expr().Text()))
|
v.panic(dt.Expr(), fmt.Sprintf("unsupported %s", dt.Expr().Text()))
|
||||||
}
|
}
|
||||||
case *Literal:
|
case *Literal:
|
||||||
lit := dataType.Literal.Text()
|
lit := dataType.Literal.Text()
|
||||||
@@ -276,7 +276,7 @@ func (v *ApiVisitor) VisitReplybody(ctx *api.ReplybodyContext) any {
|
|||||||
v.panic(dataType.Literal, fmt.Sprintf("expecting 'ID', but found golang keyword '%s'", lit))
|
v.panic(dataType.Literal, fmt.Sprintf("expecting 'ID', but found golang keyword '%s'", lit))
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
v.panic(dt.Expr(), fmt.Sprintf("unsupport %s", dt.Expr().Text()))
|
v.panic(dt.Expr(), fmt.Sprintf("unsupported %s", dt.Expr().Text()))
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Body{
|
return &Body{
|
||||||
|
|||||||
@@ -190,7 +190,7 @@ func (v *ApiVisitor) VisitTypeBlockStruct(ctx *api.TypeBlockStructContext) any {
|
|||||||
structExpr := v.newExprWithToken(ctx.GetStructToken())
|
structExpr := v.newExprWithToken(ctx.GetStructToken())
|
||||||
structTokenText := ctx.GetStructToken().GetText()
|
structTokenText := ctx.GetStructToken().GetText()
|
||||||
if structTokenText != "struct" {
|
if structTokenText != "struct" {
|
||||||
v.panic(structExpr, fmt.Sprintf("expecting 'struct', found imput '%s'", structTokenText))
|
v.panic(structExpr, fmt.Sprintf("expecting 'struct', found input '%s'", structTokenText))
|
||||||
}
|
}
|
||||||
|
|
||||||
if api.IsGolangKeyWord(structTokenText, "struct") {
|
if api.IsGolangKeyWord(structTokenText, "struct") {
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ type parser struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Parse parses the api file.
|
// Parse parses the api file.
|
||||||
// Depreacted: use tools/goctl/pkg/parser/api/parser/parser.go:18 instead,
|
// Deprecated: use tools/goctl/pkg/parser/api/parser/parser.go:18 instead,
|
||||||
// it will be removed in the future.
|
// it will be removed in the future.
|
||||||
func Parse(filename string) (*spec.ApiSpec, error) {
|
func Parse(filename string) (*spec.ApiSpec, error) {
|
||||||
if env.UseExperimental() {
|
if env.UseExperimental() {
|
||||||
@@ -63,14 +63,14 @@ func parseContent(content string, skipCheckTypeDeclaration bool, filename ...str
|
|||||||
return apiSpec, nil
|
return apiSpec, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Depreacted: use tools/goctl/pkg/parser/api/parser/parser.go:18 instead,
|
// Deprecated: use tools/goctl/pkg/parser/api/parser/parser.go:18 instead,
|
||||||
// it will be removed in the future.
|
// it will be removed in the future.
|
||||||
// ParseContent parses the api content
|
// ParseContent parses the api content
|
||||||
func ParseContent(content string, filename ...string) (*spec.ApiSpec, error) {
|
func ParseContent(content string, filename ...string) (*spec.ApiSpec, error) {
|
||||||
return parseContent(content, false, filename...)
|
return parseContent(content, false, filename...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Depreacted: use tools/goctl/pkg/parser/api/parser/parser.go:18 instead,
|
// Deprecated: use tools/goctl/pkg/parser/api/parser/parser.go:18 instead,
|
||||||
// it will be removed in the future.
|
// it will be removed in the future.
|
||||||
// ParseContentWithParserSkipCheckTypeDeclaration parses the api content with skip check type declaration
|
// ParseContentWithParserSkipCheckTypeDeclaration parses the api content with skip check type declaration
|
||||||
func ParseContentWithParserSkipCheckTypeDeclaration(content string, filename ...string) (*spec.ApiSpec, error) {
|
func ParseContentWithParserSkipCheckTypeDeclaration(content string, filename ...string) (*spec.ApiSpec, error) {
|
||||||
@@ -227,7 +227,7 @@ func (p parser) astTypeToSpec(in ast.DataType) spec.Type {
|
|||||||
return spec.PointerType{RawName: v.PointerExpr.Text(), Type: spec.DefineStruct{RawName: raw}}
|
return spec.PointerType{RawName: v.PointerExpr.Text(), Type: spec.DefineStruct{RawName: raw}}
|
||||||
}
|
}
|
||||||
|
|
||||||
panic(fmt.Sprintf("unspported type %+v", in))
|
panic(fmt.Sprintf("unsupported type %+v", in))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p parser) stringExprs(docs []ast.Expr) []string {
|
func (p parser) stringExprs(docs []ast.Expr) []string {
|
||||||
|
|||||||
@@ -8,68 +8,71 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func getBoolFromKVOrDefault(properties map[string]string, key string, def bool) bool {
|
func getBoolFromKVOrDefault(properties map[string]string, key string, def bool) bool {
|
||||||
if len(properties) == 0 {
|
return getOrDefault(properties, key, def, func(str string, def bool) bool {
|
||||||
return def
|
res, err := strconv.ParseBool(str)
|
||||||
}
|
if err != nil {
|
||||||
md := metadata.New(properties)
|
return def
|
||||||
val := md.Get(key)
|
}
|
||||||
if len(val) == 0 {
|
|
||||||
return def
|
|
||||||
}
|
|
||||||
str := util.Unquote(val[0])
|
|
||||||
if len(str) == 0 {
|
|
||||||
return def
|
|
||||||
}
|
|
||||||
res, _ := strconv.ParseBool(str)
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
func getStringFromKVOrDefault(properties map[string]string, key string, def string) string {
|
return res
|
||||||
if len(properties) == 0 {
|
})
|
||||||
return def
|
|
||||||
}
|
|
||||||
md := metadata.New(properties)
|
|
||||||
val := md.Get(key)
|
|
||||||
if len(val) == 0 {
|
|
||||||
return def
|
|
||||||
}
|
|
||||||
str := util.Unquote(val[0])
|
|
||||||
if len(str) == 0 {
|
|
||||||
return def
|
|
||||||
}
|
|
||||||
return str
|
|
||||||
}
|
|
||||||
|
|
||||||
func getListFromInfoOrDefault(properties map[string]string, key string, def []string) []string {
|
|
||||||
if len(properties) == 0 {
|
|
||||||
return def
|
|
||||||
}
|
|
||||||
md := metadata.New(properties)
|
|
||||||
val := md.Get(key)
|
|
||||||
if len(val) == 0 {
|
|
||||||
return def
|
|
||||||
}
|
|
||||||
|
|
||||||
str := util.Unquote(val[0])
|
|
||||||
if len(str) == 0 {
|
|
||||||
return def
|
|
||||||
}
|
|
||||||
resp := util.FieldsAndTrimSpace(str, commaRune)
|
|
||||||
if len(resp) == 0 {
|
|
||||||
return def
|
|
||||||
}
|
|
||||||
return resp
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func getFirstUsableString(def ...string) string {
|
func getFirstUsableString(def ...string) string {
|
||||||
if len(def) == 0 {
|
if len(def) == 0 {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, val := range def {
|
for _, val := range def {
|
||||||
str := util.Unquote(val)
|
// Try to unquote if it's a quoted string
|
||||||
if len(str) != 0 {
|
if str, err := strconv.Unquote(val); err == nil && len(str) != 0 {
|
||||||
return str
|
return str
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Otherwise, use the value as-is if it's not empty
|
||||||
|
if len(val) != 0 {
|
||||||
|
return val
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getListFromInfoOrDefault(properties map[string]string, key string, def []string) []string {
|
||||||
|
return getOrDefault(properties, key, def, func(str string, def []string) []string {
|
||||||
|
resp := util.FieldsAndTrimSpace(str, commaRune)
|
||||||
|
if len(resp) == 0 {
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
return resp
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// getOrDefault abstracts the common logic for fetching, unquoting, and defaulting.
|
||||||
|
func getOrDefault[T any](properties map[string]string, key string, def T, convert func(string, T) T) T {
|
||||||
|
if len(properties) == 0 {
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
|
||||||
|
md := metadata.New(properties)
|
||||||
|
val := md.Get(key)
|
||||||
|
if len(val) == 0 {
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
|
||||||
|
str := val[0]
|
||||||
|
if unquoted, err := strconv.Unquote(str); err == nil {
|
||||||
|
str = unquoted
|
||||||
|
}
|
||||||
|
if len(str) == 0 {
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
|
||||||
|
return convert(str, def)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getStringFromKVOrDefault(properties map[string]string, key string, def string) string {
|
||||||
|
return getOrDefault(properties, key, def, func(str string, def string) string {
|
||||||
|
return str
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -21,6 +21,19 @@ func Test_getBoolFromKVOrDefault(t *testing.T) {
|
|||||||
assert.False(t, getBoolFromKVOrDefault(properties, "empty_value", false))
|
assert.False(t, getBoolFromKVOrDefault(properties, "empty_value", false))
|
||||||
assert.False(t, getBoolFromKVOrDefault(nil, "nil", false))
|
assert.False(t, getBoolFromKVOrDefault(nil, "nil", false))
|
||||||
assert.False(t, getBoolFromKVOrDefault(map[string]string{}, "empty", false))
|
assert.False(t, getBoolFromKVOrDefault(map[string]string{}, "empty", false))
|
||||||
|
|
||||||
|
// Test with unquoted values (as stored by RawText())
|
||||||
|
unquotedProperties := map[string]string{
|
||||||
|
"enabled": "true",
|
||||||
|
"disabled": "false",
|
||||||
|
"invalid": "notabool",
|
||||||
|
"empty_value": "",
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.True(t, getBoolFromKVOrDefault(unquotedProperties, "enabled", false))
|
||||||
|
assert.False(t, getBoolFromKVOrDefault(unquotedProperties, "disabled", true))
|
||||||
|
assert.False(t, getBoolFromKVOrDefault(unquotedProperties, "invalid", false))
|
||||||
|
assert.False(t, getBoolFromKVOrDefault(unquotedProperties, "empty_value", false))
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_getStringFromKVOrDefault(t *testing.T) {
|
func Test_getStringFromKVOrDefault(t *testing.T) {
|
||||||
@@ -34,6 +47,17 @@ func Test_getStringFromKVOrDefault(t *testing.T) {
|
|||||||
assert.Equal(t, "default", getStringFromKVOrDefault(properties, "missing", "default"))
|
assert.Equal(t, "default", getStringFromKVOrDefault(properties, "missing", "default"))
|
||||||
assert.Equal(t, "default", getStringFromKVOrDefault(nil, "nil", "default"))
|
assert.Equal(t, "default", getStringFromKVOrDefault(nil, "nil", "default"))
|
||||||
assert.Equal(t, "default", getStringFromKVOrDefault(map[string]string{}, "empty", "default"))
|
assert.Equal(t, "default", getStringFromKVOrDefault(map[string]string{}, "empty", "default"))
|
||||||
|
|
||||||
|
// Test with unquoted values (as stored by RawText())
|
||||||
|
unquotedProperties := map[string]string{
|
||||||
|
"name": "example",
|
||||||
|
"title": "Demo API",
|
||||||
|
"empty": "",
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, "example", getStringFromKVOrDefault(unquotedProperties, "name", "default"))
|
||||||
|
assert.Equal(t, "Demo API", getStringFromKVOrDefault(unquotedProperties, "title", "default"))
|
||||||
|
assert.Equal(t, "default", getStringFromKVOrDefault(unquotedProperties, "empty", "default"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_getListFromInfoOrDefault(t *testing.T) {
|
func Test_getListFromInfoOrDefault(t *testing.T) {
|
||||||
@@ -50,4 +74,123 @@ func Test_getListFromInfoOrDefault(t *testing.T) {
|
|||||||
assert.Equal(t, []string{"default"}, getListFromInfoOrDefault(map[string]string{
|
assert.Equal(t, []string{"default"}, getListFromInfoOrDefault(map[string]string{
|
||||||
"foo": ",,",
|
"foo": ",,",
|
||||||
}, "foo", []string{"default"}))
|
}, "foo", []string{"default"}))
|
||||||
|
|
||||||
|
// Test with unquoted values (as stored by RawText())
|
||||||
|
unquotedProperties := map[string]string{
|
||||||
|
"list": "a, b, c",
|
||||||
|
"schemes": "http,https",
|
||||||
|
"tags": "query",
|
||||||
|
"empty": "",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note: FieldsAndTrimSpace doesn't actually trim the spaces from returned values
|
||||||
|
assert.Equal(t, []string{"a", " b", " c"}, getListFromInfoOrDefault(unquotedProperties, "list", []string{"default"}))
|
||||||
|
assert.Equal(t, []string{"http", "https"}, getListFromInfoOrDefault(unquotedProperties, "schemes", []string{"default"}))
|
||||||
|
assert.Equal(t, []string{"query"}, getListFromInfoOrDefault(unquotedProperties, "tags", []string{"default"}))
|
||||||
|
assert.Equal(t, []string{"default"}, getListFromInfoOrDefault(unquotedProperties, "empty", []string{"default"}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_getFirstUsableString(t *testing.T) {
|
||||||
|
t.Run("empty input", func(t *testing.T) {
|
||||||
|
result := getFirstUsableString()
|
||||||
|
assert.Equal(t, "", result, "should return empty string for no arguments")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("single plain string", func(t *testing.T) {
|
||||||
|
result := getFirstUsableString("Check server health status.")
|
||||||
|
assert.Equal(t, "Check server health status.", result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("single quoted string", func(t *testing.T) {
|
||||||
|
// This is how Go would represent a quoted string literal
|
||||||
|
result := getFirstUsableString(`"Check server health status."`)
|
||||||
|
assert.Equal(t, "Check server health status.", result, "should unquote quoted strings")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("multiple plain strings", func(t *testing.T) {
|
||||||
|
result := getFirstUsableString("", "second", "third")
|
||||||
|
assert.Equal(t, "second", result, "should return first non-empty string")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("handler name fallback", func(t *testing.T) {
|
||||||
|
// Simulates the real use case: @doc text, handler name
|
||||||
|
result := getFirstUsableString("", "HealthCheck")
|
||||||
|
assert.Equal(t, "HealthCheck", result, "should fallback to handler name")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("doc text over handler name", func(t *testing.T) {
|
||||||
|
// Simulates the real use case with @doc text
|
||||||
|
result := getFirstUsableString("Check server health status.", "HealthCheck")
|
||||||
|
assert.Equal(t, "Check server health status.", result, "should use doc text over handler name")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty strings before valid", func(t *testing.T) {
|
||||||
|
result := getFirstUsableString("", "", "valid")
|
||||||
|
assert.Equal(t, "valid", result, "should skip empty strings")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("all empty strings", func(t *testing.T) {
|
||||||
|
result := getFirstUsableString("", "", "")
|
||||||
|
assert.Equal(t, "", result, "should return empty if all are empty")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("quoted then plain", func(t *testing.T) {
|
||||||
|
result := getFirstUsableString(`"quoted"`, "plain")
|
||||||
|
assert.Equal(t, "quoted", result, "should unquote first quoted string")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("plain then quoted", func(t *testing.T) {
|
||||||
|
result := getFirstUsableString("plain", `"quoted"`)
|
||||||
|
assert.Equal(t, "plain", result, "should use first plain string")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid quoted string", func(t *testing.T) {
|
||||||
|
// String that looks quoted but isn't valid Go syntax
|
||||||
|
result := getFirstUsableString(`"incomplete`, "fallback")
|
||||||
|
assert.Equal(t, `"incomplete`, result, "should use as-is if unquote fails but not empty")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("whitespace only", func(t *testing.T) {
|
||||||
|
result := getFirstUsableString(" ", "fallback")
|
||||||
|
assert.Equal(t, " ", result, "should not trim whitespace, return as-is")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("real world API doc scenario", func(t *testing.T) {
|
||||||
|
// This is the actual bug scenario from issue #5229
|
||||||
|
atDocText := "Check server health status."
|
||||||
|
handlerName := "HealthCheck"
|
||||||
|
|
||||||
|
result := getFirstUsableString(atDocText, handlerName)
|
||||||
|
assert.Equal(t, "Check server health status.", result,
|
||||||
|
"should use @doc text for API summary")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("real world with empty doc", func(t *testing.T) {
|
||||||
|
// When @doc is empty, should fall back to handler name
|
||||||
|
atDocText := ""
|
||||||
|
handlerName := "HealthCheck"
|
||||||
|
|
||||||
|
result := getFirstUsableString(atDocText, handlerName)
|
||||||
|
assert.Equal(t, "HealthCheck", result,
|
||||||
|
"should fallback to handler name when @doc is empty")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("complex summary with special characters", func(t *testing.T) {
|
||||||
|
result := getFirstUsableString("Get user by ID: /users/{id}")
|
||||||
|
assert.Equal(t, "Get user by ID: /users/{id}", result,
|
||||||
|
"should handle special characters in plain strings")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("multiline string", func(t *testing.T) {
|
||||||
|
result := getFirstUsableString("Line 1\nLine 2")
|
||||||
|
assert.Equal(t, "Line 1\nLine 2", result,
|
||||||
|
"should handle multiline strings")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unicode characters", func(t *testing.T) {
|
||||||
|
result := getFirstUsableString("健康检查", "HealthCheck")
|
||||||
|
assert.Equal(t, "健康检查", result,
|
||||||
|
"should handle unicode characters")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user