@@ -12,8 +12,10 @@ import (
1212 "path"
1313 "strings"
1414
15- "github.com/mpyw/sqlc-restruct/pkg/internal/astutil"
1615 "golang.org/x/exp/slices"
16+ "golang.org/x/tools/imports"
17+
18+ "github.com/mpyw/sqlc-restruct/pkg/internal/astutil"
1719)
1820
1921type runner struct {
@@ -43,19 +45,32 @@ func (r *runner) Run() error {
4345 if newModelsContent , err = r .newModelsContent (); err != nil {
4446 return err
4547 }
48+ newModelsContent , err = imports .Process ("" , newModelsContent , nil )
49+ if err != nil {
50+ return err
51+ }
4652 continue
4753 }
4854 if filename == r .input .QuerierFileName {
4955 if newQuerierContent , err = r .newQuerierContent (); err != nil {
5056 return err
5157 }
58+ newQuerierContent , err = imports .Process ("" , newQuerierContent , nil )
59+ if err != nil {
60+ return err
61+ }
5262 continue
5363 }
54- if strings . HasSuffix (filename , r . input . ImplSQLSuffix ) {
64+ if r . isImplFile (filename ) {
5565 var newQueriesContent []byte
5666 if newQueriesContent , err = r .newQueriesContent (filename ); err != nil {
5767 return err
5868 }
69+ newQueriesContent , err = imports .Process ("" , newQueriesContent , nil )
70+ if err != nil {
71+ return err
72+ }
73+
5974 if newQueriesContents == nil {
6075 newQueriesContents = make (map [string ][]byte )
6176 }
@@ -66,21 +81,21 @@ func (r *runner) Run() error {
6681
6782 if newModelsContent != nil {
6883 _ = os .Remove (path .Join (r .input .ModelsDir , r .input .ModelsFileName ))
69- if err := os .WriteFile (path .Join (r .input .ModelsDir , r .input .ModelsFileName ), newModelsContent , 0644 ); err != nil {
84+ if err := os .WriteFile (path .Join (r .input .ModelsDir , r .input .ModelsFileName ), newModelsContent , 0o644 ); err != nil {
7085 return fmt .Errorf ("runner.Run() failed: %w" , err )
7186 }
7287 _ = os .Remove (path .Join (r .input .ImplDir , r .input .ModelsFileName ))
7388 }
7489 if newQuerierContent != nil {
7590 _ = os .Remove (path .Join (r .input .IfaceDir , r .input .QuerierFileName ))
76- if err := os .WriteFile (path .Join (r .input .IfaceDir , r .input .QuerierFileName ), newQuerierContent , 0644 ); err != nil {
91+ if err := os .WriteFile (path .Join (r .input .IfaceDir , r .input .QuerierFileName ), newQuerierContent , 0o644 ); err != nil {
7792 return fmt .Errorf ("runner.Run() failed: %w" , err )
7893 }
7994 _ = os .Remove (path .Join (r .input .ImplDir , r .input .QuerierFileName ))
8095 }
8196 for filename , content := range newQueriesContents {
8297 _ = os .Remove (path .Join (r .input .ImplDir , filename ))
83- if err := os .WriteFile (path .Join (r .input .ImplDir , filename ), content , 0644 ); err != nil {
98+ if err := os .WriteFile (path .Join (r .input .ImplDir , filename ), content , 0o644 ); err != nil {
8499 return fmt .Errorf ("runner.Run() failed: %w" , err )
85100 }
86101 }
@@ -168,7 +183,7 @@ func (r *runner) newQuerierContent() ([]byte, error) {
168183 }
169184
170185 for _ , dirEntry := range dirEntries {
171- if ! strings . HasSuffix (dirEntry .Name (), r . input . ImplSQLSuffix ) {
186+ if ! r . isImplFile (dirEntry .Name ()) {
172187 continue
173188 }
174189
@@ -273,3 +288,16 @@ func (r *runner) newQueriesContent(filename string) ([]byte, error) {
273288 }
274289 return byt , nil
275290}
291+
292+ func (r * runner ) isImplFile (filename string ) bool {
293+ return strings .HasSuffix (filename , r .input .ImplSQLSuffix ) || inStrings (filename , r .input .AditionalQuerierFiles )
294+ }
295+
296+ func inStrings (s string , slice []string ) bool {
297+ for _ , v := range slice {
298+ if v == s {
299+ return true
300+ }
301+ }
302+ return false
303+ }
0 commit comments