Skip to content

Commit 6da9266

Browse files
authored
Merge pull request #465 from aofei/CodeBuilder
fix(CodeBuilder): prevent infinite recursion in struct field lookup
2 parents 682e395 + b5ebdba commit 6da9266

File tree

3 files changed

+47
-10
lines changed

3 files changed

+47
-10
lines changed

builtin.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1142,7 +1142,7 @@ func (p addableT) Match(pkg *Package, typ types.Type) bool {
11421142
// TODO: refactor
11431143
cb := pkg.cb
11441144
cb.stk.Push(elemNone)
1145-
kind := cb.findMember(typ, "Gop_Add", "", MemberFlagVal, &Element{}, nil)
1145+
kind := cb.findMember(typ, "Gop_Add", "", MemberFlagVal, &Element{}, nil, nil)
11461146
if kind != 0 {
11471147
cb.stk.PopN(1)
11481148
if kind == MemberMethod {

codebuild.go

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1569,7 +1569,7 @@ func (p *CodeBuilder) Member(name string, flag MemberFlag, src ...ast.Node) (kin
15691569
flag = memberFlagMethodToFunc
15701570
}
15711571
aliasName, flag = aliasNameOf(name, flag)
1572-
kind = p.findMember(at, name, aliasName, flag, arg, srcExpr)
1572+
kind = p.findMember(at, name, aliasName, flag, arg, srcExpr, nil)
15731573
if isType && kind != MemberMethod {
15741574
code, pos := p.loadExpr(srcExpr)
15751575
return MemberInvalid, p.newCodeError(
@@ -1617,7 +1617,7 @@ func getUnderlying(pkg *Package, typ types.Type) types.Type {
16171617
}
16181618

16191619
func (p *CodeBuilder) findMember(
1620-
typ types.Type, name, aliasName string, flag MemberFlag, arg *Element, srcExpr ast.Node) MemberKind {
1620+
typ types.Type, name, aliasName string, flag MemberFlag, arg *Element, srcExpr ast.Node, visited map[*types.Struct]none) MemberKind {
16211621
var named *types.Named
16221622
retry:
16231623
switch o := typ.(type) {
@@ -1635,10 +1635,10 @@ retry:
16351635
return kind
16361636
}
16371637
if fstruc {
1638-
return p.embeddedField(struc, name, aliasName, flag, arg, srcExpr)
1638+
return p.embeddedField(struc, name, aliasName, flag, arg, srcExpr, visited)
16391639
}
16401640
case *types.Struct:
1641-
if kind := p.field(t, name, aliasName, flag, arg, srcExpr); kind != MemberInvalid {
1641+
if kind := p.field(t, name, aliasName, flag, arg, srcExpr, visited); kind != MemberInvalid {
16421642
return kind
16431643
}
16441644
}
@@ -1649,7 +1649,7 @@ retry:
16491649
}
16501650
goto retry
16511651
case *types.Struct:
1652-
if kind := p.field(o, name, aliasName, flag, arg, srcExpr); kind != MemberInvalid {
1652+
if kind := p.field(o, name, aliasName, flag, arg, srcExpr, visited); kind != MemberInvalid {
16531653
return kind
16541654
}
16551655
if named != nil {
@@ -1829,10 +1829,18 @@ func (p *CodeBuilder) normalField(
18291829
}
18301830

18311831
func (p *CodeBuilder) embeddedField(
1832-
o *types.Struct, name, aliasName string, flag MemberFlag, arg *Element, src ast.Node) MemberKind {
1832+
o *types.Struct, name, aliasName string, flag MemberFlag, arg *Element, src ast.Node, visited map[*types.Struct]none) MemberKind {
1833+
if visited == nil {
1834+
visited = make(map[*types.Struct]none)
1835+
}
1836+
if _, ok := visited[o]; ok {
1837+
return MemberInvalid
1838+
}
1839+
visited[o] = none{}
1840+
18331841
for i, n := 0, o.NumFields(); i < n; i++ {
18341842
if fld := o.Field(i); fld.Embedded() {
1835-
if kind := p.findMember(fld.Type(), name, aliasName, flag, arg, src); kind != MemberInvalid {
1843+
if kind := p.findMember(fld.Type(), name, aliasName, flag, arg, src, visited); kind != MemberInvalid {
18361844
return kind
18371845
}
18381846
}
@@ -1841,11 +1849,11 @@ func (p *CodeBuilder) embeddedField(
18411849
}
18421850

18431851
func (p *CodeBuilder) field(
1844-
o *types.Struct, name, aliasName string, flag MemberFlag, arg *Element, src ast.Node) MemberKind {
1852+
o *types.Struct, name, aliasName string, flag MemberFlag, arg *Element, src ast.Node, visited map[*types.Struct]none) MemberKind {
18451853
if kind := p.normalField(o, name, arg, src); kind != MemberInvalid {
18461854
return kind
18471855
}
1848-
return p.embeddedField(o, name, aliasName, flag, arg, src)
1856+
return p.embeddedField(o, name, aliasName, flag, arg, src, visited)
18491857
}
18501858

18511859
func toFuncSig(sig *types.Signature, recv *types.Var) *types.Signature {

codebuild_test.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
package gogen
2+
3+
import (
4+
"go/token"
5+
"go/types"
6+
"testing"
7+
)
8+
9+
func TestCircularEmbeddedFieldLookup(t *testing.T) {
10+
pkg := NewPackage("", "foo", nil)
11+
cb := pkg.CB()
12+
13+
typeA := types.NewNamed(types.NewTypeName(token.NoPos, pkg.Types, "A", nil), nil, nil)
14+
typeB := types.NewNamed(types.NewTypeName(token.NoPos, pkg.Types, "B", nil), nil, nil)
15+
16+
// Creates a circular embedding relationship between type A and B.
17+
typeA.SetUnderlying(types.NewStruct([]*types.Var{
18+
types.NewField(token.NoPos, pkg.Types, "", typeB, true), // Embed B.
19+
}, nil))
20+
typeB.SetUnderlying(types.NewStruct([]*types.Var{
21+
types.NewField(token.NoPos, pkg.Types, "", typeA, true), // Embed A.
22+
}, nil))
23+
24+
cb.stk.Push(&Element{Type: typeA})
25+
kind, _ := cb.Member("any", MemberFlagVal)
26+
if kind != MemberInvalid {
27+
t.Fatal("Member should return MemberInvalid for circular embedding")
28+
}
29+
}

0 commit comments

Comments
 (0)