Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions compiler/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ type compilerContext struct {
pkg *types.Package
packageDir string // directory for this package
runtimePkg *types.Package
scopeIdx map[*types.Scope]int
}

// newCompilerContext returns a new compiler context ready for use, most
Expand All @@ -106,6 +107,7 @@ func newCompilerContext(moduleName string, machine llvm.TargetMachine, config *C
targetData: machine.CreateTargetData(),
functionInfos: map[*ssa.Function]functionInfo{},
astComments: map[string]*ast.CommentGroup{},
scopeIdx: map[*types.Scope]int{},
}

c.ctx = llvm.NewContext()
Expand Down
68 changes: 51 additions & 17 deletions compiler/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value {
}
}

typeCodeName, isLocal := getTypeCodeName(typ)
typeCodeName, isLocal := c.getTypeCodeName(typ)
globalName := "reflect/types.type:" + typeCodeName
var global llvm.Value
if isLocal {
Expand Down Expand Up @@ -511,23 +511,57 @@ var basicTypeNames = [...]string{
types.UnsafePointer: "unsafe.Pointer",
}

// return an integer representing this scope in a package.
func (c *compilerContext) getScopeID(pkg *types.Scope, scope *types.Scope) string {
var ids []int

for scope != pkg {
if idx, ok := c.scopeIdx[scope]; ok {
ids = append(ids, idx)
} else {
// flesh out scope idx for all children of this parent
parent := scope.Parent()
for i := range parent.NumChildren() {
child := parent.Child(i)
if child == scope {
c.scopeIdx[scope] = i
break
}
}
ids = append(ids, c.scopeIdx[scope])
}
// move up a level
scope = scope.Parent()
}

var buf []byte
for _, v := range ids {
buf = strconv.AppendInt(buf, int64(v), 10)
buf = append(buf, ':')
}

id := string(buf)
return id
}

// getTypeCodeName returns a name for this type that can be used in the
// interface lowering pass to assign type codes as expected by the reflect
// package. See getTypeCodeNum.
func getTypeCodeName(t types.Type) (string, bool) {
func (c *compilerContext) getTypeCodeName(t types.Type) (string, bool) {
switch t := types.Unalias(t).(type) {
case *types.Named:
if t.Obj().Parent() != t.Obj().Pkg().Scope() {
return "named:" + t.String() + "$local", true
parent, pkg := t.Obj().Parent(), t.Obj().Pkg().Scope()
if parent != pkg {
return fmt.Sprintf("named:%s$local:%s", t.String(), c.getScopeID(pkg, parent)), true
}
return "named:" + t.String(), false
case *types.Array:
s, isLocal := getTypeCodeName(t.Elem())
s, isLocal := c.getTypeCodeName(t.Elem())
return "array:" + strconv.FormatInt(t.Len(), 10) + ":" + s, isLocal
case *types.Basic:
return "basic:" + basicTypeNames[t.Kind()], false
case *types.Chan:
s, isLocal := getTypeCodeName(t.Elem())
s, isLocal := c.getTypeCodeName(t.Elem())
var dir string
switch t.Dir() {
case types.SendOnly:
Expand All @@ -547,41 +581,41 @@ func getTypeCodeName(t types.Type) (string, bool) {
if !token.IsExported(name) {
name = t.Method(i).Pkg().Path() + "." + name
}
s, local := getTypeCodeName(t.Method(i).Type())
s, local := c.getTypeCodeName(t.Method(i).Type())
if local {
isLocal = true
}
methods[i] = name + ":" + s
}
return "interface:" + "{" + strings.Join(methods, ",") + "}", isLocal
case *types.Map:
keyType, keyLocal := getTypeCodeName(t.Key())
elemType, elemLocal := getTypeCodeName(t.Elem())
keyType, keyLocal := c.getTypeCodeName(t.Key())
elemType, elemLocal := c.getTypeCodeName(t.Elem())
return "map:" + "{" + keyType + "," + elemType + "}", keyLocal || elemLocal
case *types.Pointer:
s, isLocal := getTypeCodeName(t.Elem())
s, isLocal := c.getTypeCodeName(t.Elem())
return "pointer:" + s, isLocal
case *types.Signature:
isLocal := false
params := make([]string, t.Params().Len())
for i := 0; i < t.Params().Len(); i++ {
s, local := getTypeCodeName(t.Params().At(i).Type())
s, local := c.getTypeCodeName(t.Params().At(i).Type())
if local {
isLocal = true
}
params[i] = s
}
results := make([]string, t.Results().Len())
for i := 0; i < t.Results().Len(); i++ {
s, local := getTypeCodeName(t.Results().At(i).Type())
s, local := c.getTypeCodeName(t.Results().At(i).Type())
if local {
isLocal = true
}
results[i] = s
}
return "func:" + "{" + strings.Join(params, ",") + "}{" + strings.Join(results, ",") + "}", isLocal
case *types.Slice:
s, isLocal := getTypeCodeName(t.Elem())
s, isLocal := c.getTypeCodeName(t.Elem())
return "slice:" + s, isLocal
case *types.Struct:
elems := make([]string, t.NumFields())
Expand All @@ -591,7 +625,7 @@ func getTypeCodeName(t types.Type) (string, bool) {
if t.Field(i).Embedded() {
embedded = "#"
}
s, local := getTypeCodeName(t.Field(i).Type())
s, local := c.getTypeCodeName(t.Field(i).Type())
if local {
isLocal = true
}
Expand Down Expand Up @@ -709,7 +743,7 @@ func (b *builder) createTypeAssert(expr *ssa.TypeAssert) llvm.Value {
commaOk = b.CreateCall(fn.GlobalValueType(), fn, []llvm.Value{actualTypeNum}, "")
}
} else {
name, _ := getTypeCodeName(expr.AssertedType)
name, _ := b.getTypeCodeName(expr.AssertedType)
globalName := "reflect/types.typeid:" + name
assertedTypeCodeGlobal := b.mod.NamedGlobal(globalName)
if assertedTypeCodeGlobal.IsNil() {
Expand Down Expand Up @@ -786,7 +820,7 @@ func (c *compilerContext) getMethodsString(itf *types.Interface) string {
// getInterfaceImplementsFunc returns a declared function that works as a type
// switch. The interface lowering pass will define this function.
func (c *compilerContext) getInterfaceImplementsFunc(assertedType types.Type) llvm.Value {
s, _ := getTypeCodeName(assertedType.Underlying())
s, _ := c.getTypeCodeName(assertedType.Underlying())
fnName := s + ".$typeassert"
llvmFn := c.mod.NamedFunction(fnName)
if llvmFn.IsNil() {
Expand All @@ -803,7 +837,7 @@ func (c *compilerContext) getInterfaceImplementsFunc(assertedType types.Type) ll
// thunk is declared, not defined: it will be defined by the interface lowering
// pass.
func (c *compilerContext) getInvokeFunction(instr *ssa.CallCommon) llvm.Value {
s, _ := getTypeCodeName(instr.Value.Type().Underlying())
s, _ := c.getTypeCodeName(instr.Value.Type().Underlying())
fnName := s + "." + instr.Method.Name() + "$invoke"
llvmFn := c.mod.NamedFunction(fnName)
if llvmFn.IsNil() {
Expand Down
34 changes: 34 additions & 0 deletions compiler/testdata/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,37 @@ func callFooMethod(itf fooInterface) uint8 {
func callErrorMethod(itf error) string {
return itf.Error()
}

func namedFoo() {
type Foo struct {
A int
}
f1 := &Foo{}
fcopy := copyOf(f1)
f2 := fcopy.(*Foo)
println(f2.A)
}

func namedFoo2Nested() {
type Foo struct {
A *int
}
f1 := &Foo{}
fcopy := copyOf(f1)
f2 := fcopy.(*Foo)
println(f2.A == nil)

if f2.A == nil {
type Foo struct {
A *byte
}
nestedf1 := &Foo{}
fcopy := copyOf(nestedf1)
nestedf2 := fcopy.(*Foo)
println(nestedf2.A == nil)
}
}

func copyOf(src interface{}) (dst interface{}) {
return src
}
Loading
Loading