diff --git a/generator/generator.go b/generator/generator.go index 2f2f2fe5..9549631b 100644 --- a/generator/generator.go +++ b/generator/generator.go @@ -3,6 +3,7 @@ package generator import ( "bytes" "path/filepath" + "sort" "strings" "go/ast" @@ -35,7 +36,46 @@ type TemplateInputs struct { // Interface information for template Interface TemplateInputInterface // Vars additional vars to pass to the template, see Options.Vars - Vars map[string]interface{} + Vars map[string]interface{} + Imports []string +} + +// Import generates an import statement using a list of imports from the source file +// along with the ones from the template itself +func (t TemplateInputs) Import(imports ...string) string { + allImports := make(map[string]struct{}, len(imports)+len(t.Imports)) + + for _, i := range t.Imports { + allImports[strings.TrimSpace(i)] = struct{}{} + } + + for _, i := range imports { + if len(i) == 0 { + continue + } + + i = strings.TrimSpace(i) + + if i[len(i)-1] != '"' { + i += `"` + } + + if i[0] != '"' { + i = `"` + i + } + + allImports[i] = struct{}{} + } + + out := make([]string, 0, len(allImports)) + + for i := range allImports { + out = append(out, i) + } + + sort.Strings(out) + + return "import (\n" + strings.Join(out, "\n") + ")\n" } // TemplateInputInterface subset of interface information used for template generation @@ -131,8 +171,12 @@ func NewGenerator(options Options) (*Generator, error) { if srcPackage.PkgPath == dstPackage.PkgPath { interfaceType = options.InterfaceName srcPackageAST.Name = "" - } else if options.SourcePackageAlias != "" { - srcPackageAST.Name = options.SourcePackageAlias + } else { + if options.SourcePackageAlias != "" { + srcPackageAST.Name = options.SourcePackageAlias + } + + options.Imports = append(options.Imports, `"`+srcPackage.PkgPath+`"`) } methods, imports, err := findInterface(fs, srcPackageAST, options.InterfaceName) @@ -150,7 +194,7 @@ func NewGenerator(options Options) (*Generator, error) { } } - options.Imports = makeImports(imports) + options.Imports = append(options.Imports, makeImports(imports)...) return &Generator{ Options: options, @@ -219,7 +263,8 @@ func (g Generator) Generate(w io.Writer) error { Type: g.interfaceType, Methods: g.methods, }, - Vars: g.Options.Vars, + Imports: g.Options.Imports, + Vars: g.Options.Vars, }) if err != nil { return err diff --git a/generator/generator_test.go b/generator/generator_test.go index dfb8d1e9..5695c8a0 100644 --- a/generator/generator_test.go +++ b/generator/generator_test.go @@ -484,7 +484,7 @@ func TestGenerator_Generate(t *testing.T) { tests := []struct { name string init func(t minimock.Tester) Generator - inspect func(r Generator, t *testing.T) //inspects Generator after execution of Generate + inspect func(r Generator, w io.Writer, t *testing.T) //inspects Generator after execution of Generate args func(t minimock.Tester) args @@ -551,6 +551,36 @@ func TestGenerator_Generate(t *testing.T) { }, wantErr: false, }, + { + name: "imports can be generated", + init: func(t minimock.Tester) Generator { + return Generator{ + Options: Options{ + Imports: []string{`"github.com/pkg/errors"`, `"github.com/sirupsen/logrus"`}, + }, + headerTemplate: template.Must(template.New("header").Parse("package success\n")), + bodyTemplate: template.Must(template.New("body").Parse(` + {{.Import "github.com/sirupsen/logrus" }} + func test(l *logrus.Logger) {} + `)), + } + }, + args: func(t minimock.Tester) args { + return args{ + w: bytes.NewBuffer([]byte{}), + } + }, + inspect: func(_ Generator, w io.Writer, t *testing.T) { + assert.Equal(t, `package success + +import ( + "github.com/sirupsen/logrus" +) + +func test(l *logrus.Logger) {} +`, w.(*bytes.Buffer).String()) + }, + }, } for _, tt := range tests { @@ -564,7 +594,7 @@ func TestGenerator_Generate(t *testing.T) { err := receiver.Generate(tArgs.w) if tt.inspect != nil { - tt.inspect(receiver, t) + tt.inspect(receiver, tArgs.w, t) } if tt.wantErr {