Skip to content

Commit

Permalink
ensureLoaded => getUnderlying
Browse files Browse the repository at this point in the history
  • Loading branch information
xushiwei committed Aug 15, 2021
1 parent 6013a45 commit e4762fe
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 23 deletions.
10 changes: 5 additions & 5 deletions builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,7 @@ retry:
return true
}
case *types.Named:
typ = t.Underlying()
typ = pkg.cb.getUnderlying(t)
goto retry
}
return false
Expand Down Expand Up @@ -686,7 +686,7 @@ retry:
case *types.Basic:
return t.Kind() != types.UntypedNil // excluding nil
case *types.Named:
typ = t.Underlying()
typ = pkg.cb.getUnderlying(t)
goto retry
case *types.Slice: // slice/map/func is very special
return false
Expand Down Expand Up @@ -744,7 +744,7 @@ retry:
_, ok := t.Elem().(*types.Array) // array_pointer
return ok
case *types.Named:
typ = t.Underlying()
typ = pkg.cb.getUnderlying(t)
goto retry
}
return false
Expand All @@ -770,7 +770,7 @@ retry:
case *types.Map:
return true
case *types.Named:
typ = t.Underlying()
typ = pkg.cb.getUnderlying(t)
goto retry
}
return capable.Match(pkg, typ)
Expand All @@ -796,7 +796,7 @@ retry:
case *types.Chan:
return true
case *types.Named:
typ = t.Underlying()
typ = pkg.cb.getUnderlying(t)
goto retry
}
return false
Expand Down
26 changes: 22 additions & 4 deletions builtin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,19 +113,37 @@ func TestNodeInterp(t *testing.T) {
}
}

func TestEnsureLoaded(t *testing.T) {
var cb CodeBuilder
func TestGetUnderlying(t *testing.T) {
var pkg = new(Package)
var cb = &pkg.cb
cb.loadNamed = func(at *Package, t *types.Named) {
panic("loadNamed")
}
defaultLoadNamed(nil, nil)
defer func() {
if e := recover(); e != "loadNamed" {
t.Fatal("TestGetUnderlying failed")
}
}()
named := types.NewNamed(types.NewTypeName(0, nil, "foo", nil), nil, nil)
cb.getUnderlying(named)
}

func TestGetUnderlying2(t *testing.T) {
var pkg = new(Package)
var cb = &pkg.cb
cb.pkg = pkg
cb.loadNamed = func(at *Package, t *types.Named) {
panic("loadNamed")
}
defaultLoadNamed(nil, nil)
defer func() {
if e := recover(); e != "loadNamed" {
t.Fatal("TestEnsureLoaded failed")
t.Fatal("TestGetUnderlying2 failed")
}
}()
named := types.NewNamed(types.NewTypeName(0, nil, "foo", nil), nil, nil)
cb.ensureLoaded(named)
getUnderlying(pkg, named)
}

func TestWriteFile(t *testing.T) {
Expand Down
38 changes: 26 additions & 12 deletions codebuild.go
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,7 @@ retry:
case *types.Chan:
return p.Val(nil)
case *types.Named:
typ = t.Underlying()
typ = p.getUnderlying(t)
goto retry
}
ret := &ast.CompositeLit{}
Expand Down Expand Up @@ -757,7 +757,7 @@ func (p *CodeBuilder) MapLit(typ types.Type, arity int) *CodeBuilder {
switch tt := typ.(type) {
case *types.Named:
typExpr = toNamedType(pkg, tt)
t = tt.Underlying().(*types.Map)
t = p.getUnderlying(tt).(*types.Map)
case *types.Map:
typExpr = toMapType(pkg, tt)
t = tt
Expand Down Expand Up @@ -867,7 +867,7 @@ func (p *CodeBuilder) SliceLit(typ types.Type, arity int, keyVal ...bool) *CodeB
switch tt := typ.(type) {
case *types.Named:
typExpr = toNamedType(pkg, tt)
t = tt.Underlying().(*types.Slice)
t = p.getUnderlying(tt).(*types.Slice)
case *types.Slice:
typExpr = toSliceType(pkg, tt)
t = tt
Expand Down Expand Up @@ -941,7 +941,7 @@ func (p *CodeBuilder) ArrayLit(typ types.Type, arity int, keyVal ...bool) *CodeB
switch tt := typ.(type) {
case *types.Named:
typExpr = toNamedType(pkg, tt)
t = tt.Underlying().(*types.Array)
t = p.getUnderlying(tt).(*types.Array)
case *types.Array:
typExpr = toArrayType(pkg, tt)
t = tt
Expand Down Expand Up @@ -1004,7 +1004,7 @@ func (p *CodeBuilder) StructLit(typ types.Type, arity int, keyVal bool) *CodeBui
switch tt := typ.(type) {
case *types.Named:
typExpr = toNamedType(pkg, tt)
t = tt.Underlying().(*types.Struct)
t = p.getUnderlying(tt).(*types.Struct)
case *types.Struct:
typExpr = toStructType(pkg, tt)
t = tt
Expand Down Expand Up @@ -1346,7 +1346,7 @@ func (p *CodeBuilder) MemberRef(name string, src ...ast.Node) *CodeBuilder {
arg := p.stk.Get(-1)
switch o := indirect(arg.Type).(type) {
case *types.Named:
if struc, ok := o.Underlying().(*types.Struct); ok {
if struc, ok := p.getUnderlying(o).(*types.Struct); ok {
if p.fieldRef(arg.Val, struc, name) {
return p
}
Expand Down Expand Up @@ -1415,22 +1415,36 @@ func (p *CodeBuilder) Member(name string, src ...ast.Node) (kind MemberKind, err
&pos, fmt.Sprintf("%s undefined (type %v has no field or method %s)", code, arg.Type, name))
}

func (p *CodeBuilder) ensureLoaded(t *types.Named) {
if t.Underlying() == nil {
func (p *CodeBuilder) getUnderlying(t *types.Named) types.Type {
u := t.Underlying()
if u == nil {
p.loadNamed(p.pkg, t)
u = t.Underlying()
}
return u
}

func getUnderlying(pkg *Package, typ types.Type) types.Type {
u := typ.Underlying()
if u == nil {
if t, ok := typ.(*types.Named); ok {
pkg.cb.loadNamed(pkg, t)
u = t.Underlying()
}
}
return u
}

func (p *CodeBuilder) findMember(typ types.Type, name string, argVal ast.Expr, srcExpr ast.Node) MemberKind {
switch o := typ.(type) {
case *types.Pointer:
switch t := o.Elem().(type) {
case *types.Named:
p.ensureLoaded(t)
u := p.getUnderlying(t)
if p.method(t, name, argVal, srcExpr) {
return MemberMethod
}
if struc, ok := t.Underlying().(*types.Struct); ok {
if struc, ok := u.(*types.Struct); ok {
if kind := p.field(struc, name, argVal, srcExpr); kind != 0 {
return kind
}
Expand All @@ -1441,11 +1455,11 @@ func (p *CodeBuilder) findMember(typ types.Type, name string, argVal ast.Expr, s
}
}
case *types.Named:
p.ensureLoaded(o)
u := p.getUnderlying(o)
if p.method(o, name, argVal, srcExpr) {
return MemberMethod
}
switch t := o.Underlying().(type) {
switch t := u.(type) {
case *types.Struct:
if kind := p.field(t, name, argVal, srcExpr); kind != 0 {
return kind
Expand Down
2 changes: 1 addition & 1 deletion func.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ func (p *Package) NewFuncWith(
return nil, cb.newCodePosErrorf(
getRecv(recvTypePos), "invalid receiver type %v (%v is not a defined type)", typ, typ)
}
switch t.Obj().Type().Underlying().(type) {
switch getUnderlying(p, t.Obj().Type()).(type) {
case *types.Interface:
return nil, cb.newCodePosErrorf(
getRecv(recvTypePos), "invalid receiver type %v (%v is an interface type)", typ, typ)
Expand Down
2 changes: 1 addition & 1 deletion template.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ func assignable(pkg *Package, v types.Type, t *types.Named, expr *ast.Expr) bool

func ComparableTo(pkg *Package, V, T types.Type) bool {
V, T = types.Default(V), types.Default(T)
if V != T && V.Underlying() != T.Underlying() {
if V != T && getUnderlying(pkg, V) != getUnderlying(pkg, T) {
return false
}
return types.Comparable(V)
Expand Down

0 comments on commit e4762fe

Please sign in to comment.