mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-05-21 21:58:18 +08:00
Compare commits
28 Commits
copilot/fi
...
v1.9.3
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
75941aedd4 | ||
|
|
c7065171d7 | ||
|
|
052de3b552 | ||
|
|
866613af8c | ||
|
|
3d4f6a5e16 | ||
|
|
d1d47d02d5 | ||
|
|
d6c876860b | ||
|
|
98423ca948 | ||
|
|
4e52d77ad8 | ||
|
|
1fc2cfb859 | ||
|
|
942cdae41d | ||
|
|
e9c3607bc6 | ||
|
|
d1603e9166 | ||
|
|
e30317e9c4 | ||
|
|
568f9ce007 | ||
|
|
dcb309065a | ||
|
|
bf8e17a686 | ||
|
|
b2ebbfce62 | ||
|
|
2b10a6a223 | ||
|
|
80c320b46e | ||
|
|
bea9d150a1 | ||
|
|
3f756a2cbf | ||
|
|
bbe5bbb0c0 | ||
|
|
5ad2278a69 | ||
|
|
77763fe748 | ||
|
|
538c4fb5c7 | ||
|
|
315fb2fe0a | ||
|
|
e382887eb8 |
6
.github/workflows/codeql-analysis.yml
vendored
6
.github/workflows/codeql-analysis.yml
vendored
@@ -39,7 +39,7 @@ jobs:
|
|||||||
|
|
||||||
# Initializes the CodeQL tools for scanning.
|
# Initializes the CodeQL tools for scanning.
|
||||||
- name: Initialize CodeQL
|
- name: Initialize CodeQL
|
||||||
uses: github/codeql-action/init@v3
|
uses: github/codeql-action/init@v4
|
||||||
with:
|
with:
|
||||||
languages: ${{ matrix.language }}
|
languages: ${{ matrix.language }}
|
||||||
# If you wish to specify custom queries, you can do so here or in a config file.
|
# If you wish to specify custom queries, you can do so here or in a config file.
|
||||||
@@ -50,7 +50,7 @@ jobs:
|
|||||||
# Autobuild attempts to build any compiled languages (C/C++, C#, or Java).
|
# Autobuild attempts to build any compiled languages (C/C++, C#, or Java).
|
||||||
# If this step fails, then you should remove it and run the build manually (see below)
|
# If this step fails, then you should remove it and run the build manually (see below)
|
||||||
- name: Autobuild
|
- name: Autobuild
|
||||||
uses: github/codeql-action/autobuild@v3
|
uses: github/codeql-action/autobuild@v4
|
||||||
|
|
||||||
# ℹ️ Command-line programs to run using the OS shell.
|
# ℹ️ Command-line programs to run using the OS shell.
|
||||||
# 📚 https://git.io/JvXDl
|
# 📚 https://git.io/JvXDl
|
||||||
@@ -64,4 +64,4 @@ jobs:
|
|||||||
# make release
|
# make release
|
||||||
|
|
||||||
- name: Perform CodeQL Analysis
|
- name: Perform CodeQL Analysis
|
||||||
uses: github/codeql-action/analyze@v3
|
uses: github/codeql-action/analyze@v4
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ type (
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
// New create a Filter, store is the backed redis, key is the key for the bloom filter,
|
// New creates a Filter, store is the backed redis, key is the key for the bloom filter,
|
||||||
// bits is how many bits will be used, maps is how many hashes for each addition.
|
// bits is how many bits will be used, maps is how many hashes for each addition.
|
||||||
// best practices:
|
// best practices:
|
||||||
// elements - means how many actual elements
|
// elements - means how many actual elements
|
||||||
|
|||||||
@@ -81,6 +81,10 @@ func (c *Cache) Del(key string) {
|
|||||||
delete(c.data, key)
|
delete(c.data, key)
|
||||||
c.lruCache.remove(key)
|
c.lruCache.remove(key)
|
||||||
c.lock.Unlock()
|
c.lock.Unlock()
|
||||||
|
|
||||||
|
// RemoveTimer is called outside the lock to avoid performance impact from this
|
||||||
|
// potentially time-consuming operation. Data integrity is maintained by lruCache,
|
||||||
|
// which will eventually evict any remaining entries when capacity is exceeded.
|
||||||
c.timingWheel.RemoveTimer(key)
|
c.timingWheel.RemoveTimer(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -168,7 +168,7 @@ func (s Stream) Count() (count int) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Distinct removes the duplicated items base on the given KeyFunc.
|
// Distinct removes the duplicated items based on the given KeyFunc.
|
||||||
func (s Stream) Distinct(fn KeyFunc) Stream {
|
func (s Stream) Distinct(fn KeyFunc) Stream {
|
||||||
source := make(chan any)
|
source := make(chan any)
|
||||||
|
|
||||||
@@ -459,7 +459,7 @@ func (s Stream) Tail(n int64) Stream {
|
|||||||
return Range(source)
|
return Range(source)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Walk lets the callers handle each item, the caller may write zero, one or more items base on the given item.
|
// Walk lets the callers handle each item, the caller may write zero, one or more items based on the given item.
|
||||||
func (s Stream) Walk(fn WalkFunc, opts ...Option) Stream {
|
func (s Stream) Walk(fn WalkFunc, opts ...Option) Stream {
|
||||||
option := buildOptions(opts...)
|
option := buildOptions(opts...)
|
||||||
if option.unlimitedWorkers {
|
if option.unlimitedWorkers {
|
||||||
|
|||||||
@@ -96,7 +96,7 @@ func (h *ConsistentHash) AddWithWeight(node any, weight int) {
|
|||||||
h.AddWithReplicas(node, replicas)
|
h.AddWithReplicas(node, replicas)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get returns the corresponding node from h base on the given v.
|
// Get returns the corresponding node from h based on the given v.
|
||||||
func (h *ConsistentHash) Get(v any) (any, bool) {
|
func (h *ConsistentHash) Get(v any) (any, bool) {
|
||||||
h.lock.RLock()
|
h.lock.RLock()
|
||||||
defer h.lock.RUnlock()
|
defer h.lock.RUnlock()
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ type (
|
|||||||
gzip bool
|
gzip bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// SizeLimitRotateRule a rotation rule that make the log file rotated base on size
|
// SizeLimitRotateRule a rotation rule that makes the log file rotated based on size
|
||||||
SizeLimitRotateRule struct {
|
SizeLimitRotateRule struct {
|
||||||
DailyRotateRule
|
DailyRotateRule
|
||||||
maxSize int64
|
maxSize int64
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// An Unstable is used to generate random value around the mean value base on given deviation.
|
// An Unstable is used to generate random value around the mean value based on given deviation.
|
||||||
type Unstable struct {
|
type Unstable struct {
|
||||||
deviation float64
|
deviation float64
|
||||||
r *rand.Rand
|
r *rand.Rand
|
||||||
|
|||||||
@@ -259,12 +259,34 @@ func (s *Redis) BitPosCtx(ctx context.Context, key string, bit, start, end int64
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Blpop uses passed in redis connection to execute blocking queries.
|
// Blpop uses passed in redis connection to execute blocking queries.
|
||||||
|
//
|
||||||
|
// For blocking operations, you must create a dedicated RedisNode using CreateBlockingNode to avoid
|
||||||
|
// exhausting the connection pool. Blocking commands hold connections for extended periods and should
|
||||||
|
// not share the regular connection pool.
|
||||||
|
//
|
||||||
|
// Example usage:
|
||||||
|
//
|
||||||
|
// node, err := redis.CreateBlockingNode(rds)
|
||||||
|
// if err != nil {
|
||||||
|
// // handle error
|
||||||
|
// }
|
||||||
|
// defer node.Close()
|
||||||
|
//
|
||||||
|
// value, err := rds.Blpop(node, "mylist")
|
||||||
|
// if err != nil {
|
||||||
|
// // handle error
|
||||||
|
// }
|
||||||
|
//
|
||||||
// Doesn't benefit from pooling redis connections of blocking queries
|
// Doesn't benefit from pooling redis connections of blocking queries
|
||||||
func (s *Redis) Blpop(node RedisNode, key string) (string, error) {
|
func (s *Redis) Blpop(node RedisNode, key string) (string, error) {
|
||||||
return s.BlpopCtx(context.Background(), node, key)
|
return s.BlpopCtx(context.Background(), node, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
// BlpopCtx uses passed in redis connection to execute blocking queries.
|
// BlpopCtx uses passed in redis connection to execute blocking queries.
|
||||||
|
//
|
||||||
|
// For blocking operations, you must create a dedicated RedisNode using CreateBlockingNode.
|
||||||
|
// See Blpop for usage examples.
|
||||||
|
//
|
||||||
// Doesn't benefit from pooling redis connections of blocking queries
|
// Doesn't benefit from pooling redis connections of blocking queries
|
||||||
func (s *Redis) BlpopCtx(ctx context.Context, node RedisNode, key string) (string, error) {
|
func (s *Redis) BlpopCtx(ctx context.Context, node RedisNode, key string) (string, error) {
|
||||||
return s.BlpopWithTimeoutCtx(ctx, node, blockingQueryTimeout, key)
|
return s.BlpopWithTimeoutCtx(ctx, node, blockingQueryTimeout, key)
|
||||||
@@ -272,12 +294,18 @@ func (s *Redis) BlpopCtx(ctx context.Context, node RedisNode, key string) (strin
|
|||||||
|
|
||||||
// BlpopEx uses passed in redis connection to execute blpop command.
|
// BlpopEx uses passed in redis connection to execute blpop command.
|
||||||
// The difference against Blpop is that this method returns a bool to indicate success.
|
// The difference against Blpop is that this method returns a bool to indicate success.
|
||||||
|
//
|
||||||
|
// For blocking operations, you must create a dedicated RedisNode using CreateBlockingNode.
|
||||||
|
// See Blpop for usage examples.
|
||||||
func (s *Redis) BlpopEx(node RedisNode, key string) (string, bool, error) {
|
func (s *Redis) BlpopEx(node RedisNode, key string) (string, bool, error) {
|
||||||
return s.BlpopExCtx(context.Background(), node, key)
|
return s.BlpopExCtx(context.Background(), node, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
// BlpopExCtx uses passed in redis connection to execute blpop command.
|
// BlpopExCtx uses passed in redis connection to execute blpop command.
|
||||||
// The difference against Blpop is that this method returns a bool to indicate success.
|
// The difference against Blpop is that this method returns a bool to indicate success.
|
||||||
|
//
|
||||||
|
// For blocking operations, you must create a dedicated RedisNode using CreateBlockingNode.
|
||||||
|
// See Blpop for usage examples.
|
||||||
func (s *Redis) BlpopExCtx(ctx context.Context, node RedisNode, key string) (string, bool, error) {
|
func (s *Redis) BlpopExCtx(ctx context.Context, node RedisNode, key string) (string, bool, error) {
|
||||||
if node == nil {
|
if node == nil {
|
||||||
return "", false, ErrNilNode
|
return "", false, ErrNilNode
|
||||||
@@ -297,12 +325,18 @@ func (s *Redis) BlpopExCtx(ctx context.Context, node RedisNode, key string) (str
|
|||||||
|
|
||||||
// BlpopWithTimeout uses passed in redis connection to execute blpop command.
|
// BlpopWithTimeout uses passed in redis connection to execute blpop command.
|
||||||
// Control blocking query timeout
|
// Control blocking query timeout
|
||||||
|
//
|
||||||
|
// For blocking operations, you must create a dedicated RedisNode using CreateBlockingNode.
|
||||||
|
// See Blpop for usage examples.
|
||||||
func (s *Redis) BlpopWithTimeout(node RedisNode, timeout time.Duration, key string) (string, error) {
|
func (s *Redis) BlpopWithTimeout(node RedisNode, timeout time.Duration, key string) (string, error) {
|
||||||
return s.BlpopWithTimeoutCtx(context.Background(), node, timeout, key)
|
return s.BlpopWithTimeoutCtx(context.Background(), node, timeout, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
// BlpopWithTimeoutCtx uses passed in redis connection to execute blpop command.
|
// BlpopWithTimeoutCtx uses passed in redis connection to execute blpop command.
|
||||||
// Control blocking query timeout
|
// Control blocking query timeout
|
||||||
|
//
|
||||||
|
// For blocking operations, you must create a dedicated RedisNode using CreateBlockingNode.
|
||||||
|
// See Blpop for usage examples.
|
||||||
func (s *Redis) BlpopWithTimeoutCtx(ctx context.Context, node RedisNode, timeout time.Duration,
|
func (s *Redis) BlpopWithTimeoutCtx(ctx context.Context, node RedisNode, timeout time.Duration,
|
||||||
key string) (string, error) {
|
key string) (string, error) {
|
||||||
if node == nil {
|
if node == nil {
|
||||||
@@ -1840,6 +1874,29 @@ func (s *Redis) XInfoStreamCtx(ctx context.Context, stream string) (*red.XInfoSt
|
|||||||
|
|
||||||
// XReadGroup reads messages from Redis streams as part of a consumer group.
|
// XReadGroup reads messages from Redis streams as part of a consumer group.
|
||||||
// It allows for distributed processing of stream messages with automatic message delivery semantics.
|
// It allows for distributed processing of stream messages with automatic message delivery semantics.
|
||||||
|
//
|
||||||
|
// For blocking operations, you must create a dedicated RedisNode using CreateBlockingNode to avoid
|
||||||
|
// exhausting the connection pool. Blocking commands hold connections for extended periods and should
|
||||||
|
// not share the regular connection pool.
|
||||||
|
//
|
||||||
|
// Example usage:
|
||||||
|
//
|
||||||
|
// node, err := redis.CreateBlockingNode(rds)
|
||||||
|
// if err != nil {
|
||||||
|
// // handle error
|
||||||
|
// }
|
||||||
|
// defer node.Close()
|
||||||
|
//
|
||||||
|
// streams, err := rds.XReadGroup(
|
||||||
|
// node, // RedisNode created with CreateBlockingNode
|
||||||
|
// "mygroup", // consumer group name
|
||||||
|
// "consumer1", // consumer ID
|
||||||
|
// 10, // max number of messages to read
|
||||||
|
// 5*time.Second, // block duration
|
||||||
|
// false, // noAck flag
|
||||||
|
// "mystream", // stream name
|
||||||
|
// )
|
||||||
|
//
|
||||||
// Doesn't benefit from pooling redis connections of blocking queries.
|
// Doesn't benefit from pooling redis connections of blocking queries.
|
||||||
func (s *Redis) XReadGroup(node RedisNode, group string, consumerId string, count int64,
|
func (s *Redis) XReadGroup(node RedisNode, group string, consumerId string, count int64,
|
||||||
block time.Duration, noAck bool, streams ...string) ([]red.XStream, error) {
|
block time.Duration, noAck bool, streams ...string) ([]red.XStream, error) {
|
||||||
@@ -1847,6 +1904,10 @@ func (s *Redis) XReadGroup(node RedisNode, group string, consumerId string, coun
|
|||||||
}
|
}
|
||||||
|
|
||||||
// XReadGroupCtx is the context-aware version of XReadGroup.
|
// XReadGroupCtx is the context-aware version of XReadGroup.
|
||||||
|
//
|
||||||
|
// For blocking operations, you must create a dedicated RedisNode using CreateBlockingNode to avoid
|
||||||
|
// exhausting the connection pool. See XReadGroup for usage examples.
|
||||||
|
//
|
||||||
// Doesn't benefit from pooling redis connections of blocking queries.
|
// Doesn't benefit from pooling redis connections of blocking queries.
|
||||||
func (s *Redis) XReadGroupCtx(ctx context.Context, node RedisNode, group string, consumerId string,
|
func (s *Redis) XReadGroupCtx(ctx context.Context, node RedisNode, group string, consumerId string,
|
||||||
count int64, block time.Duration, noAck bool, streams ...string) ([]red.XStream, error) {
|
count int64, block time.Duration, noAck bool, streams ...string) ([]red.XStream, error) {
|
||||||
|
|||||||
@@ -13,7 +13,37 @@ type ClosableNode interface {
|
|||||||
Close()
|
Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateBlockingNode returns a ClosableNode.
|
// CreateBlockingNode creates a dedicated RedisNode for blocking operations.
|
||||||
|
//
|
||||||
|
// Blocking Redis commands (like BLPOP, BRPOP, XREADGROUP with block parameter) hold connections
|
||||||
|
// for extended periods while waiting for data. Using them with the regular Redis connection pool
|
||||||
|
// can exhaust all available connections, causing other operations to fail or timeout.
|
||||||
|
//
|
||||||
|
// CreateBlockingNode creates a separate Redis client with a minimal connection pool (size 1) that
|
||||||
|
// is dedicated to blocking operations. This ensures blocking commands don't interfere with regular
|
||||||
|
// Redis operations.
|
||||||
|
//
|
||||||
|
// Example usage:
|
||||||
|
//
|
||||||
|
// rds := redis.MustNewRedis(redis.RedisConf{
|
||||||
|
// Host: "localhost:6379",
|
||||||
|
// Type: redis.NodeType,
|
||||||
|
// })
|
||||||
|
//
|
||||||
|
// // Create a dedicated node for blocking operations
|
||||||
|
// node, err := redis.CreateBlockingNode(rds)
|
||||||
|
// if err != nil {
|
||||||
|
// // handle error
|
||||||
|
// }
|
||||||
|
// defer node.Close() // Important: close the node when done
|
||||||
|
//
|
||||||
|
// // Use the node for blocking operations
|
||||||
|
// value, err := rds.Blpop(node, "mylist")
|
||||||
|
// if err != nil {
|
||||||
|
// // handle error
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// The returned ClosableNode must be closed when no longer needed to release resources.
|
||||||
func CreateBlockingNode(r *Redis) (ClosableNode, error) {
|
func CreateBlockingNode(r *Redis) (ClosableNode, error) {
|
||||||
timeout := readWriteTimeout + blockingQueryTimeout
|
timeout := readWriteTimeout + blockingQueryTimeout
|
||||||
|
|
||||||
|
|||||||
@@ -70,25 +70,16 @@ func getTaggedFieldValueMap(v reflect.Value) (map[string]any, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func getValueInterface(value reflect.Value) (any, error) {
|
func getValueInterface(value reflect.Value) (any, error) {
|
||||||
switch value.Kind() {
|
|
||||||
case reflect.Ptr:
|
|
||||||
if !value.CanInterface() {
|
|
||||||
return nil, ErrNotReadableValue
|
|
||||||
}
|
|
||||||
|
|
||||||
if value.IsNil() {
|
|
||||||
baseValueType := mapping.Deref(value.Type())
|
|
||||||
value.Set(reflect.New(baseValueType))
|
|
||||||
}
|
|
||||||
|
|
||||||
return value.Interface(), nil
|
|
||||||
default:
|
|
||||||
if !value.CanAddr() || !value.Addr().CanInterface() {
|
if !value.CanAddr() || !value.Addr().CanInterface() {
|
||||||
return nil, ErrNotReadableValue
|
return nil, ErrNotReadableValue
|
||||||
}
|
}
|
||||||
|
|
||||||
return value.Addr().Interface(), nil
|
if value.Kind() == reflect.Pointer && value.IsNil() {
|
||||||
|
baseValueType := mapping.Deref(value.Type())
|
||||||
|
value.Set(reflect.New(baseValueType))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return value.Addr().Interface(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func isScanFailed(err error) bool {
|
func isScanFailed(err error) bool {
|
||||||
|
|||||||
@@ -4,7 +4,9 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/DATA-DOG/go-sqlmock"
|
"github.com/DATA-DOG/go-sqlmock"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@@ -1575,6 +1577,782 @@ func TestAnonymousStructPrError(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUnmarshalRowsZeroValueStructPtr(t *testing.T) {
|
||||||
|
secondNamePtr := "second_ptr"
|
||||||
|
secondAgePtr := int64(30)
|
||||||
|
thirdNamePtr := "third_ptr"
|
||||||
|
thirdAgePtr := int64(0)
|
||||||
|
|
||||||
|
expect := []struct {
|
||||||
|
Name string
|
||||||
|
NamePtr *string
|
||||||
|
Age int64
|
||||||
|
AgePtr *int64
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
Name: "first",
|
||||||
|
NamePtr: nil,
|
||||||
|
Age: 2,
|
||||||
|
AgePtr: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "second",
|
||||||
|
NamePtr: &secondNamePtr,
|
||||||
|
Age: 3,
|
||||||
|
AgePtr: &secondAgePtr,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "",
|
||||||
|
NamePtr: &thirdNamePtr,
|
||||||
|
Age: 0,
|
||||||
|
AgePtr: &thirdAgePtr,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var value []struct {
|
||||||
|
Age int64 `db:"age"`
|
||||||
|
AgePtr *int64 `db:"age_ptr"`
|
||||||
|
Name string `db:"name"`
|
||||||
|
NamePtr *string `db:"name_ptr"`
|
||||||
|
}
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
rs := sqlmock.NewRows([]string{"name", "name_ptr", "age", "age_ptr"}).
|
||||||
|
AddRow("first", nil, 2, nil).
|
||||||
|
AddRow("second", "second_ptr", 3, 30).
|
||||||
|
AddRow("", "third_ptr", 0, 0)
|
||||||
|
|
||||||
|
mock.ExpectQuery("select (.+) from users where user=?").
|
||||||
|
WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
|
||||||
|
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRows(&value, rows, true)
|
||||||
|
}, "select name, name_ptr, age, age_ptr from users where user=?", "anyone"))
|
||||||
|
|
||||||
|
assert.Equal(t, 3, len(value), "应该返回3行数据")
|
||||||
|
|
||||||
|
for i, each := range expect {
|
||||||
|
|
||||||
|
assert.Equal(t, each.Name, value[i].Name)
|
||||||
|
assert.Equal(t, each.Age, value[i].Age)
|
||||||
|
|
||||||
|
if each.NamePtr == nil {
|
||||||
|
assert.Nil(t, value[i].NamePtr)
|
||||||
|
} else {
|
||||||
|
assert.NotNil(t, value[i].NamePtr)
|
||||||
|
assert.Equal(t, *each.NamePtr, *value[i].NamePtr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if each.AgePtr == nil {
|
||||||
|
assert.Nil(t, value[i].AgePtr)
|
||||||
|
} else {
|
||||||
|
assert.NotNil(t, value[i].AgePtr)
|
||||||
|
assert.Equal(t, *each.AgePtr, *value[i].AgePtr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnmarshalRowsAllNullStructPtrFields(t *testing.T) {
|
||||||
|
expect := []struct {
|
||||||
|
NamePtr *string
|
||||||
|
AgePtr *int64
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
NamePtr: nil,
|
||||||
|
AgePtr: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
NamePtr: stringPtr("second"),
|
||||||
|
AgePtr: int64Ptr(30),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
NamePtr: nil,
|
||||||
|
AgePtr: nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var value []struct {
|
||||||
|
AgePtr *int64 `db:"age_ptr"`
|
||||||
|
NamePtr *string `db:"name_ptr"`
|
||||||
|
}
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
rs := sqlmock.NewRows([]string{"name_ptr", "age_ptr"}).
|
||||||
|
AddRow(nil, nil).
|
||||||
|
AddRow("second", 30).
|
||||||
|
AddRow(nil, nil)
|
||||||
|
|
||||||
|
mock.ExpectQuery("select (.+) from users where user=?").
|
||||||
|
WithArgs("anyone").WillReturnRows(rs)
|
||||||
|
|
||||||
|
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRows(&value, rows, true)
|
||||||
|
}, "select name_ptr, age_ptr from users where user=?", "anyone"))
|
||||||
|
|
||||||
|
assert.Equal(t, 3, len(value))
|
||||||
|
|
||||||
|
for i, each := range expect {
|
||||||
|
if each.NamePtr == nil {
|
||||||
|
assert.Nil(t, value[i].NamePtr)
|
||||||
|
} else {
|
||||||
|
assert.NotNil(t, value[i].NamePtr)
|
||||||
|
assert.Equal(t, *each.NamePtr, *value[i].NamePtr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if each.AgePtr == nil {
|
||||||
|
assert.Nil(t, value[i].AgePtr)
|
||||||
|
} else {
|
||||||
|
assert.NotNil(t, value[i].AgePtr)
|
||||||
|
assert.Equal(t, *each.AgePtr, *value[i].AgePtr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnmarshalRowsWithSqlNullTypes(t *testing.T) {
|
||||||
|
expect := []struct {
|
||||||
|
Name string
|
||||||
|
NullName sql.NullString
|
||||||
|
Age int64
|
||||||
|
NullAge sql.NullInt64
|
||||||
|
Score float64
|
||||||
|
NullScore sql.NullFloat64
|
||||||
|
Active bool
|
||||||
|
NullActive sql.NullBool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
Name: "first",
|
||||||
|
NullName: sql.NullString{
|
||||||
|
String: "",
|
||||||
|
Valid: false,
|
||||||
|
},
|
||||||
|
Age: 20,
|
||||||
|
NullAge: sql.NullInt64{
|
||||||
|
Int64: 0,
|
||||||
|
Valid: false,
|
||||||
|
},
|
||||||
|
Score: 85.5,
|
||||||
|
NullScore: sql.NullFloat64{
|
||||||
|
Float64: 0,
|
||||||
|
Valid: false,
|
||||||
|
},
|
||||||
|
Active: true,
|
||||||
|
NullActive: sql.NullBool{
|
||||||
|
Bool: false,
|
||||||
|
Valid: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "second",
|
||||||
|
NullName: sql.NullString{
|
||||||
|
String: "not_null_name",
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
Age: 25,
|
||||||
|
NullAge: sql.NullInt64{
|
||||||
|
Int64: 30,
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
Score: 90.0,
|
||||||
|
NullScore: sql.NullFloat64{
|
||||||
|
Float64: 95.5,
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
Active: false,
|
||||||
|
NullActive: sql.NullBool{
|
||||||
|
Bool: true,
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "third",
|
||||||
|
NullName: sql.NullString{
|
||||||
|
String: "",
|
||||||
|
Valid: false,
|
||||||
|
},
|
||||||
|
Age: 0,
|
||||||
|
NullAge: sql.NullInt64{
|
||||||
|
Int64: 0,
|
||||||
|
Valid: false,
|
||||||
|
},
|
||||||
|
Score: 0,
|
||||||
|
NullScore: sql.NullFloat64{
|
||||||
|
Float64: 0,
|
||||||
|
Valid: false,
|
||||||
|
},
|
||||||
|
Active: false,
|
||||||
|
NullActive: sql.NullBool{
|
||||||
|
Bool: false,
|
||||||
|
Valid: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var value []struct {
|
||||||
|
Name string `db:"name"`
|
||||||
|
NullName sql.NullString `db:"null_name"`
|
||||||
|
Age int64 `db:"age"`
|
||||||
|
NullAge sql.NullInt64 `db:"null_age"`
|
||||||
|
Score float64 `db:"score"`
|
||||||
|
NullScore sql.NullFloat64 `db:"null_score"`
|
||||||
|
Active bool `db:"active"`
|
||||||
|
NullActive sql.NullBool `db:"null_active"`
|
||||||
|
}
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
rs := sqlmock.NewRows([]string{
|
||||||
|
"name", "null_name", "age", "null_age", "score", "null_score", "active", "null_active",
|
||||||
|
}).
|
||||||
|
AddRow("first", nil, 20, nil, 85.5, nil, true, nil).
|
||||||
|
AddRow("second", "not_null_name", 25, 30, 90.0, 95.5, false, true).
|
||||||
|
AddRow("third", nil, 0, nil, 0, nil, false, nil)
|
||||||
|
|
||||||
|
mock.ExpectQuery("select (.+) from users where type=?").
|
||||||
|
WithArgs("test").WillReturnRows(rs)
|
||||||
|
|
||||||
|
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRows(&value, rows, true)
|
||||||
|
}, "select name, null_name, age, null_age, score, null_score, active, null_active from users where type=?", "test"))
|
||||||
|
|
||||||
|
assert.Equal(t, 3, len(value))
|
||||||
|
|
||||||
|
for i, each := range expect {
|
||||||
|
assert.Equal(t, each.Name, value[i].Name)
|
||||||
|
assert.Equal(t, each.Age, value[i].Age)
|
||||||
|
assert.Equal(t, each.Score, value[i].Score)
|
||||||
|
assert.Equal(t, each.Active, value[i].Active)
|
||||||
|
|
||||||
|
assert.Equal(t, each.NullName.Valid, value[i].NullName.Valid)
|
||||||
|
if each.NullName.Valid {
|
||||||
|
assert.Equal(t, each.NullName.String, value[i].NullName.String)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, each.NullAge.Valid, value[i].NullAge.Valid)
|
||||||
|
if each.NullAge.Valid {
|
||||||
|
assert.Equal(t, each.NullAge.Int64, value[i].NullAge.Int64)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, each.NullScore.Valid, value[i].NullScore.Valid)
|
||||||
|
if each.NullScore.Valid {
|
||||||
|
assert.Equal(t, each.NullScore.Float64, value[i].NullScore.Float64)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, each.NullActive.Valid, value[i].NullActive.Valid)
|
||||||
|
if each.NullActive.Valid {
|
||||||
|
assert.Equal(t, each.NullActive.Bool, value[i].NullActive.Bool)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnmarshalRowsSqlNullWithMixedData(t *testing.T) {
|
||||||
|
expect := []struct {
|
||||||
|
Name string
|
||||||
|
NullName sql.NullString
|
||||||
|
Age int64
|
||||||
|
NullAge sql.NullInt64
|
||||||
|
IsStudent bool
|
||||||
|
NullActive sql.NullBool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
Name: "student1",
|
||||||
|
NullName: sql.NullString{
|
||||||
|
String: "",
|
||||||
|
Valid: false,
|
||||||
|
},
|
||||||
|
Age: 18,
|
||||||
|
NullAge: sql.NullInt64{
|
||||||
|
Int64: 0,
|
||||||
|
Valid: false,
|
||||||
|
},
|
||||||
|
IsStudent: true,
|
||||||
|
NullActive: sql.NullBool{
|
||||||
|
Bool: false,
|
||||||
|
Valid: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "student2",
|
||||||
|
NullName: sql.NullString{
|
||||||
|
String: "has_nickname",
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
Age: 20,
|
||||||
|
NullAge: sql.NullInt64{
|
||||||
|
Int64: 22,
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
IsStudent: false,
|
||||||
|
NullActive: sql.NullBool{
|
||||||
|
Bool: true,
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var value []struct {
|
||||||
|
Name string `db:"name"`
|
||||||
|
NullName sql.NullString `db:"null_name"`
|
||||||
|
Age int64 `db:"age"`
|
||||||
|
NullAge sql.NullInt64 `db:"null_age"`
|
||||||
|
IsStudent bool `db:"is_student"`
|
||||||
|
NullActive sql.NullBool `db:"null_active"`
|
||||||
|
}
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
rs := sqlmock.NewRows([]string{"name", "null_name", "age", "null_age", "is_student", "null_active"}).
|
||||||
|
AddRow("student1", nil, 18, nil, true, nil).
|
||||||
|
AddRow("student2", "has_nickname", 20, 22, false, true)
|
||||||
|
|
||||||
|
mock.ExpectQuery("select (.+) from students where class=?").
|
||||||
|
WithArgs("A").WillReturnRows(rs)
|
||||||
|
|
||||||
|
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRows(&value, rows, true)
|
||||||
|
}, "select name, null_name, age, null_age, is_student, null_active from students where class=?", "A"))
|
||||||
|
|
||||||
|
assert.Equal(t, 2, len(value))
|
||||||
|
|
||||||
|
for i, each := range expect {
|
||||||
|
assert.Equal(t, each.Name, value[i].Name)
|
||||||
|
assert.Equal(t, each.Age, value[i].Age)
|
||||||
|
assert.Equal(t, each.IsStudent, value[i].IsStudent)
|
||||||
|
|
||||||
|
assert.Equal(t, each.NullName.Valid, value[i].NullName.Valid)
|
||||||
|
if each.NullName.Valid {
|
||||||
|
assert.Equal(t, each.NullName.String, value[i].NullName.String)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, each.NullAge.Valid, value[i].NullAge.Valid)
|
||||||
|
if each.NullAge.Valid {
|
||||||
|
assert.Equal(t, each.NullAge.Int64, value[i].NullAge.Int64)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, each.NullActive.Valid, value[i].NullActive.Valid)
|
||||||
|
if each.NullActive.Valid {
|
||||||
|
assert.Equal(t, each.NullActive.Bool, value[i].NullActive.Bool)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnmarshalRowsSqlNullTime(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
futureTime := now.AddDate(1, 0, 0)
|
||||||
|
|
||||||
|
expect := []struct {
|
||||||
|
Name string
|
||||||
|
BirthDate sql.NullTime
|
||||||
|
LastLogin sql.NullTime
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
Name: "user1",
|
||||||
|
BirthDate: sql.NullTime{
|
||||||
|
Time: time.Time{},
|
||||||
|
Valid: false,
|
||||||
|
},
|
||||||
|
LastLogin: sql.NullTime{
|
||||||
|
Time: now,
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "user2",
|
||||||
|
BirthDate: sql.NullTime{
|
||||||
|
Time: futureTime,
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
LastLogin: sql.NullTime{
|
||||||
|
Time: time.Time{},
|
||||||
|
Valid: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var value []struct {
|
||||||
|
Name string `db:"name"`
|
||||||
|
BirthDate sql.NullTime `db:"birth_date"`
|
||||||
|
LastLogin sql.NullTime `db:"last_login"`
|
||||||
|
}
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
rs := sqlmock.NewRows([]string{"name", "birth_date", "last_login"}).
|
||||||
|
AddRow("user1", nil, now).
|
||||||
|
AddRow("user2", futureTime, nil)
|
||||||
|
|
||||||
|
mock.ExpectQuery("select (.+) from users").
|
||||||
|
WillReturnRows(rs)
|
||||||
|
|
||||||
|
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRows(&value, rows, true)
|
||||||
|
}, "select name, birth_date, last_login from users"))
|
||||||
|
|
||||||
|
assert.Equal(t, 2, len(value))
|
||||||
|
|
||||||
|
for i, each := range expect {
|
||||||
|
assert.Equal(t, each.Name, value[i].Name)
|
||||||
|
|
||||||
|
assert.Equal(t, each.BirthDate.Valid, value[i].BirthDate.Valid)
|
||||||
|
if each.BirthDate.Valid {
|
||||||
|
assert.WithinDuration(t, each.BirthDate.Time, value[i].BirthDate.Time, time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, each.LastLogin.Valid, value[i].LastLogin.Valid)
|
||||||
|
if each.LastLogin.Valid {
|
||||||
|
assert.WithinDuration(t, each.LastLogin.Time, value[i].LastLogin.Time, time.Second)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnmarshalRowsSqlNullWithEmptyValues(t *testing.T) {
|
||||||
|
expect := []struct {
|
||||||
|
Name string
|
||||||
|
NullString sql.NullString
|
||||||
|
NullInt sql.NullInt64
|
||||||
|
NullFloat sql.NullFloat64
|
||||||
|
NullBool sql.NullBool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
Name: "empty_values",
|
||||||
|
NullString: sql.NullString{
|
||||||
|
String: "",
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
NullInt: sql.NullInt64{
|
||||||
|
Int64: 0,
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
NullFloat: sql.NullFloat64{
|
||||||
|
Float64: 0.0,
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
NullBool: sql.NullBool{
|
||||||
|
Bool: false,
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "null_values",
|
||||||
|
NullString: sql.NullString{
|
||||||
|
String: "",
|
||||||
|
Valid: false,
|
||||||
|
},
|
||||||
|
NullInt: sql.NullInt64{
|
||||||
|
Int64: 0,
|
||||||
|
Valid: false,
|
||||||
|
},
|
||||||
|
NullFloat: sql.NullFloat64{
|
||||||
|
Float64: 0.0,
|
||||||
|
Valid: false,
|
||||||
|
},
|
||||||
|
NullBool: sql.NullBool{
|
||||||
|
Bool: false,
|
||||||
|
Valid: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "mixed_values",
|
||||||
|
NullString: sql.NullString{
|
||||||
|
String: "actual_value",
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
NullInt: sql.NullInt64{
|
||||||
|
Int64: 0,
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
NullFloat: sql.NullFloat64{
|
||||||
|
Float64: 0.0,
|
||||||
|
Valid: false,
|
||||||
|
},
|
||||||
|
NullBool: sql.NullBool{
|
||||||
|
Bool: true,
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var value []struct {
|
||||||
|
Name string `db:"name"`
|
||||||
|
NullString sql.NullString `db:"null_string"`
|
||||||
|
NullInt sql.NullInt64 `db:"null_int"`
|
||||||
|
NullFloat sql.NullFloat64 `db:"null_float"`
|
||||||
|
NullBool sql.NullBool `db:"null_bool"`
|
||||||
|
}
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
rs := sqlmock.NewRows([]string{"name", "null_string", "null_int", "null_float", "null_bool"}).
|
||||||
|
AddRow("empty_values", "", 0, 0.0, false).
|
||||||
|
AddRow("null_values", nil, nil, nil, nil).
|
||||||
|
AddRow("mixed_values", "actual_value", 0, nil, true)
|
||||||
|
|
||||||
|
mock.ExpectQuery("select (.+) from test_table").
|
||||||
|
WillReturnRows(rs)
|
||||||
|
|
||||||
|
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRows(&value, rows, true)
|
||||||
|
}, "select name, null_string, null_int, null_float, null_bool from test_table"))
|
||||||
|
|
||||||
|
assert.Equal(t, 3, len(value))
|
||||||
|
|
||||||
|
for i, each := range expect {
|
||||||
|
|
||||||
|
assert.Equal(t, each.Name, value[i].Name)
|
||||||
|
|
||||||
|
assert.Equal(t, each.NullString.Valid, value[i].NullString.Valid)
|
||||||
|
if each.NullString.Valid {
|
||||||
|
assert.Equal(t, each.NullString.String, value[i].NullString.String)
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, "", value[i].NullString.String)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, each.NullInt.Valid, value[i].NullInt.Valid)
|
||||||
|
if each.NullInt.Valid {
|
||||||
|
assert.Equal(t, each.NullInt.Int64, value[i].NullInt.Int64)
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, int64(0), value[i].NullInt.Int64)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, each.NullFloat.Valid, value[i].NullFloat.Valid)
|
||||||
|
if each.NullFloat.Valid {
|
||||||
|
assert.Equal(t, each.NullFloat.Float64, value[i].NullFloat.Float64)
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, 0.0, value[i].NullFloat.Float64)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, each.NullBool.Valid, value[i].NullBool.Valid)
|
||||||
|
if each.NullBool.Valid {
|
||||||
|
assert.Equal(t, each.NullBool.Bool, value[i].NullBool.Bool)
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, false, value[i].NullBool.Bool)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnmarshalRowsSqlNullStringEmptyVsNull(t *testing.T) {
|
||||||
|
expect := []struct {
|
||||||
|
Name string
|
||||||
|
EmptyString sql.NullString
|
||||||
|
NullString sql.NullString
|
||||||
|
NormalString sql.NullString
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
Name: "row1",
|
||||||
|
EmptyString: sql.NullString{
|
||||||
|
String: "",
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
NullString: sql.NullString{
|
||||||
|
String: "",
|
||||||
|
Valid: false,
|
||||||
|
},
|
||||||
|
NormalString: sql.NullString{
|
||||||
|
String: "hello",
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "row2",
|
||||||
|
EmptyString: sql.NullString{
|
||||||
|
String: " ",
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
NullString: sql.NullString{
|
||||||
|
String: "",
|
||||||
|
Valid: false,
|
||||||
|
},
|
||||||
|
NormalString: sql.NullString{
|
||||||
|
String: "",
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var value []struct {
|
||||||
|
Name string `db:"name"`
|
||||||
|
EmptyString sql.NullString `db:"empty_string"`
|
||||||
|
NullString sql.NullString `db:"null_string"`
|
||||||
|
NormalString sql.NullString `db:"normal_string"`
|
||||||
|
}
|
||||||
|
|
||||||
|
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||||
|
rs := sqlmock.NewRows([]string{"name", "empty_string", "null_string", "normal_string"}).
|
||||||
|
AddRow("row1", "", nil, "hello").
|
||||||
|
AddRow("row2", " ", nil, "")
|
||||||
|
|
||||||
|
mock.ExpectQuery("select (.+) from string_test").
|
||||||
|
WillReturnRows(rs)
|
||||||
|
|
||||||
|
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||||
|
return unmarshalRows(&value, rows, true)
|
||||||
|
}, "select name, empty_string, null_string, normal_string from string_test"))
|
||||||
|
|
||||||
|
assert.Equal(t, 2, len(value))
|
||||||
|
|
||||||
|
for i, each := range expect {
|
||||||
|
assert.True(t, value[i].EmptyString.Valid)
|
||||||
|
assert.Equal(t, each.EmptyString.String, value[i].EmptyString.String)
|
||||||
|
|
||||||
|
assert.False(t, value[i].NullString.Valid)
|
||||||
|
assert.Equal(t, "", value[i].NullString.String)
|
||||||
|
|
||||||
|
assert.Equal(t, each.NormalString.Valid, value[i].NormalString.Valid)
|
||||||
|
if each.NormalString.Valid {
|
||||||
|
assert.Equal(t, each.NormalString.String, value[i].NormalString.String)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetValueInterface(t *testing.T) {
|
||||||
|
t.Run("non_pointer_field", func(t *testing.T) {
|
||||||
|
type testStruct struct {
|
||||||
|
Name string
|
||||||
|
Age int
|
||||||
|
}
|
||||||
|
s := testStruct{}
|
||||||
|
v := reflect.ValueOf(&s).Elem()
|
||||||
|
|
||||||
|
nameField := v.Field(0)
|
||||||
|
result, err := getValueInterface(nameField)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
|
||||||
|
// Should return pointer to the field
|
||||||
|
ptr, ok := result.(*string)
|
||||||
|
assert.True(t, ok)
|
||||||
|
*ptr = "test"
|
||||||
|
assert.Equal(t, "test", s.Name)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("pointer_field_nil", func(t *testing.T) {
|
||||||
|
type testStruct struct {
|
||||||
|
NamePtr *string
|
||||||
|
AgePtr *int64
|
||||||
|
}
|
||||||
|
s := testStruct{}
|
||||||
|
v := reflect.ValueOf(&s).Elem()
|
||||||
|
|
||||||
|
// Test with nil pointer field
|
||||||
|
namePtrField := v.Field(0)
|
||||||
|
assert.True(t, namePtrField.IsNil(), "initial pointer should be nil")
|
||||||
|
|
||||||
|
result, err := getValueInterface(namePtrField)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
|
||||||
|
// Should have allocated the pointer
|
||||||
|
assert.False(t, namePtrField.IsNil(), "pointer should be allocated after getValueInterface")
|
||||||
|
|
||||||
|
// Should return pointer to pointer field
|
||||||
|
ptrPtr, ok := result.(**string)
|
||||||
|
assert.True(t, ok)
|
||||||
|
testValue := "initialized"
|
||||||
|
*ptrPtr = &testValue
|
||||||
|
assert.NotNil(t, s.NamePtr)
|
||||||
|
assert.Equal(t, "initialized", *s.NamePtr)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("pointer_field_already_allocated", func(t *testing.T) {
|
||||||
|
type testStruct struct {
|
||||||
|
NamePtr *string
|
||||||
|
}
|
||||||
|
initial := "existing"
|
||||||
|
s := testStruct{NamePtr: &initial}
|
||||||
|
v := reflect.ValueOf(&s).Elem()
|
||||||
|
|
||||||
|
namePtrField := v.Field(0)
|
||||||
|
assert.False(t, namePtrField.IsNil(), "pointer should not be nil initially")
|
||||||
|
|
||||||
|
result, err := getValueInterface(namePtrField)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
|
||||||
|
// Should return pointer to pointer field
|
||||||
|
ptrPtr, ok := result.(**string)
|
||||||
|
assert.True(t, ok)
|
||||||
|
|
||||||
|
// Verify it points to the existing value
|
||||||
|
assert.Equal(t, "existing", **ptrPtr)
|
||||||
|
|
||||||
|
// Modify through the returned pointer
|
||||||
|
newValue := "modified"
|
||||||
|
*ptrPtr = &newValue
|
||||||
|
assert.Equal(t, "modified", *s.NamePtr)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("pointer_field_zero_value", func(t *testing.T) {
|
||||||
|
type testStruct struct {
|
||||||
|
IntPtr *int
|
||||||
|
}
|
||||||
|
s := testStruct{}
|
||||||
|
v := reflect.ValueOf(&s).Elem()
|
||||||
|
|
||||||
|
intPtrField := v.Field(0)
|
||||||
|
result, err := getValueInterface(intPtrField)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// After calling getValueInterface, nil pointer should be allocated
|
||||||
|
assert.NotNil(t, s.IntPtr)
|
||||||
|
|
||||||
|
// Set zero value through returned interface
|
||||||
|
ptrPtr, ok := result.(**int)
|
||||||
|
assert.True(t, ok)
|
||||||
|
zero := 0
|
||||||
|
*ptrPtr = &zero
|
||||||
|
assert.Equal(t, 0, *s.IntPtr)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("not_addressable_value", func(t *testing.T) {
|
||||||
|
type testStruct struct {
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
s := testStruct{Name: "test"}
|
||||||
|
v := reflect.ValueOf(s) // Non-pointer, not addressable
|
||||||
|
|
||||||
|
nameField := v.Field(0)
|
||||||
|
result, err := getValueInterface(nameField)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Equal(t, ErrNotReadableValue, err)
|
||||||
|
assert.Nil(t, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("multiple_pointer_types", func(t *testing.T) {
|
||||||
|
type testStruct struct {
|
||||||
|
StringPtr *string
|
||||||
|
IntPtr *int
|
||||||
|
Int64Ptr *int64
|
||||||
|
FloatPtr *float64
|
||||||
|
BoolPtr *bool
|
||||||
|
}
|
||||||
|
s := testStruct{}
|
||||||
|
v := reflect.ValueOf(&s).Elem()
|
||||||
|
|
||||||
|
// Test each pointer type gets properly initialized
|
||||||
|
for i := 0; i < v.NumField(); i++ {
|
||||||
|
field := v.Field(i)
|
||||||
|
assert.True(t, field.IsNil(), "field %d should start as nil", i)
|
||||||
|
|
||||||
|
result, err := getValueInterface(field)
|
||||||
|
assert.NoError(t, err, "field %d should not error", i)
|
||||||
|
assert.NotNil(t, result, "field %d result should not be nil", i)
|
||||||
|
|
||||||
|
// After getValueInterface, pointer should be allocated
|
||||||
|
assert.False(t, field.IsNil(), "field %d should be allocated", i)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func stringPtr(s string) *string {
|
||||||
|
return &s
|
||||||
|
}
|
||||||
|
|
||||||
|
func int64Ptr(i int64) *int64 {
|
||||||
|
return &i
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkIgnore(b *testing.B) {
|
func BenchmarkIgnore(b *testing.B) {
|
||||||
db, mock, err := sqlmock.New()
|
db, mock, err := sqlmock.New()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/core/lang"
|
|
||||||
"github.com/zeromicro/go-zero/core/logx"
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
"go.opentelemetry.io/otel"
|
"go.opentelemetry.io/otel"
|
||||||
"go.opentelemetry.io/otel/exporters/jaeger"
|
"go.opentelemetry.io/otel/exporters/jaeger"
|
||||||
@@ -30,42 +29,36 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
agents = make(map[string]lang.PlaceholderType)
|
once sync.Once
|
||||||
lock sync.Mutex
|
|
||||||
tp *sdktrace.TracerProvider
|
tp *sdktrace.TracerProvider
|
||||||
|
shutdownOnceFn = sync.OnceFunc(func() {
|
||||||
|
if tp != nil {
|
||||||
|
_ = tp.Shutdown(context.Background())
|
||||||
|
}
|
||||||
|
})
|
||||||
)
|
)
|
||||||
|
|
||||||
// StartAgent starts an opentelemetry agent.
|
// StartAgent starts an opentelemetry agent.
|
||||||
|
// It uses sync.Once to ensure the agent is initialized only once,
|
||||||
|
// similar to prometheus.StartAgent and logx.SetUp.
|
||||||
|
// This prevents multiple ServiceConf.SetUp() calls from reinitializing
|
||||||
|
// the global tracer provider when running multiple servers (e.g., REST + RPC)
|
||||||
|
// in the same process.
|
||||||
func StartAgent(c Config) {
|
func StartAgent(c Config) {
|
||||||
if c.Disabled {
|
if c.Disabled {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
lock.Lock()
|
once.Do(func() {
|
||||||
defer lock.Unlock()
|
|
||||||
|
|
||||||
_, ok := agents[c.Endpoint]
|
|
||||||
if ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// if error happens, let later calls run.
|
|
||||||
if err := startAgent(c); err != nil {
|
if err := startAgent(c); err != nil {
|
||||||
return
|
logx.Error(err)
|
||||||
}
|
}
|
||||||
|
})
|
||||||
agents[c.Endpoint] = lang.Placeholder
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// StopAgent shuts down the span processors in the order they were registered.
|
// StopAgent shuts down the span processors in the order they were registered.
|
||||||
func StopAgent() {
|
func StopAgent() {
|
||||||
lock.Lock()
|
shutdownOnceFn()
|
||||||
defer lock.Unlock()
|
|
||||||
|
|
||||||
if tp != nil {
|
|
||||||
_ = tp.Shutdown(context.Background())
|
|
||||||
tp = nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func createExporter(c Config) (sdktrace.SpanExporter, error) {
|
func createExporter(c Config) (sdktrace.SpanExporter, error) {
|
||||||
|
|||||||
@@ -1,10 +1,13 @@
|
|||||||
package trace
|
package trace
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/zeromicro/go-zero/core/logx"
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
|
"go.opentelemetry.io/otel"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestStartAgent(t *testing.T) {
|
func TestStartAgent(t *testing.T) {
|
||||||
@@ -89,23 +92,305 @@ func TestStartAgent(t *testing.T) {
|
|||||||
StartAgent(c10)
|
StartAgent(c10)
|
||||||
defer StopAgent()
|
defer StopAgent()
|
||||||
|
|
||||||
lock.Lock()
|
// With sync.Once, only the first non-disabled config (c1) takes effect.
|
||||||
defer lock.Unlock()
|
// Subsequent calls are ignored, which is the desired behavior to prevent
|
||||||
|
// multiple servers (REST + RPC) from reinitializing the global tracer.
|
||||||
// because remotehost cannot be resolved
|
assert.NotNil(t, tp)
|
||||||
assert.Equal(t, 6, len(agents))
|
}
|
||||||
_, ok := agents[""]
|
|
||||||
assert.True(t, ok)
|
func TestCreateExporter_InvalidFilePath(t *testing.T) {
|
||||||
_, ok = agents[endpoint1]
|
logx.Disable()
|
||||||
assert.True(t, ok)
|
|
||||||
_, ok = agents[endpoint2]
|
c := Config{
|
||||||
assert.False(t, ok)
|
Name: "test-invalid-file",
|
||||||
_, ok = agents[endpoint5]
|
Endpoint: "/non-existent-directory/trace.log",
|
||||||
assert.True(t, ok)
|
Batcher: kindFile,
|
||||||
_, ok = agents[endpoint6]
|
}
|
||||||
assert.False(t, ok)
|
|
||||||
_, ok = agents[endpoint71]
|
_, err := createExporter(c)
|
||||||
assert.True(t, ok)
|
assert.Error(t, err)
|
||||||
_, ok = agents[endpoint72]
|
assert.Contains(t, err.Error(), "file exporter endpoint error")
|
||||||
assert.False(t, ok)
|
}
|
||||||
|
|
||||||
|
func TestCreateExporter_UnknownBatcher(t *testing.T) {
|
||||||
|
logx.Disable()
|
||||||
|
|
||||||
|
c := Config{
|
||||||
|
Name: "test-unknown",
|
||||||
|
Endpoint: "localhost:1234",
|
||||||
|
Batcher: "unknown-batcher-type",
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := createExporter(c)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "unknown exporter")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateExporter_ValidExporters(t *testing.T) {
|
||||||
|
logx.Disable()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
config Config
|
||||||
|
wantErr bool
|
||||||
|
errMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid file exporter",
|
||||||
|
config: Config{
|
||||||
|
Name: "file-test",
|
||||||
|
Endpoint: "/tmp/trace-test.log",
|
||||||
|
Batcher: kindFile,
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid file path",
|
||||||
|
config: Config{
|
||||||
|
Name: "file-test-invalid",
|
||||||
|
Endpoint: "/invalid-path/that/does/not/exist/trace.log",
|
||||||
|
Batcher: kindFile,
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errMsg: "file exporter endpoint error",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown batcher",
|
||||||
|
config: Config{
|
||||||
|
Name: "unknown-test",
|
||||||
|
Endpoint: "localhost:1234",
|
||||||
|
Batcher: "invalid-batcher",
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errMsg: "unknown exporter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "jaeger http",
|
||||||
|
config: Config{
|
||||||
|
Name: "jaeger-http",
|
||||||
|
Endpoint: "http://localhost:14268/api/traces",
|
||||||
|
Batcher: kindJaeger,
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "jaeger udp",
|
||||||
|
config: Config{
|
||||||
|
Name: "jaeger-udp",
|
||||||
|
Endpoint: "udp://localhost:6831",
|
||||||
|
Batcher: kindJaeger,
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zipkin",
|
||||||
|
config: Config{
|
||||||
|
Name: "zipkin",
|
||||||
|
Endpoint: "http://localhost:9411/api/v2/spans",
|
||||||
|
Batcher: kindZipkin,
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "otlpgrpc",
|
||||||
|
config: Config{
|
||||||
|
Name: "otlpgrpc",
|
||||||
|
Endpoint: "localhost:4317",
|
||||||
|
Batcher: kindOtlpGrpc,
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "otlpgrpc with headers",
|
||||||
|
config: Config{
|
||||||
|
Name: "otlpgrpc-headers",
|
||||||
|
Endpoint: "localhost:4317",
|
||||||
|
Batcher: kindOtlpGrpc,
|
||||||
|
OtlpHeaders: map[string]string{
|
||||||
|
"authorization": "Bearer token123",
|
||||||
|
"x-custom-key": "custom-value",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "otlphttp",
|
||||||
|
config: Config{
|
||||||
|
Name: "otlphttp",
|
||||||
|
Endpoint: "localhost:4318",
|
||||||
|
Batcher: kindOtlpHttp,
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "otlphttp with headers",
|
||||||
|
config: Config{
|
||||||
|
Name: "otlphttp-headers",
|
||||||
|
Endpoint: "localhost:4318",
|
||||||
|
Batcher: kindOtlpHttp,
|
||||||
|
OtlpHeaders: map[string]string{
|
||||||
|
"authorization": "Bearer token456",
|
||||||
|
"x-api-key": "api-key-value",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "otlphttp with headers and path",
|
||||||
|
config: Config{
|
||||||
|
Name: "otlphttp-headers-path",
|
||||||
|
Endpoint: "localhost:4318",
|
||||||
|
Batcher: kindOtlpHttp,
|
||||||
|
OtlpHttpPath: "/v1/traces",
|
||||||
|
OtlpHeaders: map[string]string{
|
||||||
|
"authorization": "Bearer token789",
|
||||||
|
"x-custom-trace": "trace-id",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "otlphttp with secure connection",
|
||||||
|
config: Config{
|
||||||
|
Name: "otlphttp-secure",
|
||||||
|
Endpoint: "localhost:4318",
|
||||||
|
Batcher: kindOtlpHttp,
|
||||||
|
OtlpHttpSecure: true,
|
||||||
|
OtlpHeaders: map[string]string{
|
||||||
|
"authorization": "Bearer secure-token",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
exporter, err := createExporter(tt.config)
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
if tt.errMsg != "" {
|
||||||
|
assert.Contains(t, err.Error(), tt.errMsg)
|
||||||
|
}
|
||||||
|
assert.Nil(t, exporter)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, exporter)
|
||||||
|
// Clean up the exporter
|
||||||
|
if exporter != nil {
|
||||||
|
_ = exporter.Shutdown(context.Background())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStopAgent(t *testing.T) {
|
||||||
|
logx.Disable()
|
||||||
|
|
||||||
|
// StopAgent should be idempotent and safe to call multiple times
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
StopAgent()
|
||||||
|
StopAgent()
|
||||||
|
StopAgent()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStartAgent_WithEndpoint(t *testing.T) {
|
||||||
|
logx.Disable()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
config Config
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty endpoint - no exporter created",
|
||||||
|
config: Config{
|
||||||
|
Name: "test-no-endpoint",
|
||||||
|
Sampler: 1.0,
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid endpoint with file exporter",
|
||||||
|
config: Config{
|
||||||
|
Name: "test-with-endpoint",
|
||||||
|
Endpoint: "/tmp/test-trace.log",
|
||||||
|
Batcher: kindFile,
|
||||||
|
Sampler: 1.0,
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "endpoint with invalid exporter type",
|
||||||
|
config: Config{
|
||||||
|
Name: "test-invalid-batcher",
|
||||||
|
Endpoint: "localhost:1234",
|
||||||
|
Batcher: "invalid-type",
|
||||||
|
Sampler: 1.0,
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "endpoint with invalid file path",
|
||||||
|
config: Config{
|
||||||
|
Name: "test-invalid-path",
|
||||||
|
Endpoint: "/non/existent/path/trace.log",
|
||||||
|
Batcher: kindFile,
|
||||||
|
Sampler: 1.0,
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Reset tp for each test
|
||||||
|
originalTp := tp
|
||||||
|
tp = nil
|
||||||
|
defer func() {
|
||||||
|
if tp != nil {
|
||||||
|
_ = tp.Shutdown(context.Background())
|
||||||
|
}
|
||||||
|
tp = originalTp
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := startAgent(tt.config)
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, tp, "TracerProvider should be created")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStartAgent_ErrorHandler(t *testing.T) {
|
||||||
|
// Setup a tracer provider to test error handler
|
||||||
|
originalTp := tp
|
||||||
|
tp = nil
|
||||||
|
defer func() {
|
||||||
|
if tp != nil {
|
||||||
|
_ = tp.Shutdown(context.Background())
|
||||||
|
}
|
||||||
|
tp = originalTp
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Call startAgent to set up the error handler
|
||||||
|
config := Config{
|
||||||
|
Name: "test-error-handler",
|
||||||
|
Sampler: 1.0,
|
||||||
|
}
|
||||||
|
err := startAgent(config)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, tp)
|
||||||
|
|
||||||
|
// Verify the error handler was set and can be called without panicking
|
||||||
|
// We test this by calling otel.Handle which will invoke the registered error handler
|
||||||
|
testErr := errors.New("test otel error")
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
otel.Handle(testErr)
|
||||||
|
}, "Error handler should handle errors without panicking")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,16 +11,40 @@ const (
|
|||||||
metadataPrefix = "gateway-"
|
metadataPrefix = "gateway-"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// OpenTelemetry trace propagation headers that need to be forwarded to gRPC metadata.
|
||||||
|
// These headers are used by the W3C Trace Context standard for distributed tracing.
|
||||||
|
var traceHeaders = map[string]bool{
|
||||||
|
"traceparent": true,
|
||||||
|
"tracestate": true,
|
||||||
|
"baggage": true,
|
||||||
|
}
|
||||||
|
|
||||||
// ProcessHeaders builds the headers for the gateway from HTTP headers.
|
// ProcessHeaders builds the headers for the gateway from HTTP headers.
|
||||||
|
// It forwards both custom metadata headers (with Grpc-Metadata- prefix)
|
||||||
|
// and OpenTelemetry trace propagation headers (traceparent, tracestate, baggage)
|
||||||
|
// to ensure distributed tracing works correctly across the gateway.
|
||||||
func ProcessHeaders(header http.Header) []string {
|
func ProcessHeaders(header http.Header) []string {
|
||||||
var headers []string
|
var headers []string
|
||||||
|
|
||||||
for k, v := range header {
|
for k, v := range header {
|
||||||
|
// Forward OpenTelemetry trace propagation headers
|
||||||
|
// These must be lowercase per gRPC metadata conventions
|
||||||
|
if lowerKey := strings.ToLower(k); traceHeaders[lowerKey] {
|
||||||
|
for _, vv := range v {
|
||||||
|
headers = append(headers, lowerKey+":"+vv)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward custom metadata headers with Grpc-Metadata- prefix
|
||||||
if !strings.HasPrefix(k, metadataHeaderPrefix) {
|
if !strings.HasPrefix(k, metadataHeaderPrefix) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
key := fmt.Sprintf("%s%s", metadataPrefix, strings.TrimPrefix(k, metadataHeaderPrefix))
|
// gRPC metadata keys are case-insensitive and stored as lowercase,
|
||||||
|
// so we lowercase the key to match gRPC conventions
|
||||||
|
trimmedKey := strings.TrimPrefix(k, metadataHeaderPrefix)
|
||||||
|
key := strings.ToLower(fmt.Sprintf("%s%s", metadataPrefix, trimmedKey))
|
||||||
for _, vv := range v {
|
for _, vv := range v {
|
||||||
headers = append(headers, key+":"+vv)
|
headers = append(headers, key+":"+vv)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,5 +18,93 @@ func TestBuildHeadersWithValues(t *testing.T) {
|
|||||||
req := httptest.NewRequest("GET", "/", http.NoBody)
|
req := httptest.NewRequest("GET", "/", http.NoBody)
|
||||||
req.Header.Add("grpc-metadata-a", "b")
|
req.Header.Add("grpc-metadata-a", "b")
|
||||||
req.Header.Add("grpc-metadata-b", "b")
|
req.Header.Add("grpc-metadata-b", "b")
|
||||||
assert.ElementsMatch(t, []string{"gateway-A:b", "gateway-B:b"}, ProcessHeaders(req.Header))
|
assert.ElementsMatch(t, []string{"gateway-a:b", "gateway-b:b"}, ProcessHeaders(req.Header))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessHeadersWithTraceContext(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "/", http.NoBody)
|
||||||
|
req.Header.Set("traceparent", "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01")
|
||||||
|
req.Header.Set("tracestate", "key1=value1,key2=value2")
|
||||||
|
req.Header.Set("baggage", "userId=alice,serverNode=DF:28")
|
||||||
|
|
||||||
|
headers := ProcessHeaders(req.Header)
|
||||||
|
|
||||||
|
assert.Len(t, headers, 3)
|
||||||
|
assert.Contains(t, headers, "traceparent:00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01")
|
||||||
|
assert.Contains(t, headers, "tracestate:key1=value1,key2=value2")
|
||||||
|
assert.Contains(t, headers, "baggage:userId=alice,serverNode=DF:28")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessHeadersWithMixedHeaders(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "/", http.NoBody)
|
||||||
|
req.Header.Set("traceparent", "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01")
|
||||||
|
req.Header.Set("grpc-metadata-custom", "value1")
|
||||||
|
req.Header.Set("content-type", "application/json")
|
||||||
|
req.Header.Set("tracestate", "key1=value1")
|
||||||
|
|
||||||
|
headers := ProcessHeaders(req.Header)
|
||||||
|
|
||||||
|
// Should include trace headers and grpc-metadata headers, but not regular headers
|
||||||
|
assert.Len(t, headers, 3)
|
||||||
|
assert.Contains(t, headers, "traceparent:00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01")
|
||||||
|
assert.Contains(t, headers, "tracestate:key1=value1")
|
||||||
|
assert.Contains(t, headers, "gateway-custom:value1")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessHeadersTraceparentCaseInsensitive(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
headerKey string
|
||||||
|
headerVal string
|
||||||
|
expectedKey string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "lowercase traceparent",
|
||||||
|
headerKey: "traceparent",
|
||||||
|
headerVal: "00-trace-span-01",
|
||||||
|
expectedKey: "traceparent",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "uppercase Traceparent",
|
||||||
|
headerKey: "Traceparent",
|
||||||
|
headerVal: "00-trace-span-01",
|
||||||
|
expectedKey: "traceparent",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed case TraceParent",
|
||||||
|
headerKey: "TraceParent",
|
||||||
|
headerVal: "00-trace-span-01",
|
||||||
|
expectedKey: "traceparent",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "lowercase tracestate",
|
||||||
|
headerKey: "tracestate",
|
||||||
|
headerVal: "key=value",
|
||||||
|
expectedKey: "tracestate",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed case TraceState",
|
||||||
|
headerKey: "TraceState",
|
||||||
|
headerVal: "key=value",
|
||||||
|
expectedKey: "tracestate",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "/", http.NoBody)
|
||||||
|
req.Header.Set(tt.headerKey, tt.headerVal)
|
||||||
|
|
||||||
|
headers := ProcessHeaders(req.Header)
|
||||||
|
|
||||||
|
assert.Len(t, headers, 1)
|
||||||
|
assert.Contains(t, headers, tt.expectedKey+":"+tt.headerVal)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessHeadersEmptyHeaders(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "/", http.NoBody)
|
||||||
|
headers := ProcessHeaders(req.Header)
|
||||||
|
assert.Empty(t, headers)
|
||||||
}
|
}
|
||||||
|
|||||||
4
go.mod
4
go.mod
@@ -16,12 +16,12 @@ require (
|
|||||||
github.com/jhump/protoreflect v1.17.0
|
github.com/jhump/protoreflect v1.17.0
|
||||||
github.com/pelletier/go-toml/v2 v2.2.2
|
github.com/pelletier/go-toml/v2 v2.2.2
|
||||||
github.com/prometheus/client_golang v1.21.1
|
github.com/prometheus/client_golang v1.21.1
|
||||||
github.com/redis/go-redis/v9 v9.15.0
|
github.com/redis/go-redis/v9 v9.16.0
|
||||||
github.com/spaolacci/murmur3 v1.1.0
|
github.com/spaolacci/murmur3 v1.1.0
|
||||||
github.com/stretchr/testify v1.11.1
|
github.com/stretchr/testify v1.11.1
|
||||||
go.etcd.io/etcd/api/v3 v3.5.15
|
go.etcd.io/etcd/api/v3 v3.5.15
|
||||||
go.etcd.io/etcd/client/v3 v3.5.15
|
go.etcd.io/etcd/client/v3 v3.5.15
|
||||||
go.mongodb.org/mongo-driver/v2 v2.3.0
|
go.mongodb.org/mongo-driver/v2 v2.4.0
|
||||||
go.opentelemetry.io/otel v1.24.0
|
go.opentelemetry.io/otel v1.24.0
|
||||||
go.opentelemetry.io/otel/exporters/jaeger v1.17.0
|
go.opentelemetry.io/otel/exporters/jaeger v1.17.0
|
||||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.24.0
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.24.0
|
||||||
|
|||||||
8
go.sum
8
go.sum
@@ -154,8 +154,8 @@ github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ
|
|||||||
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
|
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
|
||||||
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
|
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
|
||||||
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
|
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
|
||||||
github.com/redis/go-redis/v9 v9.15.0 h1:2jdes0xJxer4h3NUZrZ4OGSntGlXp4WbXju2nOTRXto=
|
github.com/redis/go-redis/v9 v9.16.0 h1:OotgqgLSRCmzfqChbQyG1PHC3tLNR89DG4jdOERSEP4=
|
||||||
github.com/redis/go-redis/v9 v9.15.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
|
github.com/redis/go-redis/v9 v9.16.0/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
|
||||||
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
||||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||||
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
||||||
@@ -197,8 +197,8 @@ go.etcd.io/etcd/client/pkg/v3 v3.5.15 h1:fo0HpWz/KlHGMCC+YejpiCmyWDEuIpnTDzpJLB5
|
|||||||
go.etcd.io/etcd/client/pkg/v3 v3.5.15/go.mod h1:mXDI4NAOwEiszrHCb0aqfAYNCrZP4e9hRca3d1YK8EU=
|
go.etcd.io/etcd/client/pkg/v3 v3.5.15/go.mod h1:mXDI4NAOwEiszrHCb0aqfAYNCrZP4e9hRca3d1YK8EU=
|
||||||
go.etcd.io/etcd/client/v3 v3.5.15 h1:23M0eY4Fd/inNv1ZfU3AxrbbOdW79r9V9Rl62Nm6ip4=
|
go.etcd.io/etcd/client/v3 v3.5.15 h1:23M0eY4Fd/inNv1ZfU3AxrbbOdW79r9V9Rl62Nm6ip4=
|
||||||
go.etcd.io/etcd/client/v3 v3.5.15/go.mod h1:CLSJxrYjvLtHsrPKsy7LmZEE+DK2ktfd2bN4RhBMwlU=
|
go.etcd.io/etcd/client/v3 v3.5.15/go.mod h1:CLSJxrYjvLtHsrPKsy7LmZEE+DK2ktfd2bN4RhBMwlU=
|
||||||
go.mongodb.org/mongo-driver/v2 v2.3.0 h1:sh55yOXA2vUjW1QYw/2tRlHSQViwDyPnW61AwpZ4rtU=
|
go.mongodb.org/mongo-driver/v2 v2.4.0 h1:Oq6BmUAAFTzMeh6AonuDlgZMuAuEiUxoAD1koK5MuFo=
|
||||||
go.mongodb.org/mongo-driver/v2 v2.3.0/go.mod h1:jHeEDJHJq7tm6ZF45Issun9dbogjfnPySb1vXA7EeAI=
|
go.mongodb.org/mongo-driver/v2 v2.4.0/go.mod h1:jHeEDJHJq7tm6ZF45Issun9dbogjfnPySb1vXA7EeAI=
|
||||||
go.opentelemetry.io/otel v1.24.0 h1:0LAOdjNmQeSTzGBzduGe/rU4tZhMwL5rWgtp9Ku5Jfo=
|
go.opentelemetry.io/otel v1.24.0 h1:0LAOdjNmQeSTzGBzduGe/rU4tZhMwL5rWgtp9Ku5Jfo=
|
||||||
go.opentelemetry.io/otel v1.24.0/go.mod h1:W7b9Ozg4nkF5tWI5zsXkaKKDjdVjpD4oAt9Qi/MArHo=
|
go.opentelemetry.io/otel v1.24.0/go.mod h1:W7b9Ozg4nkF5tWI5zsXkaKKDjdVjpD4oAt9Qi/MArHo=
|
||||||
go.opentelemetry.io/otel/exporters/jaeger v1.17.0 h1:D7UpUy2Xc2wsi1Ras6V40q806WM07rqoCWzXu7Sqy+4=
|
go.opentelemetry.io/otel/exporters/jaeger v1.17.0 h1:D7UpUy2Xc2wsi1Ras6V40q806WM07rqoCWzXu7Sqy+4=
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ func AddProbe(probe Probe) {
|
|||||||
defaultHealthManager.addProbe(probe)
|
defaultHealthManager.addProbe(probe)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateHttpHandler create health http handler base on given probe.
|
// CreateHttpHandler creates a health http handler based on the given probe.
|
||||||
func CreateHttpHandler(healthResponse string) http.HandlerFunc {
|
func CreateHttpHandler(healthResponse string) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, _ *http.Request) {
|
return func(w http.ResponseWriter, _ *http.Request) {
|
||||||
if defaultHealthManager.IsReady() {
|
if defaultHealthManager.IsReady() {
|
||||||
|
|||||||
@@ -175,7 +175,7 @@ GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/zeromicro
|
|||||||
|
|
||||||
* API 文档
|
* API 文档
|
||||||
|
|
||||||
[https://go-zero.dev/cn/](https://go-zero.dev/cn/)
|
[https://go-zero.dev](https://go-zero.dev)
|
||||||
|
|
||||||
* awesome 系列(更多文章见『微服务实践』公众号)
|
* awesome 系列(更多文章见『微服务实践』公众号)
|
||||||
|
|
||||||
@@ -305,6 +305,9 @@ go-zero 已被许多公司用于生产部署,接入场景如在线教育、电
|
|||||||
>107. 深圳市聚货通信息科技有限公司
|
>107. 深圳市聚货通信息科技有限公司
|
||||||
>108. 浙江银盾云科技有限公司
|
>108. 浙江银盾云科技有限公司
|
||||||
>109. 南京造世网络科技有限公司
|
>109. 南京造世网络科技有限公司
|
||||||
|
>110. 温州飞儿云信息技术有限公司
|
||||||
|
>111. 统信软件
|
||||||
|
>112. 深圳坐标软件集团有限公司
|
||||||
|
|
||||||
如果贵公司也已使用 go-zero,欢迎在 [登记地址](https://github.com/zeromicro/go-zero/issues/602) 登记,仅仅为了推广,不做其它用途。
|
如果贵公司也已使用 go-zero,欢迎在 [登记地址](https://github.com/zeromicro/go-zero/issues/602) 登记,仅仅为了推广,不做其它用途。
|
||||||
|
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ import '../vars/vars.dart';
|
|||||||
/// Send GET request.
|
/// Send GET request.
|
||||||
///
|
///
|
||||||
/// ok: the function that will be called on success.
|
/// ok: the function that will be called on success.
|
||||||
/// fail:the fuction that will be called on failure.
|
/// fail:the function that will be called on failure.
|
||||||
/// eventually:the function that will be called regardless of success or failure.
|
/// eventually:the function that will be called regardless of success or failure.
|
||||||
Future apiGet(String path,
|
Future apiGet(String path,
|
||||||
{Map<String, String> header,
|
{Map<String, String> header,
|
||||||
@@ -47,7 +47,7 @@ Future apiGet(String path,
|
|||||||
///
|
///
|
||||||
/// data: the data to post, it will be marshaled to json automatically.
|
/// data: the data to post, it will be marshaled to json automatically.
|
||||||
/// ok: the function that will be called on success.
|
/// ok: the function that will be called on success.
|
||||||
/// fail:the fuction that will be called on failure.
|
/// fail:the function that will be called on failure.
|
||||||
/// eventually:the function that will be called regardless of success or failure.
|
/// eventually:the function that will be called regardless of success or failure.
|
||||||
Future apiPost(String path, dynamic data,
|
Future apiPost(String path, dynamic data,
|
||||||
{Map<String, String> header,
|
{Map<String, String> header,
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ func DocCommand(_ *cobra.Command, _ []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !pathx.FileExists(dir) {
|
if !pathx.FileExists(dir) {
|
||||||
return fmt.Errorf("dir %s not exsit", dir)
|
return fmt.Errorf("dir %s not exist", dir)
|
||||||
}
|
}
|
||||||
|
|
||||||
dir, err := filepath.Abs(dir)
|
dir, err := filepath.Abs(dir)
|
||||||
|
|||||||
@@ -78,9 +78,6 @@ service Test {
|
|||||||
// Logic should NOT have regular signature: Sse(req *types.SseReq) (resp *types.SseResp, err error)
|
// Logic should NOT have regular signature: Sse(req *types.SseReq) (resp *types.SseResp, err error)
|
||||||
assert.NotContains(t, logicStr, "(resp *types.SseResp, err error)",
|
assert.NotContains(t, logicStr, "(resp *types.SseResp, err error)",
|
||||||
"Logic should not have regular signature with resp return")
|
"Logic should not have regular signature with resp return")
|
||||||
|
|
||||||
t.Logf("Handler content:\n%s", handlerStr)
|
|
||||||
t.Logf("Logic content:\n%s", logicStr)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNonSSEGeneration(t *testing.T) {
|
func TestNonSSEGeneration(t *testing.T) {
|
||||||
@@ -153,7 +150,4 @@ service Test {
|
|||||||
}
|
}
|
||||||
assert.False(t, hasSSESignature,
|
assert.False(t, hasSSESignature,
|
||||||
"Logic should not have SSE signature with client channel parameter")
|
"Logic should not have SSE signature with client channel parameter")
|
||||||
|
|
||||||
t.Logf("Handler content:\n%s", handlerStr)
|
|
||||||
t.Logf("Logic content:\n%s", logicStr)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -38,13 +38,9 @@ func TestServerIntegration(t *testing.T) {
|
|||||||
ctx := svc.NewServiceContext(c)
|
ctx := svc.NewServiceContext(c)
|
||||||
handler.RegisterHandlers(server, ctx)
|
handler.RegisterHandlers(server, ctx)
|
||||||
|
|
||||||
// Start server in background
|
// Create serverless wrapper for testing
|
||||||
go func() {
|
serverless, err := rest.NewServerless(server)
|
||||||
server.Start()
|
require.NoError(t, err)
|
||||||
}()
|
|
||||||
|
|
||||||
// Wait for server to start
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -56,7 +52,7 @@ func TestServerIntegration(t *testing.T) {
|
|||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "health check",
|
name: "health check",
|
||||||
method: "GET",
|
method: http.MethodGet,
|
||||||
path: "/health",
|
path: "/health",
|
||||||
expectedStatus: http.StatusNotFound, // Adjust based on actual routes
|
expectedStatus: http.StatusNotFound, // Adjust based on actual routes
|
||||||
setup: func() {},
|
setup: func() {},
|
||||||
@@ -72,7 +68,7 @@ func TestServerIntegration(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{{end}}{{end}}{
|
{{end}}{{end}}{
|
||||||
name: "not found route",
|
name: "not found route",
|
||||||
method: "GET",
|
method: http.MethodGet,
|
||||||
path: "/nonexistent",
|
path: "/nonexistent",
|
||||||
expectedStatus: http.StatusNotFound,
|
expectedStatus: http.StatusNotFound,
|
||||||
setup: func() {},
|
setup: func() {},
|
||||||
@@ -87,7 +83,7 @@ func TestServerIntegration(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
rr := httptest.NewRecorder()
|
rr := httptest.NewRecorder()
|
||||||
server.ServeHTTP(rr, req)
|
serverless.Serve(rr, req)
|
||||||
|
|
||||||
assert.Equal(t, tt.expectedStatus, rr.Code)
|
assert.Equal(t, tt.expectedStatus, rr.Code)
|
||||||
|
|
||||||
|
|||||||
17
tools/goctl/api/gogen/jwt.api
Executable file
17
tools/goctl/api/gogen/jwt.api
Executable file
@@ -0,0 +1,17 @@
|
|||||||
|
type Request {
|
||||||
|
Name string `path:"name,options=you|me"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Response {
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
@server(
|
||||||
|
jwt: Auth
|
||||||
|
jwtTransition: Trans
|
||||||
|
middleware: TokenValidate
|
||||||
|
)
|
||||||
|
service A-api {
|
||||||
|
@handler GreetHandler
|
||||||
|
get /greet/from/:name(Request) returns (Response)
|
||||||
|
}
|
||||||
@@ -268,7 +268,7 @@ func (v *ApiVisitor) VisitReplybody(ctx *api.ReplybodyContext) any {
|
|||||||
v.panic(lit.Expr(), fmt.Sprintf("expecting 'ID', but found golang keyword '%s'", lit.Expr().Text()))
|
v.panic(lit.Expr(), fmt.Sprintf("expecting 'ID', but found golang keyword '%s'", lit.Expr().Text()))
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
v.panic(dt.Expr(), fmt.Sprintf("unsupport %s", dt.Expr().Text()))
|
v.panic(dt.Expr(), fmt.Sprintf("unsupported %s", dt.Expr().Text()))
|
||||||
}
|
}
|
||||||
case *Literal:
|
case *Literal:
|
||||||
lit := dataType.Literal.Text()
|
lit := dataType.Literal.Text()
|
||||||
@@ -276,7 +276,7 @@ func (v *ApiVisitor) VisitReplybody(ctx *api.ReplybodyContext) any {
|
|||||||
v.panic(dataType.Literal, fmt.Sprintf("expecting 'ID', but found golang keyword '%s'", lit))
|
v.panic(dataType.Literal, fmt.Sprintf("expecting 'ID', but found golang keyword '%s'", lit))
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
v.panic(dt.Expr(), fmt.Sprintf("unsupport %s", dt.Expr().Text()))
|
v.panic(dt.Expr(), fmt.Sprintf("unsupported %s", dt.Expr().Text()))
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Body{
|
return &Body{
|
||||||
|
|||||||
@@ -190,7 +190,7 @@ func (v *ApiVisitor) VisitTypeBlockStruct(ctx *api.TypeBlockStructContext) any {
|
|||||||
structExpr := v.newExprWithToken(ctx.GetStructToken())
|
structExpr := v.newExprWithToken(ctx.GetStructToken())
|
||||||
structTokenText := ctx.GetStructToken().GetText()
|
structTokenText := ctx.GetStructToken().GetText()
|
||||||
if structTokenText != "struct" {
|
if structTokenText != "struct" {
|
||||||
v.panic(structExpr, fmt.Sprintf("expecting 'struct', found imput '%s'", structTokenText))
|
v.panic(structExpr, fmt.Sprintf("expecting 'struct', found input '%s'", structTokenText))
|
||||||
}
|
}
|
||||||
|
|
||||||
if api.IsGolangKeyWord(structTokenText, "struct") {
|
if api.IsGolangKeyWord(structTokenText, "struct") {
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ type parser struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Parse parses the api file.
|
// Parse parses the api file.
|
||||||
// Depreacted: use tools/goctl/pkg/parser/api/parser/parser.go:18 instead,
|
// Deprecated: use tools/goctl/pkg/parser/api/parser/parser.go:18 instead,
|
||||||
// it will be removed in the future.
|
// it will be removed in the future.
|
||||||
func Parse(filename string) (*spec.ApiSpec, error) {
|
func Parse(filename string) (*spec.ApiSpec, error) {
|
||||||
if env.UseExperimental() {
|
if env.UseExperimental() {
|
||||||
@@ -63,14 +63,14 @@ func parseContent(content string, skipCheckTypeDeclaration bool, filename ...str
|
|||||||
return apiSpec, nil
|
return apiSpec, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Depreacted: use tools/goctl/pkg/parser/api/parser/parser.go:18 instead,
|
// Deprecated: use tools/goctl/pkg/parser/api/parser/parser.go:18 instead,
|
||||||
// it will be removed in the future.
|
// it will be removed in the future.
|
||||||
// ParseContent parses the api content
|
// ParseContent parses the api content
|
||||||
func ParseContent(content string, filename ...string) (*spec.ApiSpec, error) {
|
func ParseContent(content string, filename ...string) (*spec.ApiSpec, error) {
|
||||||
return parseContent(content, false, filename...)
|
return parseContent(content, false, filename...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Depreacted: use tools/goctl/pkg/parser/api/parser/parser.go:18 instead,
|
// Deprecated: use tools/goctl/pkg/parser/api/parser/parser.go:18 instead,
|
||||||
// it will be removed in the future.
|
// it will be removed in the future.
|
||||||
// ParseContentWithParserSkipCheckTypeDeclaration parses the api content with skip check type declaration
|
// ParseContentWithParserSkipCheckTypeDeclaration parses the api content with skip check type declaration
|
||||||
func ParseContentWithParserSkipCheckTypeDeclaration(content string, filename ...string) (*spec.ApiSpec, error) {
|
func ParseContentWithParserSkipCheckTypeDeclaration(content string, filename ...string) (*spec.ApiSpec, error) {
|
||||||
@@ -227,7 +227,7 @@ func (p parser) astTypeToSpec(in ast.DataType) spec.Type {
|
|||||||
return spec.PointerType{RawName: v.PointerExpr.Text(), Type: spec.DefineStruct{RawName: raw}}
|
return spec.PointerType{RawName: v.PointerExpr.Text(), Type: spec.DefineStruct{RawName: raw}}
|
||||||
}
|
}
|
||||||
|
|
||||||
panic(fmt.Sprintf("unspported type %+v", in))
|
panic(fmt.Sprintf("unsupported type %+v", in))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p parser) stringExprs(docs []ast.Expr) []string {
|
func (p parser) stringExprs(docs []ast.Expr) []string {
|
||||||
|
|||||||
@@ -24,10 +24,15 @@ func getFirstUsableString(def ...string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, val := range def {
|
for _, val := range def {
|
||||||
str, err := strconv.Unquote(val)
|
// Try to unquote if it's a quoted string
|
||||||
if err == nil && len(str) != 0 {
|
if str, err := strconv.Unquote(val); err == nil && len(str) != 0 {
|
||||||
return str
|
return str
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Otherwise, use the value as-is if it's not empty
|
||||||
|
if len(val) != 0 {
|
||||||
|
return val
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
@@ -89,3 +89,108 @@ func Test_getListFromInfoOrDefault(t *testing.T) {
|
|||||||
assert.Equal(t, []string{"query"}, getListFromInfoOrDefault(unquotedProperties, "tags", []string{"default"}))
|
assert.Equal(t, []string{"query"}, getListFromInfoOrDefault(unquotedProperties, "tags", []string{"default"}))
|
||||||
assert.Equal(t, []string{"default"}, getListFromInfoOrDefault(unquotedProperties, "empty", []string{"default"}))
|
assert.Equal(t, []string{"default"}, getListFromInfoOrDefault(unquotedProperties, "empty", []string{"default"}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_getFirstUsableString(t *testing.T) {
|
||||||
|
t.Run("empty input", func(t *testing.T) {
|
||||||
|
result := getFirstUsableString()
|
||||||
|
assert.Equal(t, "", result, "should return empty string for no arguments")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("single plain string", func(t *testing.T) {
|
||||||
|
result := getFirstUsableString("Check server health status.")
|
||||||
|
assert.Equal(t, "Check server health status.", result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("single quoted string", func(t *testing.T) {
|
||||||
|
// This is how Go would represent a quoted string literal
|
||||||
|
result := getFirstUsableString(`"Check server health status."`)
|
||||||
|
assert.Equal(t, "Check server health status.", result, "should unquote quoted strings")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("multiple plain strings", func(t *testing.T) {
|
||||||
|
result := getFirstUsableString("", "second", "third")
|
||||||
|
assert.Equal(t, "second", result, "should return first non-empty string")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("handler name fallback", func(t *testing.T) {
|
||||||
|
// Simulates the real use case: @doc text, handler name
|
||||||
|
result := getFirstUsableString("", "HealthCheck")
|
||||||
|
assert.Equal(t, "HealthCheck", result, "should fallback to handler name")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("doc text over handler name", func(t *testing.T) {
|
||||||
|
// Simulates the real use case with @doc text
|
||||||
|
result := getFirstUsableString("Check server health status.", "HealthCheck")
|
||||||
|
assert.Equal(t, "Check server health status.", result, "should use doc text over handler name")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty strings before valid", func(t *testing.T) {
|
||||||
|
result := getFirstUsableString("", "", "valid")
|
||||||
|
assert.Equal(t, "valid", result, "should skip empty strings")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("all empty strings", func(t *testing.T) {
|
||||||
|
result := getFirstUsableString("", "", "")
|
||||||
|
assert.Equal(t, "", result, "should return empty if all are empty")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("quoted then plain", func(t *testing.T) {
|
||||||
|
result := getFirstUsableString(`"quoted"`, "plain")
|
||||||
|
assert.Equal(t, "quoted", result, "should unquote first quoted string")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("plain then quoted", func(t *testing.T) {
|
||||||
|
result := getFirstUsableString("plain", `"quoted"`)
|
||||||
|
assert.Equal(t, "plain", result, "should use first plain string")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid quoted string", func(t *testing.T) {
|
||||||
|
// String that looks quoted but isn't valid Go syntax
|
||||||
|
result := getFirstUsableString(`"incomplete`, "fallback")
|
||||||
|
assert.Equal(t, `"incomplete`, result, "should use as-is if unquote fails but not empty")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("whitespace only", func(t *testing.T) {
|
||||||
|
result := getFirstUsableString(" ", "fallback")
|
||||||
|
assert.Equal(t, " ", result, "should not trim whitespace, return as-is")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("real world API doc scenario", func(t *testing.T) {
|
||||||
|
// This is the actual bug scenario from issue #5229
|
||||||
|
atDocText := "Check server health status."
|
||||||
|
handlerName := "HealthCheck"
|
||||||
|
|
||||||
|
result := getFirstUsableString(atDocText, handlerName)
|
||||||
|
assert.Equal(t, "Check server health status.", result,
|
||||||
|
"should use @doc text for API summary")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("real world with empty doc", func(t *testing.T) {
|
||||||
|
// When @doc is empty, should fall back to handler name
|
||||||
|
atDocText := ""
|
||||||
|
handlerName := "HealthCheck"
|
||||||
|
|
||||||
|
result := getFirstUsableString(atDocText, handlerName)
|
||||||
|
assert.Equal(t, "HealthCheck", result,
|
||||||
|
"should fallback to handler name when @doc is empty")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("complex summary with special characters", func(t *testing.T) {
|
||||||
|
result := getFirstUsableString("Get user by ID: /users/{id}")
|
||||||
|
assert.Equal(t, "Get user by ID: /users/{id}", result,
|
||||||
|
"should handle special characters in plain strings")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("multiline string", func(t *testing.T) {
|
||||||
|
result := getFirstUsableString("Line 1\nLine 2")
|
||||||
|
assert.Equal(t, "Line 1\nLine 2", result,
|
||||||
|
"should handle multiline strings")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unicode characters", func(t *testing.T) {
|
||||||
|
result := getFirstUsableString("健康检查", "HealthCheck")
|
||||||
|
assert.Equal(t, "健康检查", result,
|
||||||
|
"should handle unicode characters")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -8,28 +8,37 @@ import (
|
|||||||
apiSpec "github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
apiSpec "github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||||
)
|
)
|
||||||
|
|
||||||
func isPostJson(ctx Context, method string, tp apiSpec.Type) (string, bool) {
|
func isRequestBodyJson(ctx Context, method string, tp apiSpec.Type) (string, bool) {
|
||||||
if !strings.EqualFold(method, http.MethodPost) {
|
// Support HTTP methods that commonly use request bodies with JSON
|
||||||
|
// POST, PUT, PATCH are standard methods with bodies
|
||||||
|
// DELETE can also have a body (though less common)
|
||||||
|
method = strings.ToUpper(method)
|
||||||
|
if method != http.MethodPost && method != http.MethodPut &&
|
||||||
|
method != http.MethodPatch && method != http.MethodDelete {
|
||||||
return "", false
|
return "", false
|
||||||
}
|
}
|
||||||
|
|
||||||
structType, ok := tp.(apiSpec.DefineStruct)
|
structType, ok := tp.(apiSpec.DefineStruct)
|
||||||
if !ok {
|
if !ok {
|
||||||
return "", false
|
return "", false
|
||||||
}
|
}
|
||||||
var isPostJson bool
|
|
||||||
|
var hasJsonField bool
|
||||||
rangeMemberAndDo(ctx, structType, func(tag *apiSpec.Tags, required bool, member apiSpec.Member) {
|
rangeMemberAndDo(ctx, structType, func(tag *apiSpec.Tags, required bool, member apiSpec.Member) {
|
||||||
jsonTag, _ := tag.Get(tagJson)
|
jsonTag, _ := tag.Get(tagJson)
|
||||||
if !isPostJson {
|
if !hasJsonField {
|
||||||
isPostJson = jsonTag != nil
|
hasJsonField = jsonTag != nil
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
return structType.RawName, isPostJson
|
|
||||||
|
return structType.RawName, hasJsonField
|
||||||
}
|
}
|
||||||
|
|
||||||
func parametersFromType(ctx Context, method string, tp apiSpec.Type) []spec.Parameter {
|
func parametersFromType(ctx Context, method string, tp apiSpec.Type) []spec.Parameter {
|
||||||
if tp == nil {
|
if tp == nil {
|
||||||
return []spec.Parameter{}
|
return []spec.Parameter{}
|
||||||
}
|
}
|
||||||
|
|
||||||
structType, ok := tp.(apiSpec.DefineStruct)
|
structType, ok := tp.(apiSpec.DefineStruct)
|
||||||
if !ok {
|
if !ok {
|
||||||
return []spec.Parameter{}
|
return []spec.Parameter{}
|
||||||
@@ -43,15 +52,13 @@ func parametersFromType(ctx Context, method string, tp apiSpec.Type) []spec.Para
|
|||||||
rangeMemberAndDo(ctx, structType, func(tag *apiSpec.Tags, required bool, member apiSpec.Member) {
|
rangeMemberAndDo(ctx, structType, func(tag *apiSpec.Tags, required bool, member apiSpec.Member) {
|
||||||
headerTag, _ := tag.Get(tagHeader)
|
headerTag, _ := tag.Get(tagHeader)
|
||||||
hasHeader := headerTag != nil
|
hasHeader := headerTag != nil
|
||||||
|
|
||||||
pathParameterTag, _ := tag.Get(tagPath)
|
pathParameterTag, _ := tag.Get(tagPath)
|
||||||
hasPathParameter := pathParameterTag != nil
|
hasPathParameter := pathParameterTag != nil
|
||||||
|
|
||||||
formTag, _ := tag.Get(tagForm)
|
formTag, _ := tag.Get(tagForm)
|
||||||
hasForm := formTag != nil
|
hasForm := formTag != nil
|
||||||
|
|
||||||
jsonTag, _ := tag.Get(tagJson)
|
jsonTag, _ := tag.Get(tagJson)
|
||||||
hasJson := jsonTag != nil
|
hasJson := jsonTag != nil
|
||||||
|
|
||||||
if hasHeader {
|
if hasHeader {
|
||||||
minimum, maximum, exclusiveMinimum, exclusiveMaximum := rangeValueFromOptions(headerTag.Options)
|
minimum, maximum, exclusiveMinimum, exclusiveMaximum := rangeValueFromOptions(headerTag.Options)
|
||||||
resp = append(resp, spec.Parameter{
|
resp = append(resp, spec.Parameter{
|
||||||
@@ -75,6 +82,7 @@ func parametersFromType(ctx Context, method string, tp apiSpec.Type) []spec.Para
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
if hasPathParameter {
|
if hasPathParameter {
|
||||||
minimum, maximum, exclusiveMinimum, exclusiveMaximum := rangeValueFromOptions(pathParameterTag.Options)
|
minimum, maximum, exclusiveMinimum, exclusiveMaximum := rangeValueFromOptions(pathParameterTag.Options)
|
||||||
resp = append(resp, spec.Parameter{
|
resp = append(resp, spec.Parameter{
|
||||||
@@ -98,6 +106,7 @@ func parametersFromType(ctx Context, method string, tp apiSpec.Type) []spec.Para
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
if hasForm {
|
if hasForm {
|
||||||
minimum, maximum, exclusiveMinimum, exclusiveMaximum := rangeValueFromOptions(formTag.Options)
|
minimum, maximum, exclusiveMinimum, exclusiveMaximum := rangeValueFromOptions(formTag.Options)
|
||||||
if strings.EqualFold(method, http.MethodGet) {
|
if strings.EqualFold(method, http.MethodGet) {
|
||||||
@@ -145,8 +154,8 @@ func parametersFromType(ctx Context, method string, tp apiSpec.Type) []spec.Para
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if hasJson {
|
if hasJson {
|
||||||
minimum, maximum, exclusiveMinimum, exclusiveMaximum := rangeValueFromOptions(jsonTag.Options)
|
minimum, maximum, exclusiveMinimum, exclusiveMaximum := rangeValueFromOptions(jsonTag.Options)
|
||||||
if required {
|
if required {
|
||||||
@@ -179,9 +188,10 @@ func parametersFromType(ctx Context, method string, tp apiSpec.Type) []spec.Para
|
|||||||
properties[jsonTag.Name] = schema
|
properties[jsonTag.Name] = schema
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
if len(properties) > 0 {
|
if len(properties) > 0 {
|
||||||
if ctx.UseDefinitions {
|
if ctx.UseDefinitions {
|
||||||
structName, ok := isPostJson(ctx, method, tp)
|
structName, ok := isRequestBodyJson(ctx, method, tp)
|
||||||
if ok {
|
if ok {
|
||||||
resp = append(resp, spec.Parameter{
|
resp = append(resp, spec.Parameter{
|
||||||
ParamProps: spec.ParamProps{
|
ParamProps: spec.ParamProps{
|
||||||
@@ -213,5 +223,6 @@ func parametersFromType(ctx Context, method string, tp apiSpec.Type) []spec.Para
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return resp
|
return resp
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import (
|
|||||||
apiSpec "github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
apiSpec "github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestIsPostJson(t *testing.T) {
|
func TestIsRequestBodyJson(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
method string
|
method string
|
||||||
@@ -18,13 +18,18 @@ func TestIsPostJson(t *testing.T) {
|
|||||||
{"POST with JSON", http.MethodPost, true, true},
|
{"POST with JSON", http.MethodPost, true, true},
|
||||||
{"POST without JSON", http.MethodPost, false, false},
|
{"POST without JSON", http.MethodPost, false, false},
|
||||||
{"GET with JSON", http.MethodGet, true, false},
|
{"GET with JSON", http.MethodGet, true, false},
|
||||||
{"PUT with JSON", http.MethodPut, true, false},
|
{"PUT with JSON", http.MethodPut, true, true},
|
||||||
|
{"PUT without JSON", http.MethodPut, false, false},
|
||||||
|
{"PATCH with JSON", http.MethodPatch, true, true},
|
||||||
|
{"PATCH without JSON", http.MethodPatch, false, false},
|
||||||
|
{"DELETE with JSON", http.MethodDelete, true, true},
|
||||||
|
{"DELETE without JSON", http.MethodDelete, false, false},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
testStruct := createTestStruct("TestStruct", tt.hasJson)
|
testStruct := createTestStruct("TestStruct", tt.hasJson)
|
||||||
_, result := isPostJson(testingContext(t), tt.method, testStruct)
|
_, result := isRequestBodyJson(testingContext(t), tt.method, testStruct)
|
||||||
assert.Equal(t, tt.expected, result)
|
assert.Equal(t, tt.expected, result)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -41,6 +46,12 @@ func TestParametersFromType(t *testing.T) {
|
|||||||
}{
|
}{
|
||||||
{"POST JSON with definitions", http.MethodPost, true, true, 1, true},
|
{"POST JSON with definitions", http.MethodPost, true, true, 1, true},
|
||||||
{"POST JSON without definitions", http.MethodPost, false, true, 1, true},
|
{"POST JSON without definitions", http.MethodPost, false, true, 1, true},
|
||||||
|
{"PUT JSON with definitions", http.MethodPut, true, true, 1, true},
|
||||||
|
{"PUT JSON without definitions", http.MethodPut, false, true, 1, true},
|
||||||
|
{"PATCH JSON with definitions", http.MethodPatch, true, true, 1, true},
|
||||||
|
{"PATCH JSON without definitions", http.MethodPatch, false, true, 1, true},
|
||||||
|
{"DELETE JSON with definitions", http.MethodDelete, true, true, 1, true},
|
||||||
|
{"DELETE JSON without definitions", http.MethodDelete, false, true, 1, true},
|
||||||
{"GET with form", http.MethodGet, false, false, 1, false},
|
{"GET with form", http.MethodGet, false, false, 1, false},
|
||||||
{"POST with form", http.MethodPost, false, false, 1, false},
|
{"POST with form", http.MethodPost, false, false, 1, false},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,7 +11,7 @@
|
|||||||
## swagger
|
## swagger
|
||||||
1. [bug fix] remove example generation when request body are `query`, `path` and `header`
|
1. [bug fix] remove example generation when request body are `query`, `path` and `header`
|
||||||
- it not supported in api spec 2.0
|
- it not supported in api spec 2.0
|
||||||
- it's will generate example when request body is json format.
|
- it will generate example when request body is json format.
|
||||||
2. [features] swagger generation supported definitions
|
2. [features] swagger generation supported definitions
|
||||||
- supported response definitions
|
- supported response definitions
|
||||||
- supported json request body definitions
|
- supported json request body definitions
|
||||||
|
|||||||
@@ -73,6 +73,7 @@ func dockerCommand(_ *cobra.Command, _ []string) (err error) {
|
|||||||
|
|
||||||
base := varStringBase
|
base := varStringBase
|
||||||
port := varIntPort
|
port := varIntPort
|
||||||
|
etcDir := filepath.Join(filepath.Dir(goFile), etcDir)
|
||||||
if _, err := os.Stat(etcDir); os.IsNotExist(err) {
|
if _, err := os.Stat(etcDir); os.IsNotExist(err) {
|
||||||
return generateDockerfile(goFile, base, port, version, timezone)
|
return generateDockerfile(goFile, base, port, version, timezone)
|
||||||
}
|
}
|
||||||
@@ -170,7 +171,7 @@ func generateDockerfile(goFile, base string, port int, version, timezone string,
|
|||||||
t := template.Must(template.New("dockerfile").Parse(text))
|
t := template.Must(template.New("dockerfile").Parse(text))
|
||||||
return t.Execute(out, Docker{
|
return t.Execute(out, Docker{
|
||||||
Chinese: env.InChina(),
|
Chinese: env.InChina(),
|
||||||
GoMainFrom: path.Join(projPath, goFile),
|
GoMainFrom: path.Join(projPath, filepath.Base(goFile)),
|
||||||
GoRelPath: projPath,
|
GoRelPath: projPath,
|
||||||
GoFile: goFile,
|
GoFile: goFile,
|
||||||
ExeFile: exeName,
|
ExeFile: exeName,
|
||||||
|
|||||||
376
tools/goctl/docker/docker_test.go
Normal file
376
tools/goctl/docker/docker_test.go
Normal file
@@ -0,0 +1,376 @@
|
|||||||
|
package docker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDockerCommand_EtcDirResolution(t *testing.T) {
|
||||||
|
// Create a temporary project structure
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
|
// Create project structure: project/service/api/
|
||||||
|
serviceDir := filepath.Join(tempDir, "service", "api")
|
||||||
|
etcDir := filepath.Join(serviceDir, "etc")
|
||||||
|
require.NoError(t, os.MkdirAll(etcDir, 0755))
|
||||||
|
|
||||||
|
// Create a Go file
|
||||||
|
goFile := filepath.Join(serviceDir, "api.go")
|
||||||
|
require.NoError(t, os.WriteFile(goFile, []byte("package main\n\nfunc main() {}"), 0644))
|
||||||
|
|
||||||
|
// Create a config file
|
||||||
|
configFile := filepath.Join(etcDir, "config.yaml")
|
||||||
|
require.NoError(t, os.WriteFile(configFile, []byte("Name: test\n"), 0644))
|
||||||
|
|
||||||
|
// Create go.mod at the root
|
||||||
|
goModFile := filepath.Join(tempDir, "go.mod")
|
||||||
|
require.NoError(t, os.WriteFile(goModFile, []byte("module test\n\ngo 1.21\n"), 0644))
|
||||||
|
|
||||||
|
// Test: etc directory should be found relative to Go file, not CWD
|
||||||
|
t.Run("etc directory resolved relative to go file", func(t *testing.T) {
|
||||||
|
// Save and restore original working directory
|
||||||
|
originalWd, err := os.Getwd()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, os.Chdir(originalWd))
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Change to temp directory (not service/api directory)
|
||||||
|
require.NoError(t, os.Chdir(tempDir))
|
||||||
|
|
||||||
|
// The relative path from tempDir to the go file
|
||||||
|
relGoFile := filepath.Join("service", "api", "api.go")
|
||||||
|
|
||||||
|
// Test the etc directory resolution logic
|
||||||
|
resolvedEtcDir := filepath.Join(filepath.Dir(relGoFile), "etc")
|
||||||
|
|
||||||
|
// Verify the resolved path exists
|
||||||
|
_, err = os.Stat(resolvedEtcDir)
|
||||||
|
assert.NoError(t, err, "etc directory should be found at service/api/etc")
|
||||||
|
|
||||||
|
// Verify it's the correct path (use EvalSymlinks to handle /private on macOS)
|
||||||
|
absResolvedEtc, err := filepath.Abs(resolvedEtcDir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
absResolvedEtc, err = filepath.EvalSymlinks(absResolvedEtc)
|
||||||
|
require.NoError(t, err)
|
||||||
|
expectedEtc, err := filepath.EvalSymlinks(etcDir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, expectedEtc, absResolvedEtc)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("etc directory with empty goFile", func(t *testing.T) {
|
||||||
|
// When goFile is empty, should default to "./etc"
|
||||||
|
goFile := ""
|
||||||
|
resolvedEtcDir := filepath.Join(filepath.Dir(goFile), "etc")
|
||||||
|
|
||||||
|
// Should resolve to just "etc"
|
||||||
|
assert.Equal(t, "etc", resolvedEtcDir)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("etc directory with absolute path", func(t *testing.T) {
|
||||||
|
// When goFile is absolute path
|
||||||
|
absGoFile := filepath.Join(tempDir, "service", "api", "api.go")
|
||||||
|
resolvedEtcDir := filepath.Join(filepath.Dir(absGoFile), "etc")
|
||||||
|
|
||||||
|
// Should resolve correctly
|
||||||
|
_, err := os.Stat(resolvedEtcDir)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateDockerfile_GoMainFromPath(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
goFile string
|
||||||
|
projPath string
|
||||||
|
expectedPath string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "relative path with subdirectory",
|
||||||
|
goFile: "service/api/api.go",
|
||||||
|
projPath: "service/api",
|
||||||
|
expectedPath: "service/api/api.go",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "simple filename",
|
||||||
|
goFile: "main.go",
|
||||||
|
projPath: ".",
|
||||||
|
expectedPath: "main.go",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nested service path",
|
||||||
|
goFile: "internal/service/user/user.go",
|
||||||
|
projPath: "internal/service/user",
|
||||||
|
expectedPath: "internal/service/user/user.go",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "deep nested path",
|
||||||
|
goFile: "cmd/api/internal/handler/handler.go",
|
||||||
|
projPath: "cmd/api/internal/handler",
|
||||||
|
expectedPath: "cmd/api/internal/handler/handler.go",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Simulate the fix: using filepath.Base instead of full path
|
||||||
|
goMainFrom := filepath.Join(tt.projPath, filepath.Base(tt.goFile))
|
||||||
|
|
||||||
|
assert.Equal(t, tt.expectedPath, goMainFrom,
|
||||||
|
"GoMainFrom should not duplicate path segments")
|
||||||
|
|
||||||
|
// Verify the old buggy behavior would have been wrong
|
||||||
|
if tt.goFile != filepath.Base(tt.goFile) {
|
||||||
|
buggyPath := filepath.Join(tt.projPath, tt.goFile)
|
||||||
|
assert.NotEqual(t, tt.expectedPath, buggyPath,
|
||||||
|
"Old implementation would have created incorrect path")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateDockerfile_PathJoinBehavior(t *testing.T) {
|
||||||
|
t.Run("demonstrates the bug and fix", func(t *testing.T) {
|
||||||
|
projPath := "service/api"
|
||||||
|
goFile := "service/api/api.go"
|
||||||
|
|
||||||
|
// OLD (buggy) behavior: path duplication
|
||||||
|
buggyPath := filepath.Join(projPath, goFile)
|
||||||
|
assert.Equal(t, "service/api/service/api/api.go", buggyPath,
|
||||||
|
"Bug: path segments are duplicated")
|
||||||
|
|
||||||
|
// NEW (fixed) behavior: correct path
|
||||||
|
fixedPath := filepath.Join(projPath, filepath.Base(goFile))
|
||||||
|
assert.Equal(t, "service/api/api.go", fixedPath,
|
||||||
|
"Fix: using filepath.Base prevents duplication")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFindConfig(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
etcDir := filepath.Join(tempDir, "etc")
|
||||||
|
require.NoError(t, os.MkdirAll(etcDir, 0755))
|
||||||
|
|
||||||
|
t.Run("finds config matching go file name", func(t *testing.T) {
|
||||||
|
// Create config files
|
||||||
|
require.NoError(t, os.WriteFile(
|
||||||
|
filepath.Join(etcDir, "api.yaml"), []byte("test"), 0644))
|
||||||
|
require.NoError(t, os.WriteFile(
|
||||||
|
filepath.Join(etcDir, "other.yaml"), []byte("test"), 0644))
|
||||||
|
|
||||||
|
cfg, err := findConfig("api.go", etcDir)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "api.yaml", cfg)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns first config when no match", func(t *testing.T) {
|
||||||
|
etcDir2 := filepath.Join(tempDir, "etc2")
|
||||||
|
require.NoError(t, os.MkdirAll(etcDir2, 0755))
|
||||||
|
require.NoError(t, os.WriteFile(
|
||||||
|
filepath.Join(etcDir2, "config.yaml"), []byte("test"), 0644))
|
||||||
|
|
||||||
|
cfg, err := findConfig("main.go", etcDir2)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "config.yaml", cfg)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns error when no yaml files", func(t *testing.T) {
|
||||||
|
emptyDir := filepath.Join(tempDir, "empty")
|
||||||
|
require.NoError(t, os.MkdirAll(emptyDir, 0755))
|
||||||
|
|
||||||
|
_, err := findConfig("api.go", emptyDir)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "no yaml file")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("handles path in go file name", func(t *testing.T) {
|
||||||
|
// Test with service/api/api.go - should extract just "api"
|
||||||
|
cfg, err := findConfig("service/api/api.go", etcDir)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "api.yaml", cfg)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFilePath(t *testing.T) {
|
||||||
|
// Create a temporary directory with go.mod
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
require.NoError(t, os.WriteFile(
|
||||||
|
filepath.Join(tempDir, "go.mod"),
|
||||||
|
[]byte("module testproject\n\ngo 1.21\n"),
|
||||||
|
0644,
|
||||||
|
))
|
||||||
|
|
||||||
|
// Create subdirectories
|
||||||
|
serviceDir := filepath.Join(tempDir, "service", "api")
|
||||||
|
require.NoError(t, os.MkdirAll(serviceDir, 0755))
|
||||||
|
|
||||||
|
originalWd, err := os.Getwd()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, os.Chdir(originalWd))
|
||||||
|
}()
|
||||||
|
|
||||||
|
t.Run("returns relative path from go.mod", func(t *testing.T) {
|
||||||
|
require.NoError(t, os.Chdir(tempDir))
|
||||||
|
|
||||||
|
path, err := getFilePath("service/api")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "service/api", path)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("handles current directory", func(t *testing.T) {
|
||||||
|
require.NoError(t, os.Chdir(tempDir))
|
||||||
|
|
||||||
|
path, err := getFilePath(".")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
// Current directory returns empty string when at go.mod root
|
||||||
|
assert.True(t, path == "." || path == "")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Integration test to verify the complete fix
|
||||||
|
func TestDockerCommandIntegration(t *testing.T) {
|
||||||
|
// Create a complete project structure
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
|
// Setup: project/service/api/
|
||||||
|
serviceDir := filepath.Join(tempDir, "service", "api")
|
||||||
|
etcDir := filepath.Join(serviceDir, "etc")
|
||||||
|
require.NoError(t, os.MkdirAll(etcDir, 0755))
|
||||||
|
|
||||||
|
// Create files
|
||||||
|
goFile := filepath.Join(serviceDir, "api.go")
|
||||||
|
require.NoError(t, os.WriteFile(goFile, []byte("package main\n\nfunc main() {}"), 0644))
|
||||||
|
configFile := filepath.Join(etcDir, "api.yaml")
|
||||||
|
require.NoError(t, os.WriteFile(configFile, []byte("Name: test-api\n"), 0644))
|
||||||
|
goModFile := filepath.Join(tempDir, "go.mod")
|
||||||
|
require.NoError(t, os.WriteFile(goModFile, []byte("module testproject\n\ngo 1.21\n"), 0644))
|
||||||
|
goSumFile := filepath.Join(tempDir, "go.sum")
|
||||||
|
require.NoError(t, os.WriteFile(goSumFile, []byte(""), 0644))
|
||||||
|
|
||||||
|
originalWd, err := os.Getwd()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, os.Chdir(originalWd))
|
||||||
|
}()
|
||||||
|
|
||||||
|
t.Run("etc directory detected from different working directory", func(t *testing.T) {
|
||||||
|
// Change to project root (not service/api)
|
||||||
|
require.NoError(t, os.Chdir(tempDir))
|
||||||
|
|
||||||
|
// Relative path to Go file
|
||||||
|
relGoFile := filepath.Join("service", "api", "api.go")
|
||||||
|
|
||||||
|
// Apply the fix: resolve etc directory relative to go file
|
||||||
|
resolvedEtcDir := filepath.Join(filepath.Dir(relGoFile), "etc")
|
||||||
|
|
||||||
|
// Verify etc directory is found
|
||||||
|
stat, err := os.Stat(resolvedEtcDir)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, stat.IsDir())
|
||||||
|
|
||||||
|
// Verify config can be found
|
||||||
|
cfg, err := findConfig(relGoFile, resolvedEtcDir)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "api.yaml", cfg)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("GoMainFrom path is correct", func(t *testing.T) {
|
||||||
|
require.NoError(t, os.Chdir(tempDir))
|
||||||
|
|
||||||
|
goFileRel := filepath.Join("service", "api", "api.go")
|
||||||
|
|
||||||
|
// Simulate getFilePath return value
|
||||||
|
projPath := "service/api"
|
||||||
|
|
||||||
|
// Apply the fix: use filepath.Base
|
||||||
|
goMainFrom := filepath.Join(projPath, filepath.Base(goFileRel))
|
||||||
|
|
||||||
|
assert.Equal(t, "service/api/api.go", goMainFrom)
|
||||||
|
|
||||||
|
// Verify no path duplication
|
||||||
|
assert.NotContains(t, goMainFrom, "service/api/service/api")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that specifically validates the bug described in PR #4343
|
||||||
|
func TestPR4343_BugFixes(t *testing.T) {
|
||||||
|
t.Run("Bug 1: etc directory check uses correct base path", func(t *testing.T) {
|
||||||
|
// Setup: Create a project structure where etc is NOT in CWD but IS relative to Go file
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
serviceDir := filepath.Join(tempDir, "service", "api")
|
||||||
|
etcDir := filepath.Join(serviceDir, "etc")
|
||||||
|
require.NoError(t, os.MkdirAll(etcDir, 0755))
|
||||||
|
|
||||||
|
// Create a config file
|
||||||
|
require.NoError(t, os.WriteFile(
|
||||||
|
filepath.Join(etcDir, "config.yaml"),
|
||||||
|
[]byte("Name: test\n"),
|
||||||
|
0644,
|
||||||
|
))
|
||||||
|
|
||||||
|
originalWd, err := os.Getwd()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, os.Chdir(originalWd))
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Change to project root (CWD = tempDir)
|
||||||
|
require.NoError(t, os.Chdir(tempDir))
|
||||||
|
|
||||||
|
goFile := filepath.Join("service", "api", "api.go")
|
||||||
|
|
||||||
|
// OLD (buggy) behavior: checks for "etc" in CWD
|
||||||
|
_, errOld := os.Stat("etc")
|
||||||
|
assert.Error(t, errOld, "Bug: etc not found in CWD")
|
||||||
|
|
||||||
|
// NEW (fixed) behavior: checks for "etc" relative to go file
|
||||||
|
etcDirResolved := filepath.Join(filepath.Dir(goFile), "etc")
|
||||||
|
stat, errNew := os.Stat(etcDirResolved)
|
||||||
|
assert.NoError(t, errNew, "Fix: etc found relative to go file")
|
||||||
|
assert.True(t, stat.IsDir())
|
||||||
|
|
||||||
|
// Verify config is accessible
|
||||||
|
cfg, err := findConfig(goFile, etcDirResolved)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "config.yaml", cfg)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Bug 2: GoMainFrom path not duplicated", func(t *testing.T) {
|
||||||
|
// Test case from PR description
|
||||||
|
projPath := "service/api"
|
||||||
|
goFile := "service/api/api.go"
|
||||||
|
|
||||||
|
// OLD (buggy) behavior: duplicates path
|
||||||
|
buggyPath := filepath.Join(projPath, goFile)
|
||||||
|
assert.Equal(t, "service/api/service/api/api.go", buggyPath,
|
||||||
|
"Bug: path duplication occurs with old implementation")
|
||||||
|
|
||||||
|
// NEW (fixed) behavior: correct path using filepath.Base
|
||||||
|
fixedPath := filepath.Join(projPath, filepath.Base(goFile))
|
||||||
|
assert.Equal(t, "service/api/api.go", fixedPath,
|
||||||
|
"Fix: using filepath.Base() prevents path duplication")
|
||||||
|
|
||||||
|
// Verify the fix works for various scenarios
|
||||||
|
testCases := []struct {
|
||||||
|
projPath string
|
||||||
|
goFile string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"service/api", "service/api/api.go", "service/api/api.go"},
|
||||||
|
{"cmd/server", "cmd/server/main.go", "cmd/server/main.go"},
|
||||||
|
{"internal/handler", "internal/handler/handler.go", "internal/handler/handler.go"},
|
||||||
|
{".", "main.go", "main.go"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
result := filepath.Join(tc.projPath, filepath.Base(tc.goFile))
|
||||||
|
assert.Equal(t, tc.expected, result,
|
||||||
|
"Fix should work for projPath=%s, goFile=%s", tc.projPath, tc.goFile)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -16,7 +16,7 @@ require (
|
|||||||
github.com/withfig/autocomplete-tools/integrations/cobra v1.2.1
|
github.com/withfig/autocomplete-tools/integrations/cobra v1.2.1
|
||||||
github.com/zeromicro/antlr v0.0.1
|
github.com/zeromicro/antlr v0.0.1
|
||||||
github.com/zeromicro/ddl-parser v1.0.5
|
github.com/zeromicro/ddl-parser v1.0.5
|
||||||
github.com/zeromicro/go-zero v1.9.1
|
github.com/zeromicro/go-zero v1.9.2
|
||||||
golang.org/x/text v0.22.0
|
golang.org/x/text v0.22.0
|
||||||
google.golang.org/grpc v1.65.0
|
google.golang.org/grpc v1.65.0
|
||||||
google.golang.org/protobuf v1.36.5
|
google.golang.org/protobuf v1.36.5
|
||||||
|
|||||||
@@ -185,8 +185,8 @@ github.com/zeromicro/antlr v0.0.1 h1:CQpIn/dc0pUjgGQ81y98s/NGOm2Hfru2NNio2I9mQgk
|
|||||||
github.com/zeromicro/antlr v0.0.1/go.mod h1:nfpjEwFR6Q4xGDJMcZnCL9tEfQRgszMwu3rDz2Z+p5M=
|
github.com/zeromicro/antlr v0.0.1/go.mod h1:nfpjEwFR6Q4xGDJMcZnCL9tEfQRgszMwu3rDz2Z+p5M=
|
||||||
github.com/zeromicro/ddl-parser v1.0.5 h1:LaVqHdzMTjasua1yYpIYaksxKqRzFrEukj2Wi2EbWaQ=
|
github.com/zeromicro/ddl-parser v1.0.5 h1:LaVqHdzMTjasua1yYpIYaksxKqRzFrEukj2Wi2EbWaQ=
|
||||||
github.com/zeromicro/ddl-parser v1.0.5/go.mod h1:ISU/8NuPyEpl9pa17Py9TBPetMjtsiHrb9f5XGiYbo8=
|
github.com/zeromicro/ddl-parser v1.0.5/go.mod h1:ISU/8NuPyEpl9pa17Py9TBPetMjtsiHrb9f5XGiYbo8=
|
||||||
github.com/zeromicro/go-zero v1.9.1 h1:GZCl4jun/ZgZHnSvX3SSNDHf+tEGmEQ8x2Z23xjHa9g=
|
github.com/zeromicro/go-zero v1.9.2 h1:ZXOXBIcazZ1pWAMiHyVnDQ3Sxwy7DYPzjE89Qtj9vqM=
|
||||||
github.com/zeromicro/go-zero v1.9.1/go.mod h1:bHOl7Xr7EV/iHZWEqsUNJwFc/9WgAMrPpPagYvOaMtY=
|
github.com/zeromicro/go-zero v1.9.2/go.mod h1:k8YBMEFZKjTd4q/qO5RCW+zDgUlNyAs5vue3P4/Kmn0=
|
||||||
go.etcd.io/etcd/api/v3 v3.5.15 h1:3KpLJir1ZEBrYuV2v+Twaa/e2MdDCEZ/70H+lzEiwsk=
|
go.etcd.io/etcd/api/v3 v3.5.15 h1:3KpLJir1ZEBrYuV2v+Twaa/e2MdDCEZ/70H+lzEiwsk=
|
||||||
go.etcd.io/etcd/api/v3 v3.5.15/go.mod h1:N9EhGzXq58WuMllgH9ZvnEr7SI9pS0k0+DHZezGp7jM=
|
go.etcd.io/etcd/api/v3 v3.5.15/go.mod h1:N9EhGzXq58WuMllgH9ZvnEr7SI9pS0k0+DHZezGp7jM=
|
||||||
go.etcd.io/etcd/client/pkg/v3 v3.5.15 h1:fo0HpWz/KlHGMCC+YejpiCmyWDEuIpnTDzpJLB5fWlA=
|
go.etcd.io/etcd/client/pkg/v3 v3.5.15 h1:fo0HpWz/KlHGMCC+YejpiCmyWDEuIpnTDzpJLB5fWlA=
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// BuildVersion is the version of goctl.
|
// BuildVersion is the version of goctl.
|
||||||
const BuildVersion = "1.9.1"
|
const BuildVersion = "1.9.2"
|
||||||
|
|
||||||
var tag = map[string]int{"pre-alpha": 0, "alpha": 1, "pre-beta": 2, "beta": 3, "released": 4, "": 5}
|
var tag = map[string]int{"pre-alpha": 0, "alpha": 1, "pre-beta": 2, "beta": 3, "released": 4, "": 5}
|
||||||
|
|
||||||
|
|||||||
@@ -99,12 +99,12 @@ func (conn *MockConn) RawDB() (*sql.DB, error) {
|
|||||||
return conn.db, nil
|
return conn.db, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Transact is the implemention of sqlx.SqlConn, nothing to do
|
// Transact is the implementation of sqlx.SqlConn, nothing to do
|
||||||
func (conn *MockConn) Transact(func(session sqlx.Session) error) error {
|
func (conn *MockConn) Transact(func(session sqlx.Session) error) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TransactCtx is the implemention of sqlx.SqlConn, nothing to do
|
// TransactCtx is the implementation of sqlx.SqlConn, nothing to do
|
||||||
func (conn *MockConn) TransactCtx(ctx context.Context, fn func(context.Context, sqlx.Session) error) error {
|
func (conn *MockConn) TransactCtx(ctx context.Context, fn func(context.Context, sqlx.Session) error) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package util
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/util/console"
|
"github.com/zeromicro/go-zero/tools/goctl/util/console"
|
||||||
@@ -130,14 +129,3 @@ func FieldsAndTrimSpace(s string, f func(r rune) bool) []string {
|
|||||||
}
|
}
|
||||||
return resp
|
return resp
|
||||||
}
|
}
|
||||||
|
|
||||||
//Deprecated: This function implementation is incomplete and does not properly handle exceptional input cases.
|
|
||||||
//We strongly recommend using the standard library's strconv.Unquote function instead,
|
|
||||||
//which provides robust error handling and comprehensive support for various input formats.
|
|
||||||
func Unquote(s string) string {
|
|
||||||
ns, err := strconv.Unquote(s)
|
|
||||||
if err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return ns
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -120,20 +120,3 @@ func TestFieldsAndTrimSpace(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUnquote(t *testing.T) {
|
|
||||||
testCases := []struct {
|
|
||||||
input string
|
|
||||||
expected string
|
|
||||||
}{
|
|
||||||
{input: `"hello"`, expected: `hello`},
|
|
||||||
{input: "`world`", expected: `world`},
|
|
||||||
{input: `"foo'bar"`, expected: `foo'bar`},
|
|
||||||
{input: "", expected: ""},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
result := Unquote(tc.input)
|
|
||||||
assert.Equal(t, tc.expected, result)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,12 +1,16 @@
|
|||||||
package zrpc
|
package zrpc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/core/conf"
|
"github.com/zeromicro/go-zero/core/conf"
|
||||||
"github.com/zeromicro/go-zero/core/logx"
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
"github.com/zeromicro/go-zero/zrpc/internal"
|
"github.com/zeromicro/go-zero/zrpc/internal"
|
||||||
"github.com/zeromicro/go-zero/zrpc/internal/auth"
|
"github.com/zeromicro/go-zero/zrpc/internal/auth"
|
||||||
|
"github.com/zeromicro/go-zero/zrpc/internal/balancer/consistenthash"
|
||||||
|
"github.com/zeromicro/go-zero/zrpc/internal/balancer/p2c"
|
||||||
"github.com/zeromicro/go-zero/zrpc/internal/clientinterceptors"
|
"github.com/zeromicro/go-zero/zrpc/internal/clientinterceptors"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/keepalive"
|
"google.golang.org/grpc/keepalive"
|
||||||
@@ -17,6 +21,9 @@ var (
|
|||||||
WithDialOption = internal.WithDialOption
|
WithDialOption = internal.WithDialOption
|
||||||
// WithNonBlock sets the dialing to be nonblock.
|
// WithNonBlock sets the dialing to be nonblock.
|
||||||
WithNonBlock = internal.WithNonBlock
|
WithNonBlock = internal.WithNonBlock
|
||||||
|
// WithBlock sets the dialing to be blocking.
|
||||||
|
// Deprecated: blocking dials are not recommended by gRPC.
|
||||||
|
WithBlock = internal.WithBlock
|
||||||
// WithStreamClientInterceptor is an alias of internal.WithStreamClientInterceptor.
|
// WithStreamClientInterceptor is an alias of internal.WithStreamClientInterceptor.
|
||||||
WithStreamClientInterceptor = internal.WithStreamClientInterceptor
|
WithStreamClientInterceptor = internal.WithStreamClientInterceptor
|
||||||
// WithTimeout is an alias of internal.WithTimeout.
|
// WithTimeout is an alias of internal.WithTimeout.
|
||||||
@@ -57,6 +64,8 @@ func NewClient(c RpcClientConf, options ...ClientOption) (Client, error) {
|
|||||||
}
|
}
|
||||||
if c.NonBlock {
|
if c.NonBlock {
|
||||||
opts = append(opts, WithNonBlock())
|
opts = append(opts, WithNonBlock())
|
||||||
|
} else {
|
||||||
|
opts = append(opts, WithBlock())
|
||||||
}
|
}
|
||||||
if c.Timeout > 0 {
|
if c.Timeout > 0 {
|
||||||
opts = append(opts, WithTimeout(time.Duration(c.Timeout)*time.Millisecond))
|
opts = append(opts, WithTimeout(time.Duration(c.Timeout)*time.Millisecond))
|
||||||
@@ -67,6 +76,9 @@ func NewClient(c RpcClientConf, options ...ClientOption) (Client, error) {
|
|||||||
})))
|
})))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
svcCfg := makeLBServiceConfig(c.BalancerName)
|
||||||
|
opts = append(opts, WithDialOption(grpc.WithDefaultServiceConfig(svcCfg)))
|
||||||
|
|
||||||
opts = append(opts, options...)
|
opts = append(opts, options...)
|
||||||
|
|
||||||
target, err := c.BuildTarget()
|
target, err := c.BuildTarget()
|
||||||
@@ -111,7 +123,20 @@ func SetClientSlowThreshold(threshold time.Duration) {
|
|||||||
clientinterceptors.SetSlowThreshold(threshold)
|
clientinterceptors.SetSlowThreshold(threshold)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetHashKey sets the hash key into context.
|
||||||
|
func SetHashKey(ctx context.Context, key string) context.Context {
|
||||||
|
return consistenthash.SetHashKey(ctx, key)
|
||||||
|
}
|
||||||
|
|
||||||
// WithCallTimeout return a call option with given timeout to make a method call.
|
// WithCallTimeout return a call option with given timeout to make a method call.
|
||||||
func WithCallTimeout(timeout time.Duration) grpc.CallOption {
|
func WithCallTimeout(timeout time.Duration) grpc.CallOption {
|
||||||
return clientinterceptors.WithCallTimeout(timeout)
|
return clientinterceptors.WithCallTimeout(timeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func makeLBServiceConfig(balancerName string) string {
|
||||||
|
if len(balancerName) == 0 {
|
||||||
|
balancerName = p2c.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf(`{"loadBalancingPolicy":"%s"}`, balancerName)
|
||||||
|
}
|
||||||
|
|||||||
@@ -12,6 +12,8 @@ import (
|
|||||||
"github.com/zeromicro/go-zero/core/discov"
|
"github.com/zeromicro/go-zero/core/discov"
|
||||||
"github.com/zeromicro/go-zero/core/logx"
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
"github.com/zeromicro/go-zero/internal/mock"
|
"github.com/zeromicro/go-zero/internal/mock"
|
||||||
|
"github.com/zeromicro/go-zero/zrpc/internal/balancer/consistenthash"
|
||||||
|
"github.com/zeromicro/go-zero/zrpc/internal/balancer/p2c"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
"google.golang.org/grpc/credentials/insecure"
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
@@ -245,3 +247,42 @@ func TestNewClientWithTarget(t *testing.T) {
|
|||||||
|
|
||||||
assert.NotNil(t, err)
|
assert.NotNil(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMakeLBServiceConfig(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty name uses default p2c",
|
||||||
|
input: "",
|
||||||
|
expected: fmt.Sprintf(`{"loadBalancingPolicy":"%s"}`, p2c.Name),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "custom balancer name",
|
||||||
|
input: "consistent_hash",
|
||||||
|
expected: `{"loadBalancingPolicy":"consistent_hash"}`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := makeLBServiceConfig(tt.input)
|
||||||
|
if got != tt.expected {
|
||||||
|
t.Errorf("expected %q, got %q", tt.expected, got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetHashKey(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
key := "abc123"
|
||||||
|
|
||||||
|
ctx = SetHashKey(ctx, key)
|
||||||
|
got := consistenthash.GetHashKey(ctx)
|
||||||
|
assert.Equal(t, key, got)
|
||||||
|
|
||||||
|
assert.Empty(t, consistenthash.GetHashKey(context.Background()))
|
||||||
|
}
|
||||||
|
|||||||
@@ -27,10 +27,11 @@ type (
|
|||||||
Target string `json:",optional"`
|
Target string `json:",optional"`
|
||||||
App string `json:",optional"`
|
App string `json:",optional"`
|
||||||
Token string `json:",optional"`
|
Token string `json:",optional"`
|
||||||
NonBlock bool `json:",optional"`
|
NonBlock bool `json:",default=true"`
|
||||||
Timeout int64 `json:",default=2000"`
|
Timeout int64 `json:",default=2000"`
|
||||||
KeepaliveTime time.Duration `json:",optional"`
|
KeepaliveTime time.Duration `json:",optional"`
|
||||||
Middlewares ClientMiddlewaresConf
|
Middlewares ClientMiddlewaresConf
|
||||||
|
BalancerName string `json:",default=p2c_ewma"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// A RpcServerConf is a rpc server config.
|
// A RpcServerConf is a rpc server config.
|
||||||
|
|||||||
@@ -4,9 +4,11 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
zconf "github.com/zeromicro/go-zero/core/conf"
|
||||||
"github.com/zeromicro/go-zero/core/discov"
|
"github.com/zeromicro/go-zero/core/discov"
|
||||||
"github.com/zeromicro/go-zero/core/service"
|
"github.com/zeromicro/go-zero/core/service"
|
||||||
"github.com/zeromicro/go-zero/core/stores/redis"
|
"github.com/zeromicro/go-zero/core/stores/redis"
|
||||||
|
"github.com/zeromicro/go-zero/zrpc/internal/balancer/p2c"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestRpcClientConf(t *testing.T) {
|
func TestRpcClientConf(t *testing.T) {
|
||||||
@@ -39,6 +41,13 @@ func TestRpcClientConf(t *testing.T) {
|
|||||||
_, err := conf.BuildTarget()
|
_, err := conf.BuildTarget()
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("default balancer name", func(t *testing.T) {
|
||||||
|
var conf RpcClientConf
|
||||||
|
err := zconf.FillDefault(&conf)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, p2c.Name, conf.BalancerName)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRpcServerConf(t *testing.T) {
|
func TestRpcServerConf(t *testing.T) {
|
||||||
|
|||||||
97
zrpc/internal/balancer/consistenthash/consistenthash.go
Normal file
97
zrpc/internal/balancer/consistenthash/consistenthash.go
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
package consistenthash
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/zeromicro/go-zero/core/hash"
|
||||||
|
"google.golang.org/grpc/balancer"
|
||||||
|
"google.golang.org/grpc/balancer/base"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
Name = "consistent_hash"
|
||||||
|
|
||||||
|
defaultReplicaCount = 100
|
||||||
|
)
|
||||||
|
|
||||||
|
var emptyPickResult balancer.PickResult
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
balancer.Register(newBuilder())
|
||||||
|
}
|
||||||
|
|
||||||
|
type (
|
||||||
|
// hashKey is the key type for consistent hash in context.
|
||||||
|
hashKey struct{}
|
||||||
|
// pickerBuilder is a builder for picker.
|
||||||
|
pickerBuilder struct{}
|
||||||
|
// picker is a picker that uses consistent hash to pick a sub connection.
|
||||||
|
picker struct {
|
||||||
|
hashRing *hash.ConsistentHash
|
||||||
|
conns map[string]balancer.SubConn
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func (b *pickerBuilder) Build(info base.PickerBuildInfo) balancer.Picker {
|
||||||
|
readySCs := info.ReadySCs
|
||||||
|
if len(readySCs) == 0 {
|
||||||
|
return base.NewErrPicker(balancer.ErrNoSubConnAvailable)
|
||||||
|
}
|
||||||
|
|
||||||
|
conns := make(map[string]balancer.SubConn, len(readySCs))
|
||||||
|
hashRing := hash.NewCustomConsistentHash(defaultReplicaCount, hash.Hash)
|
||||||
|
for conn, connInfo := range readySCs {
|
||||||
|
addr := connInfo.Address.Addr
|
||||||
|
conns[addr] = conn
|
||||||
|
hashRing.Add(addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &picker{
|
||||||
|
hashRing: hashRing,
|
||||||
|
conns: conns,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBuilder() balancer.Builder {
|
||||||
|
return base.NewBalancerBuilder(Name, &pickerBuilder{}, base.Config{HealthCheck: true})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *picker) Pick(info balancer.PickInfo) (balancer.PickResult, error) {
|
||||||
|
hashKey := GetHashKey(info.Ctx)
|
||||||
|
if len(hashKey) == 0 {
|
||||||
|
return emptyPickResult, status.Error(codes.InvalidArgument,
|
||||||
|
"[consistent_hash] missing hash key in context")
|
||||||
|
}
|
||||||
|
|
||||||
|
if addrAny, ok := p.hashRing.Get(hashKey); ok {
|
||||||
|
addr, ok := addrAny.(string)
|
||||||
|
if !ok {
|
||||||
|
return emptyPickResult, status.Error(codes.Internal,
|
||||||
|
"[consistent_hash] invalid addr type in consistent hash")
|
||||||
|
}
|
||||||
|
|
||||||
|
subConn, ok := p.conns[addr]
|
||||||
|
if !ok {
|
||||||
|
return emptyPickResult, status.Errorf(codes.Internal,
|
||||||
|
"[consistent_hash] no subConn for addr: %s", addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
return balancer.PickResult{SubConn: subConn}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return emptyPickResult, status.Errorf(codes.Unavailable,
|
||||||
|
"[consistent_hash] no matching conn for hashKey: %s", hashKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetHashKey sets the hash key into context.
|
||||||
|
func SetHashKey(ctx context.Context, key string) context.Context {
|
||||||
|
return context.WithValue(ctx, hashKey{}, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHashKey gets the hash key from context.
|
||||||
|
func GetHashKey(ctx context.Context) string {
|
||||||
|
v, _ := ctx.Value(hashKey{}).(string)
|
||||||
|
return v
|
||||||
|
}
|
||||||
175
zrpc/internal/balancer/consistenthash/consistenthash_test.go
Normal file
175
zrpc/internal/balancer/consistenthash/consistenthash_test.go
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
package consistenthash
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/zeromicro/go-zero/core/hash"
|
||||||
|
"google.golang.org/grpc/balancer"
|
||||||
|
"google.golang.org/grpc/balancer/base"
|
||||||
|
"google.golang.org/grpc/resolver"
|
||||||
|
)
|
||||||
|
|
||||||
|
type fakeSubConn struct{ id int }
|
||||||
|
|
||||||
|
func (f *fakeSubConn) Connect() {}
|
||||||
|
func (f *fakeSubConn) UpdateAddresses(_ []resolver.Address) {}
|
||||||
|
func (f *fakeSubConn) Shutdown() {}
|
||||||
|
func (f *fakeSubConn) GetOrBuildProducer(b balancer.ProducerBuilder) (balancer.Producer, func()) {
|
||||||
|
return nil, func() {}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPickerBuilder_EmptyReadySCs(t *testing.T) {
|
||||||
|
b := &pickerBuilder{}
|
||||||
|
p := b.Build(base.PickerBuildInfo{ReadySCs: map[balancer.SubConn]base.SubConnInfo{}})
|
||||||
|
|
||||||
|
_, err := p.Pick(balancer.PickInfo{})
|
||||||
|
assert.Equal(t, balancer.ErrNoSubConnAvailable, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPickerBuilder_BuildAndRing(t *testing.T) {
|
||||||
|
subConn1 := &fakeSubConn{id: 1}
|
||||||
|
subConn2 := &fakeSubConn{id: 2}
|
||||||
|
addr1 := "127.0.0.1:8080"
|
||||||
|
addr2 := "127.0.0.1:8081"
|
||||||
|
|
||||||
|
b := &pickerBuilder{}
|
||||||
|
info := base.PickerBuildInfo{
|
||||||
|
ReadySCs: map[balancer.SubConn]base.SubConnInfo{
|
||||||
|
subConn1: {Address: resolver.Address{Addr: addr1}},
|
||||||
|
subConn2: {Address: resolver.Address{Addr: addr2}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
p := b.Build(info).(*picker)
|
||||||
|
assert.NotNil(t, p.hashRing)
|
||||||
|
assert.Len(t, p.conns, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPicker_HashConsistency(t *testing.T) {
|
||||||
|
subConn1 := &fakeSubConn{id: 1}
|
||||||
|
subConn2 := &fakeSubConn{id: 2}
|
||||||
|
|
||||||
|
pb := &pickerBuilder{}
|
||||||
|
info := base.PickerBuildInfo{
|
||||||
|
ReadySCs: map[balancer.SubConn]base.SubConnInfo{
|
||||||
|
subConn1: {Address: resolver.Address{Addr: "127.0.0.1:8080"}},
|
||||||
|
subConn2: {Address: resolver.Address{Addr: "127.0.0.1:8081"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
p := pb.Build(info).(*picker)
|
||||||
|
ctx := SetHashKey(context.Background(), "user_123")
|
||||||
|
res1, err := p.Pick(balancer.PickInfo{Ctx: ctx})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, res1.SubConn)
|
||||||
|
|
||||||
|
// Multiple requests with the same key remain consistent
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
resN, err := p.Pick(balancer.PickInfo{Ctx: ctx})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, res1.SubConn, resN.SubConn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPicker_MissingKey(t *testing.T) {
|
||||||
|
subConn := &fakeSubConn{id: 1}
|
||||||
|
|
||||||
|
pb := &pickerBuilder{}
|
||||||
|
info := base.PickerBuildInfo{
|
||||||
|
ReadySCs: map[balancer.SubConn]base.SubConnInfo{
|
||||||
|
subConn: {Address: resolver.Address{Addr: "127.0.0.1:8080"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
p := pb.Build(info).(*picker)
|
||||||
|
|
||||||
|
// No hash key in context
|
||||||
|
_, err := p.Pick(balancer.PickInfo{Ctx: context.Background()})
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "[consistent_hash] missing hash key in context")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPicker_NoMatchingConn(t *testing.T) {
|
||||||
|
emptyRing := newCustomRingForTest()
|
||||||
|
p := &picker{
|
||||||
|
hashRing: emptyRing,
|
||||||
|
conns: map[string]balancer.SubConn{},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := p.Pick(balancer.PickInfo{Ctx: SetHashKey(context.Background(), "someone")})
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "[consistent_hash] no matching conn for hashKey: someone")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPicker_InvalidAddrType(t *testing.T) {
|
||||||
|
ring := newCustomRingForTest()
|
||||||
|
ring.Add(12345)
|
||||||
|
|
||||||
|
subConn := &fakeSubConn{id: 1}
|
||||||
|
p := &picker{
|
||||||
|
hashRing: ring,
|
||||||
|
conns: map[string]balancer.SubConn{
|
||||||
|
"12345": subConn,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := p.Pick(balancer.PickInfo{Ctx: SetHashKey(context.Background(), "anykey")})
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "[consistent_hash] invalid addr type in consistent hash")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPicker_NoSubConnForAddr(t *testing.T) {
|
||||||
|
ring := newCustomRingForTest()
|
||||||
|
ring.Add("ghost:9999")
|
||||||
|
|
||||||
|
exist := &fakeSubConn{id: 1}
|
||||||
|
p := &picker{
|
||||||
|
hashRing: ring,
|
||||||
|
conns: map[string]balancer.SubConn{
|
||||||
|
"real:8080": exist,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := p.Pick(balancer.PickInfo{Ctx: SetHashKey(context.Background(), "anykey")})
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "[consistent_hash] no subConn for addr: ghost:9999")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetAndGetHashKey(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
key := "abc123"
|
||||||
|
|
||||||
|
ctx = SetHashKey(ctx, key)
|
||||||
|
got := GetHashKey(ctx)
|
||||||
|
assert.Equal(t, key, got)
|
||||||
|
|
||||||
|
assert.Empty(t, GetHashKey(context.Background()))
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkPicker_HashConsistency(b *testing.B) {
|
||||||
|
subConn1 := &fakeSubConn{id: 1}
|
||||||
|
subConn2 := &fakeSubConn{id: 2}
|
||||||
|
|
||||||
|
pb := &pickerBuilder{}
|
||||||
|
info := base.PickerBuildInfo{
|
||||||
|
ReadySCs: map[balancer.SubConn]base.SubConnInfo{
|
||||||
|
subConn1: {Address: resolver.Address{Addr: "127.0.0.1:8080"}},
|
||||||
|
subConn2: {Address: resolver.Address{Addr: "127.0.0.1:8081"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
p := pb.Build(info).(*picker)
|
||||||
|
|
||||||
|
ctx := SetHashKey(context.Background(), "hot_user_123")
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
res, err := p.Pick(balancer.PickInfo{Ctx: ctx})
|
||||||
|
if err != nil || res.SubConn == nil {
|
||||||
|
b.Fatalf("unexpected result: res=%v err=%v", res.SubConn, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newCustomRingForTest() *hash.ConsistentHash {
|
||||||
|
return hash.NewCustomConsistentHash(defaultReplicaCount, hash.Hash)
|
||||||
|
}
|
||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/zrpc/internal/balancer/p2c"
|
|
||||||
"github.com/zeromicro/go-zero/zrpc/internal/clientinterceptors"
|
"github.com/zeromicro/go-zero/zrpc/internal/clientinterceptors"
|
||||||
"github.com/zeromicro/go-zero/zrpc/resolver"
|
"github.com/zeromicro/go-zero/zrpc/resolver"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
@@ -53,9 +52,6 @@ func NewClient(target string, middlewares ClientMiddlewaresConf, opts ...ClientO
|
|||||||
middlewares: middlewares,
|
middlewares: middlewares,
|
||||||
}
|
}
|
||||||
|
|
||||||
svcCfg := fmt.Sprintf(`{"loadBalancingPolicy":"%s"}`, p2c.Name)
|
|
||||||
balancerOpt := WithDialOption(grpc.WithDefaultServiceConfig(svcCfg))
|
|
||||||
opts = append([]ClientOption{balancerOpt}, opts...)
|
|
||||||
if err := cli.dial(target, opts...); err != nil {
|
if err := cli.dial(target, opts...); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -145,6 +141,15 @@ func (c *client) dial(server string, opts ...ClientOption) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithBlock sets the dialing to be blocking.
|
||||||
|
// Deprecated: blocking dials are not recommended by gRPC.
|
||||||
|
// See https://github.com/grpc/grpc-go/blob/master/Documentation/anti-patterns.md
|
||||||
|
func WithBlock() ClientOption {
|
||||||
|
return func(options *ClientOptions) {
|
||||||
|
options.NonBlock = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// WithDialOption returns a func to customize a ClientOptions with given dial option.
|
// WithDialOption returns a func to customize a ClientOptions with given dial option.
|
||||||
func WithDialOption(opt grpc.DialOption) ClientOption {
|
func WithDialOption(opt grpc.DialOption) ClientOption {
|
||||||
return func(options *ClientOptions) {
|
return func(options *ClientOptions) {
|
||||||
|
|||||||
@@ -34,6 +34,13 @@ func TestWithNonBlock(t *testing.T) {
|
|||||||
assert.True(t, options.NonBlock)
|
assert.True(t, options.NonBlock)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWithBlock(t *testing.T) {
|
||||||
|
var options ClientOptions
|
||||||
|
opt := WithBlock()
|
||||||
|
opt(&options)
|
||||||
|
assert.False(t, options.NonBlock)
|
||||||
|
}
|
||||||
|
|
||||||
func TestWithStreamClientInterceptor(t *testing.T) {
|
func TestWithStreamClientInterceptor(t *testing.T) {
|
||||||
var options ClientOptions
|
var options ClientOptions
|
||||||
opt := WithStreamClientInterceptor(func(ctx context.Context, desc *grpc.StreamDesc,
|
opt := WithStreamClientInterceptor(func(ctx context.Context, desc *grpc.StreamDesc,
|
||||||
|
|||||||
Reference in New Issue
Block a user