mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-06-29 23:41:03 +08:00
Compare commits
4 Commits
b2e3aa1587
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f910257ec9 | ||
|
|
d318de1212 | ||
|
|
99515480cf | ||
|
|
dbc71bb57b |
@@ -931,6 +931,113 @@ func TestUnmarshalJsonArray(t *testing.T) {
|
||||
assert.Equal(t, 18, v[0].Age)
|
||||
}
|
||||
|
||||
func TestUnmarshalJsonBytesPointerSliceUint64(t *testing.T) {
|
||||
t.Run("with values", func(t *testing.T) {
|
||||
var c struct {
|
||||
IDs *[]uint64 `json:"ids,optional"`
|
||||
}
|
||||
content := []byte(`{"ids":[9000,9001]}`)
|
||||
|
||||
assert.Nil(t, UnmarshalJsonBytes(content, &c))
|
||||
assert.NotNil(t, c.IDs)
|
||||
assert.Equal(t, []uint64{9000, 9001}, *c.IDs)
|
||||
})
|
||||
|
||||
t.Run("omitted", func(t *testing.T) {
|
||||
var c struct {
|
||||
IDs *[]uint64 `json:"ids,optional"`
|
||||
}
|
||||
content := []byte(`{}`)
|
||||
|
||||
assert.Nil(t, UnmarshalJsonBytes(content, &c))
|
||||
assert.Nil(t, c.IDs)
|
||||
})
|
||||
|
||||
t.Run("null", func(t *testing.T) {
|
||||
var c struct {
|
||||
IDs *[]uint64 `json:"ids,optional"`
|
||||
}
|
||||
content := []byte(`{"ids":null}`)
|
||||
|
||||
assert.Nil(t, UnmarshalJsonBytes(content, &c))
|
||||
assert.Nil(t, c.IDs)
|
||||
})
|
||||
|
||||
t.Run("empty array", func(t *testing.T) {
|
||||
var c struct {
|
||||
IDs *[]uint64 `json:"ids,optional"`
|
||||
}
|
||||
content := []byte(`{"ids":[]}`)
|
||||
|
||||
assert.Nil(t, UnmarshalJsonBytes(content, &c))
|
||||
assert.NotNil(t, c.IDs)
|
||||
assert.Equal(t, []uint64{}, *c.IDs)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalJsonBytesPointerSliceOtherTypes(t *testing.T) {
|
||||
t.Run("pointer to []string", func(t *testing.T) {
|
||||
var c struct {
|
||||
Names *[]string `json:"names,optional"`
|
||||
}
|
||||
content := []byte(`{"names":["a","b"]}`)
|
||||
|
||||
assert.Nil(t, UnmarshalJsonBytes(content, &c))
|
||||
assert.NotNil(t, c.Names)
|
||||
assert.Equal(t, []string{"a", "b"}, *c.Names)
|
||||
})
|
||||
|
||||
t.Run("pointer to []int", func(t *testing.T) {
|
||||
var c struct {
|
||||
Values *[]int `json:"values,optional"`
|
||||
}
|
||||
content := []byte(`{"values":[1,2,3]}`)
|
||||
|
||||
assert.Nil(t, UnmarshalJsonBytes(content, &c))
|
||||
assert.NotNil(t, c.Values)
|
||||
assert.Equal(t, []int{1, 2, 3}, *c.Values)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalJsonBytesPointerSliceStruct(t *testing.T) {
|
||||
type Item struct {
|
||||
Name string `json:"name"`
|
||||
Age int `json:"age"`
|
||||
}
|
||||
|
||||
t.Run("with values", func(t *testing.T) {
|
||||
var c struct {
|
||||
Items *[]Item `json:"items,optional"`
|
||||
}
|
||||
content := []byte(`{"items":[{"name":"alice","age":30},{"name":"bob","age":25}]}`)
|
||||
|
||||
assert.Nil(t, UnmarshalJsonBytes(content, &c))
|
||||
assert.NotNil(t, c.Items)
|
||||
assert.Equal(t, []Item{{Name: "alice", Age: 30}, {Name: "bob", Age: 25}}, *c.Items)
|
||||
})
|
||||
|
||||
t.Run("omitted", func(t *testing.T) {
|
||||
var c struct {
|
||||
Items *[]Item `json:"items,optional"`
|
||||
}
|
||||
content := []byte(`{}`)
|
||||
|
||||
assert.Nil(t, UnmarshalJsonBytes(content, &c))
|
||||
assert.Nil(t, c.Items)
|
||||
})
|
||||
|
||||
t.Run("empty array", func(t *testing.T) {
|
||||
var c struct {
|
||||
Items *[]Item `json:"items,optional"`
|
||||
}
|
||||
content := []byte(`{"items":[]}`)
|
||||
|
||||
assert.Nil(t, UnmarshalJsonBytes(content, &c))
|
||||
assert.NotNil(t, c.Items)
|
||||
assert.Equal(t, []Item{}, *c.Items)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalJsonBytesError(t *testing.T) {
|
||||
var v []struct {
|
||||
Name string `json:"name"`
|
||||
|
||||
@@ -142,11 +142,11 @@ func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value,
|
||||
return nil
|
||||
}
|
||||
|
||||
baseType := fieldType.Elem()
|
||||
baseType := Deref(fieldType).Elem()
|
||||
dereffedBaseType := Deref(baseType)
|
||||
dereffedBaseKind := dereffedBaseType.Kind()
|
||||
if refValue.Len() == 0 {
|
||||
value.Set(reflect.MakeSlice(reflect.SliceOf(baseType), 0, 0))
|
||||
SetValue(fieldType, value, reflect.MakeSlice(reflect.SliceOf(baseType), 0, 0))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -179,7 +179,7 @@ func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value,
|
||||
}
|
||||
|
||||
if valid {
|
||||
value.Set(conv)
|
||||
SetValue(fieldType, value, conv)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -201,7 +201,7 @@ func (u *Unmarshaler) fillSliceFromString(fieldType reflect.Type, value reflect.
|
||||
return errUnsupportedType
|
||||
}
|
||||
|
||||
baseFieldType := fieldType.Elem()
|
||||
baseFieldType := Deref(fieldType).Elem()
|
||||
baseFieldKind := baseFieldType.Kind()
|
||||
conv := reflect.MakeSlice(reflect.SliceOf(baseFieldType), len(slice), cap(slice))
|
||||
|
||||
@@ -211,7 +211,7 @@ func (u *Unmarshaler) fillSliceFromString(fieldType reflect.Type, value reflect.
|
||||
}
|
||||
}
|
||||
|
||||
value.Set(conv)
|
||||
SetValue(fieldType, value, conv)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
var ignoreCmds = map[string]lang.PlaceholderType{
|
||||
"blpop": {},
|
||||
"hello": {},
|
||||
}
|
||||
|
||||
type breakerHook struct {
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
red "github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/breaker"
|
||||
)
|
||||
@@ -75,6 +76,45 @@ func TestBreakerHook_ProcessHook(t *testing.T) {
|
||||
}
|
||||
assert.Equal(t, someError.Error(), err.Error())
|
||||
})
|
||||
|
||||
t.Run("breakerHook_ignoreHello", func(t *testing.T) {
|
||||
// hello is issued on connection init and is in ignoreCmds, so repeated
|
||||
// failures must never trip the breaker into ErrServiceUnavailable.
|
||||
h := breakerHook{brk: breaker.NewBreaker()}
|
||||
someError := errors.New("ERR some error")
|
||||
process := h.ProcessHook(func(_ context.Context, _ red.Cmder) error {
|
||||
return someError
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
var err error
|
||||
for i := 0; i < 1000; i++ {
|
||||
err = process(ctx, red.NewCmd(ctx, "hello", 3))
|
||||
if err != nil && err.Error() != someError.Error() {
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.Equal(t, someError.Error(), err.Error())
|
||||
})
|
||||
|
||||
t.Run("breakerHook_notIgnored", func(t *testing.T) {
|
||||
// a regular command is not ignored, so repeated failures open the breaker.
|
||||
h := breakerHook{brk: breaker.NewBreaker()}
|
||||
someError := errors.New("ERR some error")
|
||||
process := h.ProcessHook(func(_ context.Context, _ red.Cmder) error {
|
||||
return someError
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
var err error
|
||||
for i := 0; i < 1000; i++ {
|
||||
err = process(ctx, red.NewCmd(ctx, "get", "key"))
|
||||
if err != nil && err.Error() != someError.Error() {
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.Equal(t, breaker.ErrServiceUnavailable, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestBreakerHook_ProcessPipelineHook(t *testing.T) {
|
||||
|
||||
@@ -23,6 +23,30 @@ type (
|
||||
Pass string `json:",optional"`
|
||||
Tls bool `json:",optional"`
|
||||
NonBlock bool `json:",default=true"`
|
||||
// DisableIdentity is used to disable CLIENT SETINFO command on connect.
|
||||
//
|
||||
// Some redis versions/proxies do not support CLIENT SETINFO and return an
|
||||
// error on connect; since that command runs through the breaker hook it can
|
||||
// trip the breaker. Set this to true to skip it on such servers. Together
|
||||
// with the default MaintNotifications=disabled (and the always-ignored
|
||||
// HELLO command), this keeps the connect-time commands from tripping the
|
||||
// breaker on incompatible servers, without forcing RESP2.
|
||||
//
|
||||
// default: false
|
||||
DisableIdentity bool `json:",default=false"`
|
||||
// Protocol 2 or 3. Use the version to negotiate RESP version with redis-server.
|
||||
//
|
||||
// default: 3.
|
||||
Protocol int `json:",default=3"`
|
||||
// MaintNotifications controls the CLIENT MAINT_NOTIFICATIONS handshake mode
|
||||
// (go-redis MaintNotificationsConfig.Mode):
|
||||
// - disabled: never send the command (avoids tripping the breaker on servers
|
||||
// that don't support it; keeps RESP3 intact)
|
||||
// - auto: try, silently fall back on error (go-redis default)
|
||||
// - enabled: force, fail the connection on error
|
||||
//
|
||||
// default: disabled
|
||||
MaintNotifications string `json:",default=disabled,options=disabled|enabled|auto"`
|
||||
// PingTimeout is the timeout for ping redis.
|
||||
PingTimeout time.Duration `json:",default=1s"`
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"time"
|
||||
|
||||
red "github.com/redis/go-redis/v9"
|
||||
"github.com/redis/go-redis/v9/maintnotifications"
|
||||
"github.com/zeromicro/go-zero/core/breaker"
|
||||
"github.com/zeromicro/go-zero/core/errorx"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
@@ -53,13 +54,16 @@ type (
|
||||
|
||||
// Redis defines a redis node/cluster. It is thread-safe.
|
||||
Redis struct {
|
||||
Addr string
|
||||
Type string
|
||||
User string
|
||||
Pass string
|
||||
tls bool
|
||||
brk breaker.Breaker
|
||||
hooks []red.Hook
|
||||
Addr string
|
||||
Type string
|
||||
User string
|
||||
Pass string
|
||||
protocol int
|
||||
identity bool
|
||||
maintNotifications maintnotifications.Mode
|
||||
tls bool
|
||||
brk breaker.Breaker
|
||||
hooks []red.Hook
|
||||
}
|
||||
|
||||
// RedisNode interface represents a redis node.
|
||||
@@ -136,6 +140,15 @@ func NewRedis(conf RedisConf, opts ...Option) (*Redis, error) {
|
||||
if conf.Tls {
|
||||
opts = append([]Option{WithTLS()}, opts...)
|
||||
}
|
||||
if conf.Protocol > 0 {
|
||||
opts = append([]Option{WithProtocol(conf.Protocol)}, opts...)
|
||||
}
|
||||
if conf.DisableIdentity {
|
||||
opts = append([]Option{WithIdentity()}, opts...)
|
||||
}
|
||||
if len(conf.MaintNotifications) > 0 {
|
||||
opts = append([]Option{WithMaintNotifications(conf.MaintNotifications)}, opts...)
|
||||
}
|
||||
|
||||
rds := newRedis(conf.Host, opts...)
|
||||
if !conf.NonBlock {
|
||||
@@ -147,20 +160,6 @@ func NewRedis(conf RedisConf, opts ...Option) (*Redis, error) {
|
||||
return rds, nil
|
||||
}
|
||||
|
||||
func newRedis(addr string, opts ...Option) *Redis {
|
||||
r := &Redis{
|
||||
Addr: addr,
|
||||
Type: NodeType,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(r)
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// NewScript returns a new Script instance.
|
||||
func NewScript(script string) *Script {
|
||||
return red.NewScript(script)
|
||||
@@ -2686,6 +2685,18 @@ func (s *Redis) checkConnection(pingTimeout time.Duration) error {
|
||||
return conn.Ping(ctx).Err()
|
||||
}
|
||||
|
||||
// maintNotificationsConfig builds the go-redis maintenance notifications config
|
||||
// from the configured mode, defaulting to disabled when unset so that the
|
||||
// CLIENT MAINT_NOTIFICATIONS command is not issued on connect.
|
||||
func (r *Redis) maintNotificationsConfig() *maintnotifications.Config {
|
||||
mode := r.maintNotifications
|
||||
if len(mode) == 0 {
|
||||
mode = maintnotifications.ModeDisabled
|
||||
}
|
||||
|
||||
return &maintnotifications.Config{Mode: mode}
|
||||
}
|
||||
|
||||
// Cluster customizes the given Redis as a cluster.
|
||||
func Cluster() Option {
|
||||
return func(r *Redis) {
|
||||
@@ -2726,6 +2737,28 @@ func WithUser(user string) Option {
|
||||
}
|
||||
}
|
||||
|
||||
// WithProtocol customizes the given Redis with protocol.
|
||||
func WithProtocol(protocol int) Option {
|
||||
return func(r *Redis) {
|
||||
r.protocol = protocol
|
||||
}
|
||||
}
|
||||
|
||||
// WithIdentity customizes the given Redis with Identity enabled.
|
||||
func WithIdentity() Option {
|
||||
return func(r *Redis) {
|
||||
r.identity = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithMaintNotifications customizes the given Redis with the maintenance
|
||||
// notifications mode (disabled, enabled or auto).
|
||||
func WithMaintNotifications(mode string) Option {
|
||||
return func(r *Redis) {
|
||||
r.maintNotifications = maintnotifications.Mode(mode)
|
||||
}
|
||||
}
|
||||
|
||||
func acceptable(err error) bool {
|
||||
return err == nil || errorx.In(err, red.Nil, context.Canceled)
|
||||
}
|
||||
@@ -2741,6 +2774,20 @@ func getRedis(r *Redis) (RedisNode, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func newRedis(addr string, opts ...Option) *Redis {
|
||||
r := &Redis{
|
||||
Addr: addr,
|
||||
Type: NodeType,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(r)
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func toPairs(vals []red.Z) []Pair {
|
||||
pairs := make([]Pair, len(vals))
|
||||
for i, val := range vals {
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
red "github.com/redis/go-redis/v9"
|
||||
"github.com/redis/go-redis/v9/maintnotifications"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/stringx"
|
||||
@@ -150,6 +151,82 @@ func TestNewRedis(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetClientWithProtocolAndIdentity(t *testing.T) {
|
||||
r := miniredis.RunT(t)
|
||||
defer r.Close()
|
||||
c, err := getClient(&Redis{
|
||||
Addr: r.Addr(),
|
||||
Type: NodeType,
|
||||
protocol: 2,
|
||||
identity: true,
|
||||
})
|
||||
if assert.NoError(t, err) {
|
||||
assert.NotNil(t, c)
|
||||
assert.Equal(t, 2, c.Options().Protocol)
|
||||
assert.True(t, c.Options().DisableIdentity)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRedis_ProtocolAndIdentity(t *testing.T) {
|
||||
logx.Disable()
|
||||
|
||||
s := miniredis.RunT(t)
|
||||
rds, err := NewRedis(RedisConf{
|
||||
Host: s.Addr(),
|
||||
Type: NodeType,
|
||||
Protocol: 2,
|
||||
DisableIdentity: true,
|
||||
})
|
||||
if assert.NoError(t, err) {
|
||||
assert.Equal(t, 2, rds.protocol)
|
||||
assert.True(t, rds.identity)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetClientWithMaintNotifications(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mode maintnotifications.Mode
|
||||
want maintnotifications.Mode
|
||||
}{
|
||||
{name: "unset falls back to disabled", mode: "", want: maintnotifications.ModeDisabled},
|
||||
{name: "disabled", mode: maintnotifications.ModeDisabled, want: maintnotifications.ModeDisabled},
|
||||
{name: "enabled", mode: maintnotifications.ModeEnabled, want: maintnotifications.ModeEnabled},
|
||||
{name: "auto", mode: maintnotifications.ModeAuto, want: maintnotifications.ModeAuto},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
r := miniredis.RunT(t)
|
||||
defer r.Close()
|
||||
c, err := getClient(&Redis{
|
||||
Addr: r.Addr(),
|
||||
Type: NodeType,
|
||||
maintNotifications: test.mode,
|
||||
})
|
||||
if assert.NoError(t, err) {
|
||||
assert.NotNil(t, c)
|
||||
assert.NotNil(t, c.Options().MaintNotificationsConfig)
|
||||
assert.Equal(t, test.want, c.Options().MaintNotificationsConfig.Mode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRedis_MaintNotifications(t *testing.T) {
|
||||
logx.Disable()
|
||||
|
||||
s := miniredis.RunT(t)
|
||||
rds, err := NewRedis(RedisConf{
|
||||
Host: s.Addr(),
|
||||
Type: NodeType,
|
||||
MaintNotifications: string(maintnotifications.ModeAuto),
|
||||
})
|
||||
if assert.NoError(t, err) {
|
||||
assert.Equal(t, maintnotifications.ModeAuto, rds.maintNotifications)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedis_NonBlock(t *testing.T) {
|
||||
logx.Disable()
|
||||
|
||||
|
||||
@@ -50,25 +50,31 @@ func CreateBlockingNode(r *Redis) (ClosableNode, error) {
|
||||
switch r.Type {
|
||||
case NodeType:
|
||||
client := red.NewClient(&red.Options{
|
||||
Addr: r.Addr,
|
||||
Username: r.User,
|
||||
Password: r.Pass,
|
||||
DB: defaultDatabase,
|
||||
MaxRetries: maxRetries,
|
||||
PoolSize: 1,
|
||||
MinIdleConns: 1,
|
||||
ReadTimeout: timeout,
|
||||
Addr: r.Addr,
|
||||
Username: r.User,
|
||||
Password: r.Pass,
|
||||
DB: defaultDatabase,
|
||||
MaxRetries: maxRetries,
|
||||
PoolSize: 1,
|
||||
MinIdleConns: 1,
|
||||
ReadTimeout: timeout,
|
||||
Protocol: r.protocol,
|
||||
DisableIdentity: r.identity,
|
||||
MaintNotificationsConfig: r.maintNotificationsConfig(),
|
||||
})
|
||||
return &clientBridge{client}, nil
|
||||
case ClusterType:
|
||||
client := red.NewClusterClient(&red.ClusterOptions{
|
||||
Addrs: splitClusterAddrs(r.Addr),
|
||||
Username: r.User,
|
||||
Password: r.Pass,
|
||||
MaxRetries: maxRetries,
|
||||
PoolSize: 1,
|
||||
MinIdleConns: 1,
|
||||
ReadTimeout: timeout,
|
||||
Addrs: splitClusterAddrs(r.Addr),
|
||||
Username: r.User,
|
||||
Password: r.Pass,
|
||||
MaxRetries: maxRetries,
|
||||
PoolSize: 1,
|
||||
MinIdleConns: 1,
|
||||
ReadTimeout: timeout,
|
||||
Protocol: r.protocol,
|
||||
DisableIdentity: r.identity,
|
||||
MaintNotificationsConfig: r.maintNotificationsConfig(),
|
||||
})
|
||||
return &clusterBridge{client}, nil
|
||||
default:
|
||||
|
||||
@@ -43,4 +43,32 @@ func TestBlockingNode(t *testing.T) {
|
||||
_, err = CreateBlockingNode(New(r.Addr(), badType()))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("test blocking node with protocol and identity", func(t *testing.T) {
|
||||
r, err := miniredis.Run()
|
||||
assert.NoError(t, err)
|
||||
defer r.Close()
|
||||
|
||||
node, err := CreateBlockingNode(New(r.Addr(), WithProtocol(2), WithIdentity()))
|
||||
assert.NoError(t, err)
|
||||
bridge, ok := node.(*clientBridge)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, 2, bridge.Options().Protocol)
|
||||
assert.True(t, bridge.Options().DisableIdentity)
|
||||
node.Close()
|
||||
})
|
||||
|
||||
t.Run("test blocking node with cluster, protocol and identity", func(t *testing.T) {
|
||||
r, err := miniredis.Run()
|
||||
assert.NoError(t, err)
|
||||
defer r.Close()
|
||||
|
||||
node, err := CreateBlockingNode(New(r.Addr(), Cluster(), WithProtocol(2), WithIdentity()))
|
||||
assert.NoError(t, err)
|
||||
bridge, ok := node.(*clusterBridge)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, 2, bridge.Options().Protocol)
|
||||
assert.True(t, bridge.Options().DisableIdentity)
|
||||
node.Close()
|
||||
})
|
||||
}
|
||||
|
||||
@@ -30,13 +30,16 @@ func getClient(r *Redis) (*red.Client, error) {
|
||||
}
|
||||
}
|
||||
store := red.NewClient(&red.Options{
|
||||
Addr: r.Addr,
|
||||
Username: r.User,
|
||||
Password: r.Pass,
|
||||
DB: defaultDatabase,
|
||||
MaxRetries: maxRetries,
|
||||
MinIdleConns: idleConns,
|
||||
TLSConfig: tlsConfig,
|
||||
Addr: r.Addr,
|
||||
Username: r.User,
|
||||
Password: r.Pass,
|
||||
DB: defaultDatabase,
|
||||
MaxRetries: maxRetries,
|
||||
MinIdleConns: idleConns,
|
||||
TLSConfig: tlsConfig,
|
||||
Protocol: r.protocol,
|
||||
DisableIdentity: r.identity,
|
||||
MaintNotificationsConfig: r.maintNotificationsConfig(),
|
||||
})
|
||||
|
||||
hooks := append([]red.Hook{defaultDurationHook, breakerHook{
|
||||
|
||||
@@ -27,12 +27,15 @@ func getCluster(r *Redis) (*red.ClusterClient, error) {
|
||||
}
|
||||
}
|
||||
store := red.NewClusterClient(&red.ClusterOptions{
|
||||
Addrs: splitClusterAddrs(r.Addr),
|
||||
Username: r.User,
|
||||
Password: r.Pass,
|
||||
MaxRetries: maxRetries,
|
||||
MinIdleConns: idleConns,
|
||||
TLSConfig: tlsConfig,
|
||||
Addrs: splitClusterAddrs(r.Addr),
|
||||
Username: r.User,
|
||||
Password: r.Pass,
|
||||
MaxRetries: maxRetries,
|
||||
MinIdleConns: idleConns,
|
||||
TLSConfig: tlsConfig,
|
||||
Protocol: r.protocol,
|
||||
DisableIdentity: r.identity,
|
||||
MaintNotificationsConfig: r.maintNotificationsConfig(),
|
||||
})
|
||||
|
||||
hooks := append([]red.Hook{defaultDurationHook, breakerHook{
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
red "github.com/redis/go-redis/v9"
|
||||
"github.com/redis/go-redis/v9/maintnotifications"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
@@ -57,3 +58,50 @@ func TestGetCluster(t *testing.T) {
|
||||
assert.NotNil(t, c)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetClusterWithProtocolAndIdentity(t *testing.T) {
|
||||
r := miniredis.RunT(t)
|
||||
defer r.Close()
|
||||
c, err := getCluster(&Redis{
|
||||
Addr: r.Addr(),
|
||||
Type: ClusterType,
|
||||
protocol: 2,
|
||||
identity: true,
|
||||
hooks: []red.Hook{defaultDurationHook},
|
||||
})
|
||||
if assert.NoError(t, err) {
|
||||
assert.NotNil(t, c)
|
||||
assert.Equal(t, 2, c.Options().Protocol)
|
||||
assert.True(t, c.Options().DisableIdentity)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetClusterWithMaintNotifications(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mode maintnotifications.Mode
|
||||
want maintnotifications.Mode
|
||||
}{
|
||||
{name: "unset falls back to disabled", mode: "", want: maintnotifications.ModeDisabled},
|
||||
{name: "disabled", mode: maintnotifications.ModeDisabled, want: maintnotifications.ModeDisabled},
|
||||
{name: "auto", mode: maintnotifications.ModeAuto, want: maintnotifications.ModeAuto},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
r := miniredis.RunT(t)
|
||||
defer r.Close()
|
||||
c, err := getCluster(&Redis{
|
||||
Addr: r.Addr(),
|
||||
Type: ClusterType,
|
||||
maintNotifications: test.mode,
|
||||
hooks: []red.Hook{defaultDurationHook},
|
||||
})
|
||||
if assert.NoError(t, err) {
|
||||
assert.NotNil(t, c)
|
||||
assert.NotNil(t, c.Options().MaintNotificationsConfig)
|
||||
assert.Equal(t, test.want, c.Options().MaintNotificationsConfig.Mode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -67,12 +67,9 @@ func (g *Generator) genCallGroup(ctx DirContext, proto parser.Proto, cfg *conf.C
|
||||
serviceName := stringx.From(service.Name).ToCamel()
|
||||
|
||||
// Collect only the message types actually used by this service's RPCs,
|
||||
// so that each client file only aliases its own request/response types.
|
||||
usedTypes := collection.NewSet[string]()
|
||||
for _, rpc := range service.RPC {
|
||||
usedTypes.Add(parser.CamelCase(rpc.RequestType))
|
||||
usedTypes.Add(parser.CamelCase(rpc.ReturnsType))
|
||||
}
|
||||
// so that each client file only aliases its own request/response types
|
||||
// and their same-file message dependencies.
|
||||
usedTypes := collectServiceUsedTypes(proto.Message, service)
|
||||
|
||||
alias := collection.NewSet[string]()
|
||||
var hasSameNameBetweenMessageAndService bool
|
||||
@@ -337,17 +334,85 @@ func (g *Generator) getInterfaceFuncs(goPackage, mainGoPackage string, service p
|
||||
return functions, nil
|
||||
}
|
||||
|
||||
// collectServiceUsedTypes returns the set of CamelCase message names that are
|
||||
// reachable from any of the service's RPC request or response types via field
|
||||
// references within the same proto file. This ensures per-service client files
|
||||
// alias their own request/response types and all transitively-referenced message
|
||||
// types, but never unrelated messages from other services.
|
||||
func collectServiceUsedTypes(messages []parser.Message, service parser.Service) *collection.Set[string] {
|
||||
messageByName := make(map[string]*proto.Message, len(messages))
|
||||
for _, item := range messages {
|
||||
msgName := parser.CamelCase(getMessageName(*item.Message))
|
||||
messageByName[msgName] = item.Message
|
||||
}
|
||||
|
||||
usedTypes := collection.NewSet[string]()
|
||||
for _, rpc := range service.RPC {
|
||||
collectMessageDependencies(rpc.RequestType, messageByName, usedTypes)
|
||||
collectMessageDependencies(rpc.ReturnsType, messageByName, usedTypes)
|
||||
}
|
||||
|
||||
return usedTypes
|
||||
}
|
||||
|
||||
// collectMessageDependencies recursively adds protoType and all message types
|
||||
// referenced by its fields into usedTypes, looking up messages by CamelCase
|
||||
// name in messageByName. The cycle guard (usedTypes.Contains) prevents
|
||||
// infinite recursion on circular field references.
|
||||
func collectMessageDependencies(protoType string, messageByName map[string]*proto.Message,
|
||||
usedTypes *collection.Set[string]) {
|
||||
for _, candidate := range messageTypeCandidates(protoType) {
|
||||
msg, ok := messageByName[candidate]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if usedTypes.Contains(candidate) {
|
||||
return
|
||||
}
|
||||
|
||||
usedTypes.Add(candidate)
|
||||
for _, elem := range msg.Elements {
|
||||
switch field := elem.(type) {
|
||||
case *proto.NormalField:
|
||||
collectMessageDependencies(field.Type, messageByName, usedTypes)
|
||||
case *proto.MapField:
|
||||
// Map key types are always scalars in proto3; only the value type
|
||||
// can be a message.
|
||||
collectMessageDependencies(field.Type, messageByName, usedTypes)
|
||||
case *proto.Oneof:
|
||||
for _, oneofElem := range field.Elements {
|
||||
if oneofField, ok := oneofElem.(*proto.OneOfField); ok {
|
||||
collectMessageDependencies(oneofField.Type, messageByName, usedTypes)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// messageTypeCandidates returns the CamelCase lookup keys to try for a proto
|
||||
// field type. Two candidates are produced to handle both simple names
|
||||
// ("MyMsg") and dotted/qualified names ("pkg.MyMsg" → "PkgMyMsg").
|
||||
func messageTypeCandidates(protoType string) []string {
|
||||
protoType = strings.TrimPrefix(protoType, ".")
|
||||
return []string{
|
||||
parser.CamelCase(protoType),
|
||||
parser.CamelCase(strings.ReplaceAll(protoType, ".", "_")),
|
||||
}
|
||||
}
|
||||
|
||||
// buildExtraImportLines converts a set of import paths into quoted import lines
|
||||
// for use in the call.tpl {{.extraImports}} placeholder.
|
||||
func buildExtraImportLines(extraImports *collection.Set[string]) string {
|
||||
if extraImports.Count() == 0 {
|
||||
return ""
|
||||
}
|
||||
keys := extraImports.Keys()
|
||||
sort.Strings(keys)
|
||||
lines := make([]string, 0, len(keys))
|
||||
for _, k := range keys {
|
||||
lines = append(lines, fmt.Sprintf(`"%s"`, k))
|
||||
}
|
||||
return strings.Join(lines, "\n\t")
|
||||
if extraImports.Count() == 0 {
|
||||
return ""
|
||||
}
|
||||
keys := extraImports.Keys()
|
||||
sort.Strings(keys)
|
||||
lines := make([]string, 0, len(keys))
|
||||
for _, k := range keys {
|
||||
lines = append(lines, fmt.Sprintf(`"%s"`, k))
|
||||
}
|
||||
return strings.Join(lines, "\n\t")
|
||||
}
|
||||
|
||||
@@ -34,50 +34,261 @@ func (m *mockDirContext) GetMain() Dir { return Dir{} }
|
||||
func (m *mockDirContext) GetServiceName() stringx.String { return stringx.From("test") }
|
||||
func (m *mockDirContext) SetPbDir(pbDir, grpcDir string) {}
|
||||
|
||||
// TestGenCallGroup_OnlyUsedTypesAliased verifies that in multi-service mode each
|
||||
// generated client file contains type aliases only for the message types actually
|
||||
// used by that service's RPCs (fix for issue #5481).
|
||||
// newTestDirContext builds a mockDirContext that writes generated files under
|
||||
// callBase, with a pb directory that differs (so alias generation is triggered).
|
||||
func newTestDirContext(t *testing.T, callBase, pbBase string, services ...string) *mockDirContext {
|
||||
t.Helper()
|
||||
for _, svc := range services {
|
||||
require.NoError(t, os.MkdirAll(filepath.Join(callBase, strings.ToLower(svc)), 0755))
|
||||
}
|
||||
require.NoError(t, os.MkdirAll(pbBase, 0755))
|
||||
return &mockDirContext{
|
||||
callDir: Dir{
|
||||
Filename: callBase,
|
||||
Package: "example.com/test/call",
|
||||
Base: "call",
|
||||
GetChildPackage: func(childPath string) (string, error) {
|
||||
return filepath.Join(callBase, strings.ToLower(childPath)), nil
|
||||
},
|
||||
},
|
||||
pbDir: Dir{Filename: pbBase, Package: "example.com/test/pb", Base: "pb"},
|
||||
protoGo: Dir{
|
||||
// Must differ from service dir names so isCallPkgSameToPbPkg stays
|
||||
// false and alias generation is triggered.
|
||||
Filename: pbBase,
|
||||
Package: "example.com/test/pb",
|
||||
Base: "pb",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ---- unit tests for collectServiceUsedTypes --------------------------------
|
||||
|
||||
// TestCollectServiceUsedTypes_DirectOnly verifies that request and response
|
||||
// types with no message fields are collected as-is.
|
||||
func TestCollectServiceUsedTypes_DirectOnly(t *testing.T) {
|
||||
messages := []parser.Message{
|
||||
{Message: &proto.Message{Name: "AReq"}},
|
||||
{Message: &proto.Message{Name: "AResp"}},
|
||||
{Message: &proto.Message{Name: "Unrelated"}},
|
||||
}
|
||||
service := parser.Service{
|
||||
Service: &proto.Service{Name: "ServiceA"},
|
||||
RPC: []*parser.RPC{
|
||||
{RPC: &proto.RPC{Name: "Do", RequestType: "AReq", ReturnsType: "AResp"}},
|
||||
},
|
||||
}
|
||||
|
||||
got := collectServiceUsedTypes(messages, service)
|
||||
|
||||
assert.True(t, got.Contains("AReq"))
|
||||
assert.True(t, got.Contains("AResp"))
|
||||
assert.False(t, got.Contains("Unrelated"), "unrelated message must not be collected")
|
||||
}
|
||||
|
||||
// TestCollectServiceUsedTypes_NestedNormalField verifies that a message type
|
||||
// referenced via a NormalField inside a response is transitively collected
|
||||
// (regression test for issue #5618).
|
||||
func TestCollectServiceUsedTypes_NestedNormalField(t *testing.T) {
|
||||
messages := []parser.Message{
|
||||
{Message: &proto.Message{Name: "AReq"}},
|
||||
{Message: &proto.Message{
|
||||
Name: "AResp",
|
||||
Elements: []proto.Visitee{
|
||||
&proto.NormalField{Field: &proto.Field{Name: "items", Type: "AItem"}},
|
||||
},
|
||||
}},
|
||||
{Message: &proto.Message{Name: "AItem"}},
|
||||
}
|
||||
service := parser.Service{
|
||||
Service: &proto.Service{Name: "ServiceA"},
|
||||
RPC: []*parser.RPC{
|
||||
{RPC: &proto.RPC{Name: "List", RequestType: "AReq", ReturnsType: "AResp"}},
|
||||
},
|
||||
}
|
||||
|
||||
got := collectServiceUsedTypes(messages, service)
|
||||
|
||||
assert.True(t, got.Contains("AReq"))
|
||||
assert.True(t, got.Contains("AResp"))
|
||||
assert.True(t, got.Contains("AItem"), "field type AItem must be transitively collected")
|
||||
}
|
||||
|
||||
// TestCollectServiceUsedTypes_MapValueField verifies that the value type of a
|
||||
// MapField inside a response message is transitively collected.
|
||||
func TestCollectServiceUsedTypes_MapValueField(t *testing.T) {
|
||||
messages := []parser.Message{
|
||||
{Message: &proto.Message{Name: "AReq"}},
|
||||
{Message: &proto.Message{
|
||||
Name: "AResp",
|
||||
Elements: []proto.Visitee{
|
||||
&proto.MapField{KeyType: "string", Field: &proto.Field{Name: "index", Type: "AItem"}},
|
||||
},
|
||||
}},
|
||||
{Message: &proto.Message{Name: "AItem"}},
|
||||
}
|
||||
service := parser.Service{
|
||||
Service: &proto.Service{Name: "ServiceA"},
|
||||
RPC: []*parser.RPC{
|
||||
{RPC: &proto.RPC{Name: "GetMap", RequestType: "AReq", ReturnsType: "AResp"}},
|
||||
},
|
||||
}
|
||||
|
||||
got := collectServiceUsedTypes(messages, service)
|
||||
|
||||
assert.True(t, got.Contains("AResp"))
|
||||
assert.True(t, got.Contains("AItem"), "map value type AItem must be transitively collected")
|
||||
}
|
||||
|
||||
// TestCollectServiceUsedTypes_OneofField verifies that message types referenced
|
||||
// inside a Oneof element are transitively collected.
|
||||
func TestCollectServiceUsedTypes_OneofField(t *testing.T) {
|
||||
oneof := &proto.Oneof{Name: "result"}
|
||||
oneof.Elements = []proto.Visitee{
|
||||
&proto.OneOfField{Field: &proto.Field{Name: "success", Type: "SuccessMsg"}},
|
||||
&proto.OneOfField{Field: &proto.Field{Name: "failure", Type: "FailureMsg"}},
|
||||
}
|
||||
messages := []parser.Message{
|
||||
{Message: &proto.Message{Name: "AReq"}},
|
||||
{Message: &proto.Message{
|
||||
Name: "AResp",
|
||||
Elements: []proto.Visitee{oneof},
|
||||
}},
|
||||
{Message: &proto.Message{Name: "SuccessMsg"}},
|
||||
{Message: &proto.Message{Name: "FailureMsg"}},
|
||||
}
|
||||
service := parser.Service{
|
||||
Service: &proto.Service{Name: "ServiceA"},
|
||||
RPC: []*parser.RPC{
|
||||
{RPC: &proto.RPC{Name: "Do", RequestType: "AReq", ReturnsType: "AResp"}},
|
||||
},
|
||||
}
|
||||
|
||||
got := collectServiceUsedTypes(messages, service)
|
||||
|
||||
assert.True(t, got.Contains("AResp"))
|
||||
assert.True(t, got.Contains("SuccessMsg"), "oneof field type SuccessMsg must be collected")
|
||||
assert.True(t, got.Contains("FailureMsg"), "oneof field type FailureMsg must be collected")
|
||||
}
|
||||
|
||||
// TestCollectServiceUsedTypes_MultiLevelTransitive verifies that a chain
|
||||
// AResp → BMsg → CMsg is fully collected (multi-level transitivity).
|
||||
func TestCollectServiceUsedTypes_MultiLevelTransitive(t *testing.T) {
|
||||
messages := []parser.Message{
|
||||
{Message: &proto.Message{Name: "AReq"}},
|
||||
{Message: &proto.Message{
|
||||
Name: "AResp",
|
||||
Elements: []proto.Visitee{
|
||||
&proto.NormalField{Field: &proto.Field{Name: "b", Type: "BMsg"}},
|
||||
},
|
||||
}},
|
||||
{Message: &proto.Message{
|
||||
Name: "BMsg",
|
||||
Elements: []proto.Visitee{
|
||||
&proto.NormalField{Field: &proto.Field{Name: "c", Type: "CMsg"}},
|
||||
},
|
||||
}},
|
||||
{Message: &proto.Message{Name: "CMsg"}},
|
||||
}
|
||||
service := parser.Service{
|
||||
Service: &proto.Service{Name: "ServiceA"},
|
||||
RPC: []*parser.RPC{
|
||||
{RPC: &proto.RPC{Name: "Do", RequestType: "AReq", ReturnsType: "AResp"}},
|
||||
},
|
||||
}
|
||||
|
||||
got := collectServiceUsedTypes(messages, service)
|
||||
|
||||
assert.True(t, got.Contains("AReq"))
|
||||
assert.True(t, got.Contains("AResp"))
|
||||
assert.True(t, got.Contains("BMsg"), "BMsg must be transitively collected via AResp")
|
||||
assert.True(t, got.Contains("CMsg"), "CMsg must be transitively collected via BMsg")
|
||||
}
|
||||
|
||||
// TestCollectServiceUsedTypes_CycleDetection verifies that circular field
|
||||
// references (AResp ↔ BMsg) do not cause infinite recursion.
|
||||
func TestCollectServiceUsedTypes_CycleDetection(t *testing.T) {
|
||||
messages := []parser.Message{
|
||||
{Message: &proto.Message{Name: "AReq"}},
|
||||
{Message: &proto.Message{
|
||||
Name: "AResp",
|
||||
Elements: []proto.Visitee{
|
||||
&proto.NormalField{Field: &proto.Field{Name: "b", Type: "BMsg"}},
|
||||
},
|
||||
}},
|
||||
{Message: &proto.Message{
|
||||
Name: "BMsg",
|
||||
Elements: []proto.Visitee{
|
||||
// circular back-reference to AResp
|
||||
&proto.NormalField{Field: &proto.Field{Name: "a", Type: "AResp"}},
|
||||
},
|
||||
}},
|
||||
}
|
||||
service := parser.Service{
|
||||
Service: &proto.Service{Name: "ServiceA"},
|
||||
RPC: []*parser.RPC{
|
||||
{RPC: &proto.RPC{Name: "Do", RequestType: "AReq", ReturnsType: "AResp"}},
|
||||
},
|
||||
}
|
||||
|
||||
// Must not panic or loop; both messages are reachable.
|
||||
got := collectServiceUsedTypes(messages, service)
|
||||
|
||||
assert.True(t, got.Contains("AResp"))
|
||||
assert.True(t, got.Contains("BMsg"))
|
||||
}
|
||||
|
||||
// TestCollectServiceUsedTypes_ExcludesUnrelatedService verifies that messages
|
||||
// belonging only to another service are not included.
|
||||
func TestCollectServiceUsedTypes_ExcludesUnrelatedService(t *testing.T) {
|
||||
messages := []parser.Message{
|
||||
{Message: &proto.Message{Name: "AReq"}},
|
||||
{Message: &proto.Message{Name: "AResp"}},
|
||||
{Message: &proto.Message{Name: "BReq"}},
|
||||
{Message: &proto.Message{Name: "BResp"}},
|
||||
}
|
||||
service := parser.Service{
|
||||
Service: &proto.Service{Name: "ServiceA"},
|
||||
RPC: []*parser.RPC{
|
||||
{RPC: &proto.RPC{Name: "DoA", RequestType: "AReq", ReturnsType: "AResp"}},
|
||||
},
|
||||
}
|
||||
|
||||
got := collectServiceUsedTypes(messages, service)
|
||||
|
||||
assert.True(t, got.Contains("AReq"))
|
||||
assert.True(t, got.Contains("AResp"))
|
||||
assert.False(t, got.Contains("BReq"), "BReq belongs to ServiceB and must be excluded")
|
||||
assert.False(t, got.Contains("BResp"), "BResp belongs to ServiceB and must be excluded")
|
||||
}
|
||||
|
||||
// ---- integration tests via genCallGroup ------------------------------------
|
||||
|
||||
// TestGenCallGroup_OnlyUsedTypesAliased verifies that in multi-service mode
|
||||
// each generated client file aliases only its own request/response types and
|
||||
// their transitive field dependencies (fix for issues #5481 and #5618).
|
||||
func TestGenCallGroup_OnlyUsedTypesAliased(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
callBase := filepath.Join(tmpDir, "call")
|
||||
pbBase := filepath.Join(tmpDir, "pb")
|
||||
|
||||
// Pre-create subdirs that genCallGroup will write into.
|
||||
require.NoError(t, os.MkdirAll(filepath.Join(callBase, "servicea"), 0755))
|
||||
require.NoError(t, os.MkdirAll(filepath.Join(callBase, "serviceb"), 0755))
|
||||
require.NoError(t, os.MkdirAll(pbBase, 0755))
|
||||
mctx := newTestDirContext(t, callBase, pbBase, "ServiceA", "ServiceB")
|
||||
|
||||
mctx := &mockDirContext{
|
||||
callDir: Dir{
|
||||
Filename: callBase,
|
||||
Package: "example.com/multitest/call",
|
||||
Base: "call",
|
||||
GetChildPackage: func(childPath string) (string, error) {
|
||||
// Return a package path whose Base() is the lowercase service name.
|
||||
return filepath.Join(callBase, strings.ToLower(childPath)), nil
|
||||
},
|
||||
},
|
||||
pbDir: Dir{
|
||||
Filename: pbBase,
|
||||
Package: "example.com/multitest/pb",
|
||||
Base: "pb",
|
||||
},
|
||||
protoGo: Dir{
|
||||
// Must differ from "servicea"/"serviceb" so isCallPkgSameToPbPkg stays false
|
||||
// and alias generation is triggered.
|
||||
Filename: pbBase,
|
||||
Package: "example.com/multitest/pb",
|
||||
Base: "pb",
|
||||
},
|
||||
}
|
||||
|
||||
// Proto with two services that use completely disjoint message types.
|
||||
// ServiceA: AResp contains a NormalField of type AItem (issue #5618).
|
||||
// ServiceB: BResp has no nested message fields.
|
||||
// AItem must appear in ServiceA's file but not ServiceB's.
|
||||
protoData := parser.Proto{
|
||||
Name: "multi.proto",
|
||||
PbPackage: "pb",
|
||||
Message: []parser.Message{
|
||||
{Message: &proto.Message{Name: "AReq"}},
|
||||
{Message: &proto.Message{Name: "AResp"}},
|
||||
{Message: &proto.Message{
|
||||
Name: "AResp",
|
||||
Elements: []proto.Visitee{
|
||||
&proto.NormalField{Field: &proto.Field{Name: "items", Type: "AItem"}},
|
||||
},
|
||||
}},
|
||||
{Message: &proto.Message{Name: "AItem"}},
|
||||
{Message: &proto.Message{Name: "BReq"}},
|
||||
{Message: &proto.Message{Name: "BResp"}},
|
||||
},
|
||||
@@ -99,29 +310,163 @@ func TestGenCallGroup_OnlyUsedTypesAliased(t *testing.T) {
|
||||
|
||||
cfg, err := conf.NewConfig("")
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, NewGenerator("gozero", false).genCallGroup(mctx, protoData, cfg))
|
||||
|
||||
g := NewGenerator("gozero", false)
|
||||
require.NoError(t, g.genCallGroup(mctx, protoData, cfg))
|
||||
aFile := normalizeWS(readGenFile(t, callBase, "servicea", "servicea.go"))
|
||||
assert.Contains(t, aFile, "AReq = pb.AReq", "ServiceA must alias AReq")
|
||||
assert.Contains(t, aFile, "AResp = pb.AResp", "ServiceA must alias AResp")
|
||||
assert.Contains(t, aFile, "AItem = pb.AItem", "ServiceA must alias AItem (transitive NormalField)")
|
||||
assert.NotContains(t, aFile, "BReq = pb.BReq", "ServiceA must not alias BReq")
|
||||
assert.NotContains(t, aFile, "BResp = pb.BResp", "ServiceA must not alias BResp")
|
||||
|
||||
// servicea/servicea.go — aliases for AReq/AResp only
|
||||
aContent, err := os.ReadFile(filepath.Join(callBase, "servicea", "servicea.go"))
|
||||
bFile := normalizeWS(readGenFile(t, callBase, "serviceb", "serviceb.go"))
|
||||
assert.Contains(t, bFile, "BReq = pb.BReq", "ServiceB must alias BReq")
|
||||
assert.Contains(t, bFile, "BResp = pb.BResp", "ServiceB must alias BResp")
|
||||
assert.NotContains(t, bFile, "AReq = pb.AReq", "ServiceB must not alias AReq")
|
||||
assert.NotContains(t, bFile, "AResp = pb.AResp", "ServiceB must not alias AResp")
|
||||
assert.NotContains(t, bFile, "AItem = pb.AItem", "ServiceB must not alias AItem")
|
||||
}
|
||||
|
||||
// TestGenCallGroup_MapValueAliased verifies that the value type of a MapField
|
||||
// inside a service response is included in the generated aliases.
|
||||
func TestGenCallGroup_MapValueAliased(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
callBase := filepath.Join(tmpDir, "call")
|
||||
pbBase := filepath.Join(tmpDir, "pb")
|
||||
|
||||
mctx := newTestDirContext(t, callBase, pbBase, "ServiceA")
|
||||
|
||||
protoData := parser.Proto{
|
||||
Name: "map.proto",
|
||||
PbPackage: "pb",
|
||||
Message: []parser.Message{
|
||||
{Message: &proto.Message{Name: "AReq"}},
|
||||
{Message: &proto.Message{
|
||||
Name: "AResp",
|
||||
Elements: []proto.Visitee{
|
||||
&proto.MapField{KeyType: "string", Field: &proto.Field{Name: "index", Type: "AItem"}},
|
||||
},
|
||||
}},
|
||||
{Message: &proto.Message{Name: "AItem"}},
|
||||
},
|
||||
Service: parser.Services{
|
||||
{
|
||||
Service: &proto.Service{Name: "ServiceA"},
|
||||
RPC: []*parser.RPC{
|
||||
{RPC: &proto.RPC{Name: "GetMap", RequestType: "AReq", ReturnsType: "AResp"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cfg, err := conf.NewConfig("")
|
||||
require.NoError(t, err)
|
||||
aFile := normalizeWS(string(aContent))
|
||||
require.NoError(t, NewGenerator("gozero", false).genCallGroup(mctx, protoData, cfg))
|
||||
|
||||
assert.Contains(t, aFile, "AReq = pb.AReq", "ServiceA file should alias AReq")
|
||||
assert.Contains(t, aFile, "AResp = pb.AResp", "ServiceA file should alias AResp")
|
||||
assert.NotContains(t, aFile, "BReq = pb.BReq", "ServiceA file must not alias BReq")
|
||||
assert.NotContains(t, aFile, "BResp = pb.BResp", "ServiceA file must not alias BResp")
|
||||
aFile := normalizeWS(readGenFile(t, callBase, "servicea", "servicea.go"))
|
||||
assert.Contains(t, aFile, "AResp = pb.AResp")
|
||||
assert.Contains(t, aFile, "AItem = pb.AItem", "map value type AItem must be aliased")
|
||||
}
|
||||
|
||||
// serviceb/serviceb.go — aliases for BReq/BResp only
|
||||
bContent, err := os.ReadFile(filepath.Join(callBase, "serviceb", "serviceb.go"))
|
||||
// TestGenCallGroup_OneofAliased verifies that message types referenced inside a
|
||||
// Oneof element are included in the generated aliases.
|
||||
func TestGenCallGroup_OneofAliased(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
callBase := filepath.Join(tmpDir, "call")
|
||||
pbBase := filepath.Join(tmpDir, "pb")
|
||||
|
||||
mctx := newTestDirContext(t, callBase, pbBase, "ServiceA")
|
||||
|
||||
oneof := &proto.Oneof{Name: "result"}
|
||||
oneof.Elements = []proto.Visitee{
|
||||
&proto.OneOfField{Field: &proto.Field{Name: "ok", Type: "SuccessMsg"}},
|
||||
&proto.OneOfField{Field: &proto.Field{Name: "err", Type: "FailureMsg"}},
|
||||
}
|
||||
protoData := parser.Proto{
|
||||
Name: "oneof.proto",
|
||||
PbPackage: "pb",
|
||||
Message: []parser.Message{
|
||||
{Message: &proto.Message{Name: "AReq"}},
|
||||
{Message: &proto.Message{
|
||||
Name: "AResp",
|
||||
Elements: []proto.Visitee{oneof},
|
||||
}},
|
||||
{Message: &proto.Message{Name: "SuccessMsg"}},
|
||||
{Message: &proto.Message{Name: "FailureMsg"}},
|
||||
},
|
||||
Service: parser.Services{
|
||||
{
|
||||
Service: &proto.Service{Name: "ServiceA"},
|
||||
RPC: []*parser.RPC{
|
||||
{RPC: &proto.RPC{Name: "Do", RequestType: "AReq", ReturnsType: "AResp"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cfg, err := conf.NewConfig("")
|
||||
require.NoError(t, err)
|
||||
bFile := normalizeWS(string(bContent))
|
||||
require.NoError(t, NewGenerator("gozero", false).genCallGroup(mctx, protoData, cfg))
|
||||
|
||||
assert.Contains(t, bFile, "BReq = pb.BReq", "ServiceB file should alias BReq")
|
||||
assert.Contains(t, bFile, "BResp = pb.BResp", "ServiceB file should alias BResp")
|
||||
assert.NotContains(t, bFile, "AReq = pb.AReq", "ServiceB file must not alias AReq")
|
||||
assert.NotContains(t, bFile, "AResp = pb.AResp", "ServiceB file must not alias AResp")
|
||||
aFile := normalizeWS(readGenFile(t, callBase, "servicea", "servicea.go"))
|
||||
assert.Contains(t, aFile, "SuccessMsg = pb.SuccessMsg", "oneof type SuccessMsg must be aliased")
|
||||
assert.Contains(t, aFile, "FailureMsg = pb.FailureMsg", "oneof type FailureMsg must be aliased")
|
||||
}
|
||||
|
||||
// TestGenCallGroup_MultiLevelTransitiveAliased verifies that a dependency chain
|
||||
// AResp → BMsg → CMsg causes all three types to be aliased in the client file.
|
||||
func TestGenCallGroup_MultiLevelTransitiveAliased(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
callBase := filepath.Join(tmpDir, "call")
|
||||
pbBase := filepath.Join(tmpDir, "pb")
|
||||
|
||||
mctx := newTestDirContext(t, callBase, pbBase, "ServiceA")
|
||||
|
||||
protoData := parser.Proto{
|
||||
Name: "transitive.proto",
|
||||
PbPackage: "pb",
|
||||
Message: []parser.Message{
|
||||
{Message: &proto.Message{Name: "AReq"}},
|
||||
{Message: &proto.Message{
|
||||
Name: "AResp",
|
||||
Elements: []proto.Visitee{
|
||||
&proto.NormalField{Field: &proto.Field{Name: "b", Type: "BMsg"}},
|
||||
},
|
||||
}},
|
||||
{Message: &proto.Message{
|
||||
Name: "BMsg",
|
||||
Elements: []proto.Visitee{
|
||||
&proto.NormalField{Field: &proto.Field{Name: "c", Type: "CMsg"}},
|
||||
},
|
||||
}},
|
||||
{Message: &proto.Message{Name: "CMsg"}},
|
||||
},
|
||||
Service: parser.Services{
|
||||
{
|
||||
Service: &proto.Service{Name: "ServiceA"},
|
||||
RPC: []*parser.RPC{
|
||||
{RPC: &proto.RPC{Name: "Do", RequestType: "AReq", ReturnsType: "AResp"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cfg, err := conf.NewConfig("")
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, NewGenerator("gozero", false).genCallGroup(mctx, protoData, cfg))
|
||||
|
||||
aFile := normalizeWS(readGenFile(t, callBase, "servicea", "servicea.go"))
|
||||
assert.Contains(t, aFile, "AResp = pb.AResp")
|
||||
assert.Contains(t, aFile, "BMsg = pb.BMsg", "BMsg must be transitively aliased via AResp")
|
||||
assert.Contains(t, aFile, "CMsg = pb.CMsg", "CMsg must be transitively aliased via BMsg")
|
||||
}
|
||||
|
||||
// readGenFile reads a generated file relative to callBase and returns its content.
|
||||
func readGenFile(t *testing.T, callBase string, parts ...string) string {
|
||||
t.Helper()
|
||||
content, err := os.ReadFile(filepath.Join(append([]string{callBase}, parts...)...))
|
||||
require.NoError(t, err)
|
||||
return string(content)
|
||||
}
|
||||
|
||||
// normalizeWS replaces runs of whitespace with a single space.
|
||||
|
||||
Reference in New Issue
Block a user