diff --git a/main.go b/main.go index 46037b0..8133040 100644 --- a/main.go +++ b/main.go @@ -15,6 +15,8 @@ import ( "io" "io/ioutil" "os" + "path" + "path/filepath" "sort" "strconv" "strings" @@ -79,49 +81,74 @@ func main() { } } -func realMain() error { - cfg, err := parseConfig(os.Args[1:]) +func GetFileNameList(pathname string) ([]string, error) { + var ret []string + dirname := path.Dir(pathname) + filename := path.Base(pathname) + dir, _ := os.Open(dirname) + defer dir.Close() + fileInfos, err := dir.Readdir(0) if err != nil { - if err == flag.ErrHelp { - return nil + return ret, errors.New("Error reading directory:" + dirname) + } + for _, fileInfo := range fileInfos { + if fileInfo.IsDir() { + continue + } + matched, err := filepath.Match(filename, fileInfo.Name()) + if err == nil && matched { + ret = append(ret, filepath.Join(dirname, fileInfo.Name())) } - return err } + return ret, err +} - err = cfg.validate() +func realMain() error { + cfgList, err := parseConfig(os.Args[1:]) if err != nil { + if err == flag.ErrHelp { + return nil + } return err } + for _, cfg := range cfgList { + var node ast.Node + var start, end int + err = cfg.validate() + if err != nil { + return err + } - node, err := cfg.parse() - if err != nil { - return err - } + node, err = cfg.parse() + if err != nil { + return err + } - start, end, err := cfg.findSelection(node) - if err != nil { - return err - } + start, end, err = cfg.findSelection(node) + if err != nil { + return err + } - rewrittenNode, errs := cfg.rewrite(node, start, end) - if errs != nil { - if _, ok := errs.(*rewriteErrors); !ok { - return errs + rewrittenNode, errs := cfg.rewrite(node, start, end) + if errs != nil { + if _, ok := errs.(*rewriteErrors); !ok { + return errs + } } - } - out, err := cfg.format(rewrittenNode, errs) - if err != nil { - return err - } + out, err := cfg.format(rewrittenNode, errs) + if err != nil { + return err + } - if !cfg.quiet { - fmt.Println(out) + if !cfg.quiet { + fmt.Println(out) + } } return nil } -func parseConfig(args []string) (*config, error) { +func parseConfig(args []string) ([]*config, error) { var ( // file flags flagFile = flag.String("file", "", "Filename to be parsed") @@ -183,47 +210,54 @@ func parseConfig(args []string) (*config, error) { return nil, flag.ErrHelp } - cfg := &config{ - file: *flagFile, - line: *flagLine, - structName: *flagStruct, - fieldName: *flagField, - offset: *flagOffset, - all: *flagAll, - output: *flagOutput, - write: *flagWrite, - quiet: *flagQuiet, - clear: *flagClearTags, - clearOption: *flagClearOptions, - transform: *flagTransform, - sort: *flagSort, - valueFormat: *flagFormatting, - override: *flagOverride, - skipUnexportedFields: *flagSkipUnexportedFields, + fileNameList, err := GetFileNameList(*flagFile) + if err != nil { + fmt.Fprintf(os.Stderr, "get file name error, %s:\n", err.Error()) + return nil, errors.New("get file name error, " + err.Error()) } + var cfgList []*config + for _, v := range fileNameList { + cfg := &config{ + file: v, + line: *flagLine, + structName: *flagStruct, + fieldName: *flagField, + offset: *flagOffset, + all: *flagAll, + output: *flagOutput, + write: *flagWrite, + quiet: *flagQuiet, + clear: *flagClearTags, + clearOption: *flagClearOptions, + transform: *flagTransform, + sort: *flagSort, + valueFormat: *flagFormatting, + override: *flagOverride, + skipUnexportedFields: *flagSkipUnexportedFields, + } - if *flagModified { - cfg.modified = os.Stdin - } + if *flagModified { + cfg.modified = os.Stdin + } - if *flagAddTags != "" { - cfg.add = strings.Split(*flagAddTags, ",") - } + if *flagAddTags != "" { + cfg.add = strings.Split(*flagAddTags, ",") + } - if *flagAddOptions != "" { - cfg.addOptions = strings.Split(*flagAddOptions, ",") - } + if *flagAddOptions != "" { + cfg.addOptions = strings.Split(*flagAddOptions, ",") + } - if *flagRemoveTags != "" { - cfg.remove = strings.Split(*flagRemoveTags, ",") - } + if *flagRemoveTags != "" { + cfg.remove = strings.Split(*flagRemoveTags, ",") + } - if *flagRemoveOptions != "" { - cfg.removeOptions = strings.Split(*flagRemoveOptions, ",") + if *flagRemoveOptions != "" { + cfg.removeOptions = strings.Split(*flagRemoveOptions, ",") + } + cfgList = append(cfgList, cfg) } - - return cfg, nil - + return cfgList, nil } func (c *config) parse() (ast.Node, error) { diff --git a/main_test.go b/main_test.go index 1facda4..3a94b9e 100644 --- a/main_test.go +++ b/main_test.go @@ -940,3 +940,30 @@ func TestParseConfig(t *testing.T) { t.Fatal(err) } } + +func TestGetFileNameList(t *testing.T) { + type args struct { + filename string + } + tests := []struct { + name string + args args + want []string + wantErr bool + }{ + {name: "normal", args: args{"main.go"}, wantErr: false, want: []string{"main.go"}}, + {name: "reg", args: args{"*.go"}, wantErr: false, want: []string{"main.go", "main_test.go"}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := GetFileNameList(tt.args.filename) + if (err != nil) != tt.wantErr { + t.Errorf("GetFileNameList() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("GetFileNameList() got = %v, want %v", got, tt.want) + } + }) + } +}