Compare commits

...

2 Commits

Author SHA1 Message Date
Puneet Dixit
f910257ec9 fix(goctl): include nested client aliases (#5627)
Co-authored-by: Deepak kudi <deepakkudi23@adsl-172-10-9-116.dsl.sndg02.sbcglobal.net>
Co-authored-by: kevin <wanjunfeng@gmail.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-06-27 16:09:03 +00:00
YangWJ
d318de1212 fix(mapping): correct unmarshaling of pointer-to-slice fields (#5662)
Co-authored-by: kevin <wanjunfeng@gmail.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-06-27 14:43:41 +00:00
4 changed files with 587 additions and 70 deletions

View File

@@ -931,6 +931,113 @@ func TestUnmarshalJsonArray(t *testing.T) {
assert.Equal(t, 18, v[0].Age)
}
func TestUnmarshalJsonBytesPointerSliceUint64(t *testing.T) {
t.Run("with values", func(t *testing.T) {
var c struct {
IDs *[]uint64 `json:"ids,optional"`
}
content := []byte(`{"ids":[9000,9001]}`)
assert.Nil(t, UnmarshalJsonBytes(content, &c))
assert.NotNil(t, c.IDs)
assert.Equal(t, []uint64{9000, 9001}, *c.IDs)
})
t.Run("omitted", func(t *testing.T) {
var c struct {
IDs *[]uint64 `json:"ids,optional"`
}
content := []byte(`{}`)
assert.Nil(t, UnmarshalJsonBytes(content, &c))
assert.Nil(t, c.IDs)
})
t.Run("null", func(t *testing.T) {
var c struct {
IDs *[]uint64 `json:"ids,optional"`
}
content := []byte(`{"ids":null}`)
assert.Nil(t, UnmarshalJsonBytes(content, &c))
assert.Nil(t, c.IDs)
})
t.Run("empty array", func(t *testing.T) {
var c struct {
IDs *[]uint64 `json:"ids,optional"`
}
content := []byte(`{"ids":[]}`)
assert.Nil(t, UnmarshalJsonBytes(content, &c))
assert.NotNil(t, c.IDs)
assert.Equal(t, []uint64{}, *c.IDs)
})
}
func TestUnmarshalJsonBytesPointerSliceOtherTypes(t *testing.T) {
t.Run("pointer to []string", func(t *testing.T) {
var c struct {
Names *[]string `json:"names,optional"`
}
content := []byte(`{"names":["a","b"]}`)
assert.Nil(t, UnmarshalJsonBytes(content, &c))
assert.NotNil(t, c.Names)
assert.Equal(t, []string{"a", "b"}, *c.Names)
})
t.Run("pointer to []int", func(t *testing.T) {
var c struct {
Values *[]int `json:"values,optional"`
}
content := []byte(`{"values":[1,2,3]}`)
assert.Nil(t, UnmarshalJsonBytes(content, &c))
assert.NotNil(t, c.Values)
assert.Equal(t, []int{1, 2, 3}, *c.Values)
})
}
func TestUnmarshalJsonBytesPointerSliceStruct(t *testing.T) {
type Item struct {
Name string `json:"name"`
Age int `json:"age"`
}
t.Run("with values", func(t *testing.T) {
var c struct {
Items *[]Item `json:"items,optional"`
}
content := []byte(`{"items":[{"name":"alice","age":30},{"name":"bob","age":25}]}`)
assert.Nil(t, UnmarshalJsonBytes(content, &c))
assert.NotNil(t, c.Items)
assert.Equal(t, []Item{{Name: "alice", Age: 30}, {Name: "bob", Age: 25}}, *c.Items)
})
t.Run("omitted", func(t *testing.T) {
var c struct {
Items *[]Item `json:"items,optional"`
}
content := []byte(`{}`)
assert.Nil(t, UnmarshalJsonBytes(content, &c))
assert.Nil(t, c.Items)
})
t.Run("empty array", func(t *testing.T) {
var c struct {
Items *[]Item `json:"items,optional"`
}
content := []byte(`{"items":[]}`)
assert.Nil(t, UnmarshalJsonBytes(content, &c))
assert.NotNil(t, c.Items)
assert.Equal(t, []Item{}, *c.Items)
})
}
func TestUnmarshalJsonBytesError(t *testing.T) {
var v []struct {
Name string `json:"name"`

View File

@@ -142,11 +142,11 @@ func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value,
return nil
}
baseType := fieldType.Elem()
baseType := Deref(fieldType).Elem()
dereffedBaseType := Deref(baseType)
dereffedBaseKind := dereffedBaseType.Kind()
if refValue.Len() == 0 {
value.Set(reflect.MakeSlice(reflect.SliceOf(baseType), 0, 0))
SetValue(fieldType, value, reflect.MakeSlice(reflect.SliceOf(baseType), 0, 0))
return nil
}
@@ -179,7 +179,7 @@ func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value,
}
if valid {
value.Set(conv)
SetValue(fieldType, value, conv)
}
return nil
@@ -201,7 +201,7 @@ func (u *Unmarshaler) fillSliceFromString(fieldType reflect.Type, value reflect.
return errUnsupportedType
}
baseFieldType := fieldType.Elem()
baseFieldType := Deref(fieldType).Elem()
baseFieldKind := baseFieldType.Kind()
conv := reflect.MakeSlice(reflect.SliceOf(baseFieldType), len(slice), cap(slice))
@@ -211,7 +211,7 @@ func (u *Unmarshaler) fillSliceFromString(fieldType reflect.Type, value reflect.
}
}
value.Set(conv)
SetValue(fieldType, value, conv)
return nil
}

View File

@@ -67,12 +67,9 @@ func (g *Generator) genCallGroup(ctx DirContext, proto parser.Proto, cfg *conf.C
serviceName := stringx.From(service.Name).ToCamel()
// Collect only the message types actually used by this service's RPCs,
// so that each client file only aliases its own request/response types.
usedTypes := collection.NewSet[string]()
for _, rpc := range service.RPC {
usedTypes.Add(parser.CamelCase(rpc.RequestType))
usedTypes.Add(parser.CamelCase(rpc.ReturnsType))
}
// so that each client file only aliases its own request/response types
// and their same-file message dependencies.
usedTypes := collectServiceUsedTypes(proto.Message, service)
alias := collection.NewSet[string]()
var hasSameNameBetweenMessageAndService bool
@@ -337,17 +334,85 @@ func (g *Generator) getInterfaceFuncs(goPackage, mainGoPackage string, service p
return functions, nil
}
// collectServiceUsedTypes returns the set of CamelCase message names that are
// reachable from any of the service's RPC request or response types via field
// references within the same proto file. This ensures per-service client files
// alias their own request/response types and all transitively-referenced message
// types, but never unrelated messages from other services.
func collectServiceUsedTypes(messages []parser.Message, service parser.Service) *collection.Set[string] {
messageByName := make(map[string]*proto.Message, len(messages))
for _, item := range messages {
msgName := parser.CamelCase(getMessageName(*item.Message))
messageByName[msgName] = item.Message
}
usedTypes := collection.NewSet[string]()
for _, rpc := range service.RPC {
collectMessageDependencies(rpc.RequestType, messageByName, usedTypes)
collectMessageDependencies(rpc.ReturnsType, messageByName, usedTypes)
}
return usedTypes
}
// collectMessageDependencies recursively adds protoType and all message types
// referenced by its fields into usedTypes, looking up messages by CamelCase
// name in messageByName. The cycle guard (usedTypes.Contains) prevents
// infinite recursion on circular field references.
func collectMessageDependencies(protoType string, messageByName map[string]*proto.Message,
usedTypes *collection.Set[string]) {
for _, candidate := range messageTypeCandidates(protoType) {
msg, ok := messageByName[candidate]
if !ok {
continue
}
if usedTypes.Contains(candidate) {
return
}
usedTypes.Add(candidate)
for _, elem := range msg.Elements {
switch field := elem.(type) {
case *proto.NormalField:
collectMessageDependencies(field.Type, messageByName, usedTypes)
case *proto.MapField:
// Map key types are always scalars in proto3; only the value type
// can be a message.
collectMessageDependencies(field.Type, messageByName, usedTypes)
case *proto.Oneof:
for _, oneofElem := range field.Elements {
if oneofField, ok := oneofElem.(*proto.OneOfField); ok {
collectMessageDependencies(oneofField.Type, messageByName, usedTypes)
}
}
}
}
return
}
}
// messageTypeCandidates returns the CamelCase lookup keys to try for a proto
// field type. Two candidates are produced to handle both simple names
// ("MyMsg") and dotted/qualified names ("pkg.MyMsg" → "PkgMyMsg").
func messageTypeCandidates(protoType string) []string {
protoType = strings.TrimPrefix(protoType, ".")
return []string{
parser.CamelCase(protoType),
parser.CamelCase(strings.ReplaceAll(protoType, ".", "_")),
}
}
// buildExtraImportLines converts a set of import paths into quoted import lines
// for use in the call.tpl {{.extraImports}} placeholder.
func buildExtraImportLines(extraImports *collection.Set[string]) string {
if extraImports.Count() == 0 {
return ""
}
keys := extraImports.Keys()
sort.Strings(keys)
lines := make([]string, 0, len(keys))
for _, k := range keys {
lines = append(lines, fmt.Sprintf(`"%s"`, k))
}
return strings.Join(lines, "\n\t")
if extraImports.Count() == 0 {
return ""
}
keys := extraImports.Keys()
sort.Strings(keys)
lines := make([]string, 0, len(keys))
for _, k := range keys {
lines = append(lines, fmt.Sprintf(`"%s"`, k))
}
return strings.Join(lines, "\n\t")
}

View File

@@ -34,50 +34,261 @@ func (m *mockDirContext) GetMain() Dir { return Dir{} }
func (m *mockDirContext) GetServiceName() stringx.String { return stringx.From("test") }
func (m *mockDirContext) SetPbDir(pbDir, grpcDir string) {}
// TestGenCallGroup_OnlyUsedTypesAliased verifies that in multi-service mode each
// generated client file contains type aliases only for the message types actually
// used by that service's RPCs (fix for issue #5481).
// newTestDirContext builds a mockDirContext that writes generated files under
// callBase, with a pb directory that differs (so alias generation is triggered).
func newTestDirContext(t *testing.T, callBase, pbBase string, services ...string) *mockDirContext {
t.Helper()
for _, svc := range services {
require.NoError(t, os.MkdirAll(filepath.Join(callBase, strings.ToLower(svc)), 0755))
}
require.NoError(t, os.MkdirAll(pbBase, 0755))
return &mockDirContext{
callDir: Dir{
Filename: callBase,
Package: "example.com/test/call",
Base: "call",
GetChildPackage: func(childPath string) (string, error) {
return filepath.Join(callBase, strings.ToLower(childPath)), nil
},
},
pbDir: Dir{Filename: pbBase, Package: "example.com/test/pb", Base: "pb"},
protoGo: Dir{
// Must differ from service dir names so isCallPkgSameToPbPkg stays
// false and alias generation is triggered.
Filename: pbBase,
Package: "example.com/test/pb",
Base: "pb",
},
}
}
// ---- unit tests for collectServiceUsedTypes --------------------------------
// TestCollectServiceUsedTypes_DirectOnly verifies that request and response
// types with no message fields are collected as-is.
func TestCollectServiceUsedTypes_DirectOnly(t *testing.T) {
messages := []parser.Message{
{Message: &proto.Message{Name: "AReq"}},
{Message: &proto.Message{Name: "AResp"}},
{Message: &proto.Message{Name: "Unrelated"}},
}
service := parser.Service{
Service: &proto.Service{Name: "ServiceA"},
RPC: []*parser.RPC{
{RPC: &proto.RPC{Name: "Do", RequestType: "AReq", ReturnsType: "AResp"}},
},
}
got := collectServiceUsedTypes(messages, service)
assert.True(t, got.Contains("AReq"))
assert.True(t, got.Contains("AResp"))
assert.False(t, got.Contains("Unrelated"), "unrelated message must not be collected")
}
// TestCollectServiceUsedTypes_NestedNormalField verifies that a message type
// referenced via a NormalField inside a response is transitively collected
// (regression test for issue #5618).
func TestCollectServiceUsedTypes_NestedNormalField(t *testing.T) {
messages := []parser.Message{
{Message: &proto.Message{Name: "AReq"}},
{Message: &proto.Message{
Name: "AResp",
Elements: []proto.Visitee{
&proto.NormalField{Field: &proto.Field{Name: "items", Type: "AItem"}},
},
}},
{Message: &proto.Message{Name: "AItem"}},
}
service := parser.Service{
Service: &proto.Service{Name: "ServiceA"},
RPC: []*parser.RPC{
{RPC: &proto.RPC{Name: "List", RequestType: "AReq", ReturnsType: "AResp"}},
},
}
got := collectServiceUsedTypes(messages, service)
assert.True(t, got.Contains("AReq"))
assert.True(t, got.Contains("AResp"))
assert.True(t, got.Contains("AItem"), "field type AItem must be transitively collected")
}
// TestCollectServiceUsedTypes_MapValueField verifies that the value type of a
// MapField inside a response message is transitively collected.
func TestCollectServiceUsedTypes_MapValueField(t *testing.T) {
messages := []parser.Message{
{Message: &proto.Message{Name: "AReq"}},
{Message: &proto.Message{
Name: "AResp",
Elements: []proto.Visitee{
&proto.MapField{KeyType: "string", Field: &proto.Field{Name: "index", Type: "AItem"}},
},
}},
{Message: &proto.Message{Name: "AItem"}},
}
service := parser.Service{
Service: &proto.Service{Name: "ServiceA"},
RPC: []*parser.RPC{
{RPC: &proto.RPC{Name: "GetMap", RequestType: "AReq", ReturnsType: "AResp"}},
},
}
got := collectServiceUsedTypes(messages, service)
assert.True(t, got.Contains("AResp"))
assert.True(t, got.Contains("AItem"), "map value type AItem must be transitively collected")
}
// TestCollectServiceUsedTypes_OneofField verifies that message types referenced
// inside a Oneof element are transitively collected.
func TestCollectServiceUsedTypes_OneofField(t *testing.T) {
oneof := &proto.Oneof{Name: "result"}
oneof.Elements = []proto.Visitee{
&proto.OneOfField{Field: &proto.Field{Name: "success", Type: "SuccessMsg"}},
&proto.OneOfField{Field: &proto.Field{Name: "failure", Type: "FailureMsg"}},
}
messages := []parser.Message{
{Message: &proto.Message{Name: "AReq"}},
{Message: &proto.Message{
Name: "AResp",
Elements: []proto.Visitee{oneof},
}},
{Message: &proto.Message{Name: "SuccessMsg"}},
{Message: &proto.Message{Name: "FailureMsg"}},
}
service := parser.Service{
Service: &proto.Service{Name: "ServiceA"},
RPC: []*parser.RPC{
{RPC: &proto.RPC{Name: "Do", RequestType: "AReq", ReturnsType: "AResp"}},
},
}
got := collectServiceUsedTypes(messages, service)
assert.True(t, got.Contains("AResp"))
assert.True(t, got.Contains("SuccessMsg"), "oneof field type SuccessMsg must be collected")
assert.True(t, got.Contains("FailureMsg"), "oneof field type FailureMsg must be collected")
}
// TestCollectServiceUsedTypes_MultiLevelTransitive verifies that a chain
// AResp → BMsg → CMsg is fully collected (multi-level transitivity).
func TestCollectServiceUsedTypes_MultiLevelTransitive(t *testing.T) {
messages := []parser.Message{
{Message: &proto.Message{Name: "AReq"}},
{Message: &proto.Message{
Name: "AResp",
Elements: []proto.Visitee{
&proto.NormalField{Field: &proto.Field{Name: "b", Type: "BMsg"}},
},
}},
{Message: &proto.Message{
Name: "BMsg",
Elements: []proto.Visitee{
&proto.NormalField{Field: &proto.Field{Name: "c", Type: "CMsg"}},
},
}},
{Message: &proto.Message{Name: "CMsg"}},
}
service := parser.Service{
Service: &proto.Service{Name: "ServiceA"},
RPC: []*parser.RPC{
{RPC: &proto.RPC{Name: "Do", RequestType: "AReq", ReturnsType: "AResp"}},
},
}
got := collectServiceUsedTypes(messages, service)
assert.True(t, got.Contains("AReq"))
assert.True(t, got.Contains("AResp"))
assert.True(t, got.Contains("BMsg"), "BMsg must be transitively collected via AResp")
assert.True(t, got.Contains("CMsg"), "CMsg must be transitively collected via BMsg")
}
// TestCollectServiceUsedTypes_CycleDetection verifies that circular field
// references (AResp ↔ BMsg) do not cause infinite recursion.
func TestCollectServiceUsedTypes_CycleDetection(t *testing.T) {
messages := []parser.Message{
{Message: &proto.Message{Name: "AReq"}},
{Message: &proto.Message{
Name: "AResp",
Elements: []proto.Visitee{
&proto.NormalField{Field: &proto.Field{Name: "b", Type: "BMsg"}},
},
}},
{Message: &proto.Message{
Name: "BMsg",
Elements: []proto.Visitee{
// circular back-reference to AResp
&proto.NormalField{Field: &proto.Field{Name: "a", Type: "AResp"}},
},
}},
}
service := parser.Service{
Service: &proto.Service{Name: "ServiceA"},
RPC: []*parser.RPC{
{RPC: &proto.RPC{Name: "Do", RequestType: "AReq", ReturnsType: "AResp"}},
},
}
// Must not panic or loop; both messages are reachable.
got := collectServiceUsedTypes(messages, service)
assert.True(t, got.Contains("AResp"))
assert.True(t, got.Contains("BMsg"))
}
// TestCollectServiceUsedTypes_ExcludesUnrelatedService verifies that messages
// belonging only to another service are not included.
func TestCollectServiceUsedTypes_ExcludesUnrelatedService(t *testing.T) {
messages := []parser.Message{
{Message: &proto.Message{Name: "AReq"}},
{Message: &proto.Message{Name: "AResp"}},
{Message: &proto.Message{Name: "BReq"}},
{Message: &proto.Message{Name: "BResp"}},
}
service := parser.Service{
Service: &proto.Service{Name: "ServiceA"},
RPC: []*parser.RPC{
{RPC: &proto.RPC{Name: "DoA", RequestType: "AReq", ReturnsType: "AResp"}},
},
}
got := collectServiceUsedTypes(messages, service)
assert.True(t, got.Contains("AReq"))
assert.True(t, got.Contains("AResp"))
assert.False(t, got.Contains("BReq"), "BReq belongs to ServiceB and must be excluded")
assert.False(t, got.Contains("BResp"), "BResp belongs to ServiceB and must be excluded")
}
// ---- integration tests via genCallGroup ------------------------------------
// TestGenCallGroup_OnlyUsedTypesAliased verifies that in multi-service mode
// each generated client file aliases only its own request/response types and
// their transitive field dependencies (fix for issues #5481 and #5618).
func TestGenCallGroup_OnlyUsedTypesAliased(t *testing.T) {
tmpDir := t.TempDir()
callBase := filepath.Join(tmpDir, "call")
pbBase := filepath.Join(tmpDir, "pb")
// Pre-create subdirs that genCallGroup will write into.
require.NoError(t, os.MkdirAll(filepath.Join(callBase, "servicea"), 0755))
require.NoError(t, os.MkdirAll(filepath.Join(callBase, "serviceb"), 0755))
require.NoError(t, os.MkdirAll(pbBase, 0755))
mctx := newTestDirContext(t, callBase, pbBase, "ServiceA", "ServiceB")
mctx := &mockDirContext{
callDir: Dir{
Filename: callBase,
Package: "example.com/multitest/call",
Base: "call",
GetChildPackage: func(childPath string) (string, error) {
// Return a package path whose Base() is the lowercase service name.
return filepath.Join(callBase, strings.ToLower(childPath)), nil
},
},
pbDir: Dir{
Filename: pbBase,
Package: "example.com/multitest/pb",
Base: "pb",
},
protoGo: Dir{
// Must differ from "servicea"/"serviceb" so isCallPkgSameToPbPkg stays false
// and alias generation is triggered.
Filename: pbBase,
Package: "example.com/multitest/pb",
Base: "pb",
},
}
// Proto with two services that use completely disjoint message types.
// ServiceA: AResp contains a NormalField of type AItem (issue #5618).
// ServiceB: BResp has no nested message fields.
// AItem must appear in ServiceA's file but not ServiceB's.
protoData := parser.Proto{
Name: "multi.proto",
PbPackage: "pb",
Message: []parser.Message{
{Message: &proto.Message{Name: "AReq"}},
{Message: &proto.Message{Name: "AResp"}},
{Message: &proto.Message{
Name: "AResp",
Elements: []proto.Visitee{
&proto.NormalField{Field: &proto.Field{Name: "items", Type: "AItem"}},
},
}},
{Message: &proto.Message{Name: "AItem"}},
{Message: &proto.Message{Name: "BReq"}},
{Message: &proto.Message{Name: "BResp"}},
},
@@ -99,29 +310,163 @@ func TestGenCallGroup_OnlyUsedTypesAliased(t *testing.T) {
cfg, err := conf.NewConfig("")
require.NoError(t, err)
require.NoError(t, NewGenerator("gozero", false).genCallGroup(mctx, protoData, cfg))
g := NewGenerator("gozero", false)
require.NoError(t, g.genCallGroup(mctx, protoData, cfg))
aFile := normalizeWS(readGenFile(t, callBase, "servicea", "servicea.go"))
assert.Contains(t, aFile, "AReq = pb.AReq", "ServiceA must alias AReq")
assert.Contains(t, aFile, "AResp = pb.AResp", "ServiceA must alias AResp")
assert.Contains(t, aFile, "AItem = pb.AItem", "ServiceA must alias AItem (transitive NormalField)")
assert.NotContains(t, aFile, "BReq = pb.BReq", "ServiceA must not alias BReq")
assert.NotContains(t, aFile, "BResp = pb.BResp", "ServiceA must not alias BResp")
// servicea/servicea.go — aliases for AReq/AResp only
aContent, err := os.ReadFile(filepath.Join(callBase, "servicea", "servicea.go"))
bFile := normalizeWS(readGenFile(t, callBase, "serviceb", "serviceb.go"))
assert.Contains(t, bFile, "BReq = pb.BReq", "ServiceB must alias BReq")
assert.Contains(t, bFile, "BResp = pb.BResp", "ServiceB must alias BResp")
assert.NotContains(t, bFile, "AReq = pb.AReq", "ServiceB must not alias AReq")
assert.NotContains(t, bFile, "AResp = pb.AResp", "ServiceB must not alias AResp")
assert.NotContains(t, bFile, "AItem = pb.AItem", "ServiceB must not alias AItem")
}
// TestGenCallGroup_MapValueAliased verifies that the value type of a MapField
// inside a service response is included in the generated aliases.
func TestGenCallGroup_MapValueAliased(t *testing.T) {
tmpDir := t.TempDir()
callBase := filepath.Join(tmpDir, "call")
pbBase := filepath.Join(tmpDir, "pb")
mctx := newTestDirContext(t, callBase, pbBase, "ServiceA")
protoData := parser.Proto{
Name: "map.proto",
PbPackage: "pb",
Message: []parser.Message{
{Message: &proto.Message{Name: "AReq"}},
{Message: &proto.Message{
Name: "AResp",
Elements: []proto.Visitee{
&proto.MapField{KeyType: "string", Field: &proto.Field{Name: "index", Type: "AItem"}},
},
}},
{Message: &proto.Message{Name: "AItem"}},
},
Service: parser.Services{
{
Service: &proto.Service{Name: "ServiceA"},
RPC: []*parser.RPC{
{RPC: &proto.RPC{Name: "GetMap", RequestType: "AReq", ReturnsType: "AResp"}},
},
},
},
}
cfg, err := conf.NewConfig("")
require.NoError(t, err)
aFile := normalizeWS(string(aContent))
require.NoError(t, NewGenerator("gozero", false).genCallGroup(mctx, protoData, cfg))
assert.Contains(t, aFile, "AReq = pb.AReq", "ServiceA file should alias AReq")
assert.Contains(t, aFile, "AResp = pb.AResp", "ServiceA file should alias AResp")
assert.NotContains(t, aFile, "BReq = pb.BReq", "ServiceA file must not alias BReq")
assert.NotContains(t, aFile, "BResp = pb.BResp", "ServiceA file must not alias BResp")
aFile := normalizeWS(readGenFile(t, callBase, "servicea", "servicea.go"))
assert.Contains(t, aFile, "AResp = pb.AResp")
assert.Contains(t, aFile, "AItem = pb.AItem", "map value type AItem must be aliased")
}
// serviceb/serviceb.go — aliases for BReq/BResp only
bContent, err := os.ReadFile(filepath.Join(callBase, "serviceb", "serviceb.go"))
// TestGenCallGroup_OneofAliased verifies that message types referenced inside a
// Oneof element are included in the generated aliases.
func TestGenCallGroup_OneofAliased(t *testing.T) {
tmpDir := t.TempDir()
callBase := filepath.Join(tmpDir, "call")
pbBase := filepath.Join(tmpDir, "pb")
mctx := newTestDirContext(t, callBase, pbBase, "ServiceA")
oneof := &proto.Oneof{Name: "result"}
oneof.Elements = []proto.Visitee{
&proto.OneOfField{Field: &proto.Field{Name: "ok", Type: "SuccessMsg"}},
&proto.OneOfField{Field: &proto.Field{Name: "err", Type: "FailureMsg"}},
}
protoData := parser.Proto{
Name: "oneof.proto",
PbPackage: "pb",
Message: []parser.Message{
{Message: &proto.Message{Name: "AReq"}},
{Message: &proto.Message{
Name: "AResp",
Elements: []proto.Visitee{oneof},
}},
{Message: &proto.Message{Name: "SuccessMsg"}},
{Message: &proto.Message{Name: "FailureMsg"}},
},
Service: parser.Services{
{
Service: &proto.Service{Name: "ServiceA"},
RPC: []*parser.RPC{
{RPC: &proto.RPC{Name: "Do", RequestType: "AReq", ReturnsType: "AResp"}},
},
},
},
}
cfg, err := conf.NewConfig("")
require.NoError(t, err)
bFile := normalizeWS(string(bContent))
require.NoError(t, NewGenerator("gozero", false).genCallGroup(mctx, protoData, cfg))
assert.Contains(t, bFile, "BReq = pb.BReq", "ServiceB file should alias BReq")
assert.Contains(t, bFile, "BResp = pb.BResp", "ServiceB file should alias BResp")
assert.NotContains(t, bFile, "AReq = pb.AReq", "ServiceB file must not alias AReq")
assert.NotContains(t, bFile, "AResp = pb.AResp", "ServiceB file must not alias AResp")
aFile := normalizeWS(readGenFile(t, callBase, "servicea", "servicea.go"))
assert.Contains(t, aFile, "SuccessMsg = pb.SuccessMsg", "oneof type SuccessMsg must be aliased")
assert.Contains(t, aFile, "FailureMsg = pb.FailureMsg", "oneof type FailureMsg must be aliased")
}
// TestGenCallGroup_MultiLevelTransitiveAliased verifies that a dependency chain
// AResp → BMsg → CMsg causes all three types to be aliased in the client file.
func TestGenCallGroup_MultiLevelTransitiveAliased(t *testing.T) {
tmpDir := t.TempDir()
callBase := filepath.Join(tmpDir, "call")
pbBase := filepath.Join(tmpDir, "pb")
mctx := newTestDirContext(t, callBase, pbBase, "ServiceA")
protoData := parser.Proto{
Name: "transitive.proto",
PbPackage: "pb",
Message: []parser.Message{
{Message: &proto.Message{Name: "AReq"}},
{Message: &proto.Message{
Name: "AResp",
Elements: []proto.Visitee{
&proto.NormalField{Field: &proto.Field{Name: "b", Type: "BMsg"}},
},
}},
{Message: &proto.Message{
Name: "BMsg",
Elements: []proto.Visitee{
&proto.NormalField{Field: &proto.Field{Name: "c", Type: "CMsg"}},
},
}},
{Message: &proto.Message{Name: "CMsg"}},
},
Service: parser.Services{
{
Service: &proto.Service{Name: "ServiceA"},
RPC: []*parser.RPC{
{RPC: &proto.RPC{Name: "Do", RequestType: "AReq", ReturnsType: "AResp"}},
},
},
},
}
cfg, err := conf.NewConfig("")
require.NoError(t, err)
require.NoError(t, NewGenerator("gozero", false).genCallGroup(mctx, protoData, cfg))
aFile := normalizeWS(readGenFile(t, callBase, "servicea", "servicea.go"))
assert.Contains(t, aFile, "AResp = pb.AResp")
assert.Contains(t, aFile, "BMsg = pb.BMsg", "BMsg must be transitively aliased via AResp")
assert.Contains(t, aFile, "CMsg = pb.CMsg", "CMsg must be transitively aliased via BMsg")
}
// readGenFile reads a generated file relative to callBase and returns its content.
func readGenFile(t *testing.T, callBase string, parts ...string) string {
t.Helper()
content, err := os.ReadFile(filepath.Join(append([]string{callBase}, parts...)...))
require.NoError(t, err)
return string(content)
}
// normalizeWS replaces runs of whitespace with a single space.