Skip to content

Commit

Permalink
feat: support generic interface generation (#175)
Browse files Browse the repository at this point in the history
Allow the generation of mocks for generics as introduced in golang 1.18
  • Loading branch information
cgorenflo authored Oct 1, 2022
1 parent e85ff8e commit 13aa048
Show file tree
Hide file tree
Showing 10 changed files with 526 additions and 32 deletions.
8 changes: 7 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
module github.com/matryer/moq

go 1.14
go 1.18

require (
github.com/pmezard/go-difflib v1.0.0
golang.org/x/tools v0.1.10
)

require (
golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 // indirect
golang.org/x/sys v0.0.0-20211019181941-9d821ace8654 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
)
21 changes: 0 additions & 21 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,31 +1,10 @@
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 h1:kQgndtyPBW/JIYERgdxfwMYh3AVStj88WQTlNDi2a+o=
golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211019181941-9d821ace8654 h1:id054HUawV2/6IGm2IV8KZQjqtwAOo2CYlOToYqa0d0=
golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.10 h1:QjFRCZxdOhBJ/UNgnBZLbNV13DlbnK0quyivTnXJM20=
golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
14 changes: 10 additions & 4 deletions internal/registry/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,23 @@ func (r Registry) SrcPkgName() string {

// LookupInterface returns the underlying interface definition of the
// given interface name.
func (r Registry) LookupInterface(name string) (*types.Interface, error) {
func (r Registry) LookupInterface(name string) (*types.Interface, *types.TypeParamList, error) {
obj := r.SrcPkg().Scope().Lookup(name)
if obj == nil {
return nil, fmt.Errorf("interface not found: %s", name)
return nil, nil, fmt.Errorf("interface not found: %s", name)
}

if !types.IsInterface(obj.Type()) {
return nil, fmt.Errorf("%s (%s) is not an interface", name, obj.Type())
return nil, nil, fmt.Errorf("%s (%s) is not an interface", name, obj.Type())
}

return obj.Type().Underlying().(*types.Interface).Complete(), nil
var tparams *types.TypeParamList
named, ok := obj.Type().(*types.Named)
if ok {
tparams = named.TypeParams()
}

return obj.Type().Underlying().(*types.Interface).Complete(), tparams, nil
}

// MethodScope returns a new MethodScope.
Expand Down
40 changes: 36 additions & 4 deletions internal/template/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,22 @@ import (
{{- if not $.SkipEnsure -}}
// Ensure, that {{.MockName}} does implement {{$.SrcPkgQualifier}}{{.InterfaceName}}.
// If this is not the case, regenerate this file with moq.
var _ {{$.SrcPkgQualifier}}{{.InterfaceName}} = &{{.MockName}}{}
var _ {{$.SrcPkgQualifier}}{{.InterfaceName -}}
{{- if .TypeParams }}[
{{- range $index, $param := .TypeParams}}
{{- if $index}}, {{end -}}
{{if $param.Constraint}}{{$param.Constraint.String}}{{else}}{{$param.TypeString}}{{end}}
{{- end -}}
]
{{- end }} = &{{.MockName}}
{{- if .TypeParams }}[
{{- range $index, $param := .TypeParams}}
{{- if $index}}, {{end -}}
{{if $param.Constraint}}{{$param.Constraint.String}}{{else}}{{$param.TypeString}}{{end}}
{{- end -}}
]
{{- end -}}
{}
{{- end}}
// {{.MockName}} is a mock implementation of {{$.SrcPkgQualifier}}{{.InterfaceName}}.
Expand All @@ -68,7 +83,12 @@ var _ {{$.SrcPkgQualifier}}{{.InterfaceName}} = &{{.MockName}}{}
// // and then make assertions.
//
// }
type {{.MockName}} struct {
type {{.MockName}}
{{- if .TypeParams -}}
[{{- range $index, $param := .TypeParams}}
{{- if $index}}, {{end}}{{$param.Name | Exported}} {{$param.TypeString}}
{{- end -}}]
{{- end }} struct {
{{- range .Methods}}
// {{.Name}}Func mocks the {{.Name}} method.
{{.Name}}Func func({{.ArgList}}) {{.ReturnArgTypeList}}
Expand All @@ -91,7 +111,13 @@ type {{.MockName}} struct {
}
{{range .Methods}}
// {{.Name}} calls {{.Name}}Func.
func (mock *{{$mock.MockName}}) {{.Name}}({{.ArgList}}) {{.ReturnArgTypeList}} {
func (mock *{{$mock.MockName}}
{{- if $mock.TypeParams -}}
[{{- range $index, $param := $mock.TypeParams}}
{{- if $index}}, {{end}}{{$param.Name | Exported}}
{{- end -}}]
{{- end -}}
) {{.Name}}({{.ArgList}}) {{.ReturnArgTypeList}} {
{{- if not $.StubImpl}}
if mock.{{.Name}}Func == nil {
panic("{{$mock.MockName}}.{{.Name}}Func: method is nil but {{$mock.InterfaceName}}.{{.Name}} was just called")
Expand Down Expand Up @@ -134,7 +160,13 @@ func (mock *{{$mock.MockName}}) {{.Name}}({{.ArgList}}) {{.ReturnArgTypeList}} {
// {{.Name}}Calls gets all the calls that were made to {{.Name}}.
// Check the length with:
// len(mocked{{$mock.InterfaceName}}.{{.Name}}Calls())
func (mock *{{$mock.MockName}}) {{.Name}}Calls() []struct {
func (mock *{{$mock.MockName}}
{{- if $mock.TypeParams -}}
[{{- range $index, $param := $mock.TypeParams}}
{{- if $index}}, {{end}}{{$param.Name | Exported}}
{{- end -}}]
{{- end -}}
) {{.Name}}Calls() []struct {
{{- range .Params}}
{{.Name | Exported}} {{.TypeString}}
{{- end}}
Expand Down
7 changes: 7 additions & 0 deletions internal/template/template_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package template

import (
"fmt"
"go/types"
"strings"

"github.com/matryer/moq/internal/registry"
Expand Down Expand Up @@ -33,6 +34,7 @@ func (d Data) MocksSomeMethod() bool {
type MockData struct {
InterfaceName string
MockName string
TypeParams []TypeParamData
Methods []MethodData
}

Expand Down Expand Up @@ -87,6 +89,11 @@ func (m MethodData) ReturnArgNameList() string {
return strings.Join(params, ", ")
}

type TypeParamData struct {
ParamData
Constraint types.Type
}

// ParamData is the data which represents a parameter to some method of
// an interface.
type ParamData struct {
Expand Down
41 changes: 40 additions & 1 deletion pkg/moq/moq.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package moq
import (
"bytes"
"errors"
"go/token"
"go/types"
"io"
"strings"
Expand Down Expand Up @@ -57,7 +58,7 @@ func (m *Mocker) Mock(w io.Writer, namePairs ...string) error {
mocks := make([]template.MockData, len(namePairs))
for i, np := range namePairs {
name, mockName := parseInterfaceName(np)
iface, err := m.registry.LookupInterface(name)
iface, tparams, err := m.registry.LookupInterface(name)
if err != nil {
return err
}
Expand All @@ -71,6 +72,7 @@ func (m *Mocker) Mock(w io.Writer, namePairs ...string) error {
InterfaceName: name,
MockName: mockName,
Methods: methods,
TypeParams: m.typeParams(tparams),
}
}

Expand Down Expand Up @@ -110,6 +112,43 @@ func (m *Mocker) Mock(w io.Writer, namePairs ...string) error {
return nil
}

func (m *Mocker) typeParams(tparams *types.TypeParamList) []template.TypeParamData {
var tpd []template.TypeParamData
if tparams == nil {
return tpd
}

tpd = make([]template.TypeParamData, tparams.Len())

scope := m.registry.MethodScope()
for i := 0; i < len(tpd); i++ {
tp := tparams.At(i)
typeParam := types.NewParam(token.Pos(i), tp.Obj().Pkg(), tp.Obj().Name(), tp.Constraint())
tpd[i] = template.TypeParamData{
ParamData: template.ParamData{Var: scope.AddVar(typeParam, "")},
Constraint: explicitConstraintType(typeParam),
}
}

return tpd
}

func explicitConstraintType(typeParam *types.Var) (t types.Type) {
underlying := typeParam.Type().Underlying().(*types.Interface)
// check if any of the embedded types is either a basic type or a union,
// because the generic type has to be an alias for one of those types then
for j := 0; j < underlying.NumEmbeddeds(); j++ {
t := underlying.EmbeddedType(j)
switch t := t.(type) {
case *types.Basic:
return t
case *types.Union: // only unions of basic types are allowed, so just take the first one as a valid type constraint
return t.Term(0).Type()
}
}
return nil
}

func (m *Mocker) methodData(f *types.Func) template.MethodData {
sig := f.Type().(*types.Signature)

Expand Down
6 changes: 6 additions & 0 deletions pkg/moq/moq_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,12 @@ func TestMockGolden(t *testing.T) {
interfaces: []string{"ShadowTypes"},
goldenFile: filepath.Join("testpackages/shadowtypes", "shadowtypes_moq.golden.go"),
},
{
name: "Generics",
cfg: Config{SrcDir: "testpackages/generics"},
interfaces: []string{"GenericStore1", "GenericStore2", "AliasStore"},
goldenFile: filepath.Join("testpackages/generics", "generics_moq.golden.go"),
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
Expand Down
32 changes: 32 additions & 0 deletions pkg/moq/testpackages/generics/generics.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package generics

import (
"context"
"fmt"
)

type GenericStore1[T Key1, S any] interface {
Get(ctx context.Context, id T) (S, error)
Create(ctx context.Context, id T, value S) error
}

type GenericStore2[T Key2, S any] interface {
Get(ctx context.Context, id T) (S, error)
Create(ctx context.Context, id T, value S) error
}

type AliasStore GenericStore1[KeyImpl, bool]

type Key1 interface {
fmt.Stringer
}

type Key2 interface {
~[]byte | string
}

type KeyImpl []byte

func (x KeyImpl) String() string {
return string(x)
}
Loading

0 comments on commit 13aa048

Please sign in to comment.