diff --git a/README.md b/README.md index 0b537bd..f22d638 100644 --- a/README.md +++ b/README.md @@ -24,9 +24,36 @@ go install github.com/flaticols/resetgen@latest Or add as a tool dependency (Go 1.24+): ```bash -go get -tool github.com/flaticols/resetgen +go get -tool github.com/flaticols/resetgen@latest ``` +### Go 1.24+ Tool Mechanism + +Go 1.24 introduced the ability to manage CLI tools as dependencies. You can declare tool requirements in `go.mod`: + +```go +tool ( + github.com/flaticols/resetgen +) +``` + +Run with `go tool`: + +```bash +# Generate from current package +go tool resetgen + +# Generate from specific packages +go tool resetgen ./... +go tool resetgen ./cmd ./internal + +# With flags +go tool resetgen -structs User,Order ./... +go tool resetgen -version +``` + +This approach keeps your tool versions synchronized with your project, just like regular dependencies. + ### Usage Add `reset` tags to your struct fields and run the generator: @@ -69,6 +96,136 @@ func (s *Request) Reset() { | `reset:"value"` | Default value | | `reset:"-"` | Skip field | +## CLI Flag Syntax + +### `-structs` Flag + +Specify which structs to generate using the `-structs` flag: + +```bash +//go:generate resetgen -structs User,Order,Config + +# Or with multiple files +resetgen -structs User,Order,Config ./... +``` + +When `-structs` is specified: +- **ONLY** the listed structs are processed (tags and directives are ignored for struct selection) +- All exported fields are reset to zero values +- Field-level `reset` tags still work for custom values or to skip specific fields + +**Example:** +```go +//go:generate resetgen -structs User,Order + +type User struct { + ID int64 + Name string + Secret string `reset:"-"` // Still respected - field will not be reset +} + +type Order struct { + ID int64 + Items []string + Total float64 `reset:"0.0"` // Custom value still works +} + +type Logger struct { + Level string // Will NOT be generated (not in -structs list) +} +``` + +### Package-Qualified Names + +When you have structs with the same name in different packages, use package-qualified names: + +```bash +# Process User in models package only +resetgen -structs models.User ./... + +# Process User in both models and api packages +resetgen -structs models.User,api.User ./... + +# Mix simple and qualified names +resetgen -structs Order,models.User ./... +``` + +**Rules:** +- Simple name (`User`) → processes ALL User structs in all packages +- Qualified name (`models.User`) → processes only User in models package +- Package path uses Go import path format (lowercase with dots/slashes) + +**Example with multiple packages:** +```go +// models/user.go +//go:generate resetgen -structs models.User,api.User + +package models + +type User struct { + ID int64 `reset:""` + Name string `reset:""` + Email string `reset:""` +} + +// api/user.go +//go:generate resetgen -structs models.User,api.User + +package api + +type User struct { + ID string `reset:""` + Status string `reset:"active"` +} +``` + +Both packages can use the same go:generate directive with package-qualified names, and each will generate only its own Reset() method. + +## Directive Syntax + +Use the `+resetgen` comment directive to mark structs for automatic `Reset()` generation without tagging every field: + +```go +//go:generate resetgen + +package main + +// +resetgen +type Request struct { + ID string // defaults to zero value + Method string // defaults to zero value + Headers map[string]string // defaults to zero value + Secret string `reset:"-"` // skipped from reset +} +``` + +Generated `request.gen.go`: + +```go +func (s *Request) Reset() { + s.ID = "" + s.Method = "" + clear(s.Headers) // preserves capacity + // Secret is not reset (reset:"-") +} +``` + +### How Directive Works + +- **Struct Selection**: Structs are processed if they have a `+resetgen` comment OR contain `reset` tags +- **Field Processing**: All exported fields are reset to zero values +- **Custom Values**: Fields with explicit `reset` tags use their specified values +- **Skip Fields**: Use `reset:"-"` to exclude specific fields from reset +- **Unexported Fields**: Private fields (lowercase) are automatically skipped for safety + +### Directive Formats + +All of these are recognized: +- `//+resetgen` +- `// +resetgen` +- `// +resetgen` +- `/* +resetgen */` + ## Features - **Allocation-free** — slices truncate (`s[:0]`), maps clear (`clear(m)`) diff --git a/cmd/resetgen-analyzer/analyzer/analyzer.go b/cmd/resetgen-analyzer/analyzer/analyzer.go index d9cbe3b..97d08c2 100644 --- a/cmd/resetgen-analyzer/analyzer/analyzer.go +++ b/cmd/resetgen-analyzer/analyzer/analyzer.go @@ -18,10 +18,10 @@ var Analyzer = &analysis.Analyzer{ Run: run, } +// run performs the analysis on all function declarations and literals in the pass. func run(pass *analysis.Pass) (any, error) { insp := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector) - // Analyze each function separately nodeFilter := []ast.Node{ (*ast.FuncDecl)(nil), (*ast.FuncLit)(nil), @@ -45,11 +45,11 @@ func run(pass *analysis.Pass) (any, error) { return nil, nil } +// analyzeFunction checks a function body for sync.Pool.Put() calls without preceding Reset() calls. +// Tracks which variables have had Reset() called on them and reports violations. func analyzeFunction(pass *analysis.Pass, body *ast.BlockStmt) { - // Track variables that had Reset() called on them resetCalled := make(map[string]bool) - // Walk statements in order ast.Inspect(body, func(n ast.Node) bool { stmt, ok := n.(*ast.ExprStmt) if !ok { @@ -66,7 +66,6 @@ func analyzeFunction(pass *analysis.Pass, body *ast.BlockStmt) { return true } - // Check for x.Reset() calls - track any variable that had Reset called if sel.Sel.Name == "Reset" && len(call.Args) == 0 { varName := extractVarName(sel.X) if varName != "" { @@ -74,7 +73,6 @@ func analyzeFunction(pass *analysis.Pass, body *ast.BlockStmt) { } } - // Check for sync.Pool.Put(x) calls if sel.Sel.Name == "Put" && isSyncPoolMethod(sel, pass.TypesInfo) { if len(call.Args) == 1 { varName := extractVarName(call.Args[0]) @@ -88,14 +86,11 @@ func analyzeFunction(pass *analysis.Pass, body *ast.BlockStmt) { }) } -// extractVarName gets the variable name from an expression -// Handles: x, s.x, s.field.x func extractVarName(expr ast.Expr) string { switch e := expr.(type) { case *ast.Ident: return e.Name case *ast.SelectorExpr: - // For s.field, we still track by the root identifier return extractVarName(e.X) case *ast.StarExpr: return extractVarName(e.X) @@ -103,7 +98,6 @@ func extractVarName(expr ast.Expr) string { return "" } -// isSyncPoolMethod checks if sel is a method on sync.Pool func isSyncPoolMethod(sel *ast.SelectorExpr, info *types.Info) bool { tv, ok := info.Types[sel.X] if !ok { diff --git a/internal/generator/generator.go b/internal/generator/generator.go index 5bf9cdb..f96ae1e 100644 --- a/internal/generator/generator.go +++ b/internal/generator/generator.go @@ -7,7 +7,8 @@ import ( "github.com/flaticols/resetgen/internal/types" ) -// Generate produces Reset() methods for all structs in the file info. +// Generate produces Reset() methods for all structs in the file info, including package +// declaration, imports, and all Reset() method implementations. func Generate(info *types.FileInfo) string { if len(info.Structs) == 0 { return "" @@ -44,7 +45,7 @@ func Generate(info *types.FileInfo) string { return b.String() } -// GenerateStruct produces a Reset() method for a single struct. +// GenerateStruct generates a single Reset() method for a struct. func GenerateStruct(s *types.StructInfo) string { var b strings.Builder b.Grow(512) @@ -52,6 +53,8 @@ func GenerateStruct(s *types.StructInfo) string { return b.String() } +// collectImports extracts all required standard library imports from struct fields. +// Maps common package aliases to their full import paths. func collectImports(structs []types.StructInfo) []string { pkgSet := make(map[string]bool) @@ -73,6 +76,7 @@ func collectImports(structs []types.StructInfo) []string { return imports } +// extractPackage maps package aliases to their full import paths for standard library types. func extractPackage(typeStr string) string { t := strings.TrimPrefix(typeStr, "*") idx := strings.Index(t, ".") @@ -150,6 +154,8 @@ func generateFieldReset(b *strings.Builder, f *types.FieldInfo) { } } +// generateDefaultReset writes the code to reset a field to its default value. +// For slices/maps and embedded structs, delegates to appropriate zero-reset logic. func generateDefaultReset(b *strings.Builder, f *types.FieldInfo, accessor string) { switch f.Kind { case types.KindSlice, types.KindMap: @@ -181,6 +187,8 @@ func generateDefaultReset(b *strings.Builder, f *types.FieldInfo, accessor strin } } +// generateZeroReset writes the code to reset a field to its zero value. +// Handles embedded types by calling their Reset(), slices by truncating, and maps by clearing. func generateZeroReset(b *strings.Builder, f *types.FieldInfo, accessor string) { if f.IsEmbedded { if isExternalType(f.TypeStr) { @@ -299,6 +307,7 @@ func formatDefault(f *types.FieldInfo) string { } } +// zeroValue returns the zero value literal for a Go type. func zeroValue(typeStr string) string { switch typeStr { case "string": diff --git a/internal/parser/parser.go b/internal/parser/parser.go index 92b34d3..be043c8 100644 --- a/internal/parser/parser.go +++ b/internal/parser/parser.go @@ -11,10 +11,14 @@ import ( "github.com/flaticols/resetgen/internal/types" ) -const tagName = "reset" +const ( + tagName = "reset" + toolDirective = "+resetgen" +) // ParseFile parses a Go source file and extracts structs with reset tags. -func ParseFile(path string) (*types.FileInfo, error) { +// If structFilter is provided, only the listed struct names are processed. +func ParseFile(path string, structFilter map[string]bool) (*types.FileInfo, error) { fset := token.NewFileSet() f, err := parser.ParseFile(fset, path, nil, parser.ParseComments) if err != nil { @@ -43,7 +47,7 @@ func ParseFile(path string) (*types.FileInfo, error) { continue } - structInfo := parseStruct(typeSpec.Name.Name, structType) + structInfo := parseStruct(typeSpec.Name.Name, structType, genDecl, structFilter) if structInfo != nil { structInfo.PkgName = info.PkgName info.Structs = append(info.Structs, *structInfo) @@ -55,7 +59,14 @@ func ParseFile(path string) (*types.FileInfo, error) { } // ParseSource parses Go source code from a string. +// Kept for backward compatibility with existing tests. func ParseSource(src string) (*types.FileInfo, error) { + return ParseSourceWithFilter(src, nil) +} + +// ParseSourceWithFilter parses Go source code with an optional struct filter. +// If structFilter is provided, only the listed struct names are processed. +func ParseSourceWithFilter(src string, structFilter map[string]bool) (*types.FileInfo, error) { fset := token.NewFileSet() f, err := parser.ParseFile(fset, "source.go", src, parser.ParseComments) if err != nil { @@ -84,7 +95,7 @@ func ParseSource(src string) (*types.FileInfo, error) { continue } - structInfo := parseStruct(typeSpec.Name.Name, structType) + structInfo := parseStruct(typeSpec.Name.Name, structType, genDecl, structFilter) if structInfo != nil { structInfo.PkgName = info.PkgName info.Structs = append(info.Structs, *structInfo) @@ -95,39 +106,133 @@ func ParseSource(src string) (*types.FileInfo, error) { return info, nil } -func parseStruct(name string, st *ast.StructType) *types.StructInfo { +// hasResetgenDirective reports whether genDecl has the +resetgen comment directive. +// Recognizes various formats: "//+resetgen", "// +resetgen", "/*+resetgen*/", etc. +func hasResetgenDirective(genDecl *ast.GenDecl) bool { + if genDecl.Doc == nil { + return false + } + + for _, comment := range genDecl.Doc.List { + text := strings.TrimSpace(strings.TrimPrefix(comment.Text, "//")) + text = strings.TrimSpace(strings.TrimPrefix(text, "/*")) + text = strings.TrimSuffix(strings.TrimSpace(text), "*/") + + if strings.HasPrefix(text, toolDirective) { + return true + } + } + + return false +} + +// isExportedType reports whether expr refers to an exported type. +// Pointer-to-type and package-qualified types are considered exported. +func isExportedType(expr ast.Expr) bool { + switch t := expr.(type) { + case *ast.Ident: + return ast.IsExported(t.Name) + case *ast.StarExpr: + return isExportedType(t.X) + case *ast.SelectorExpr: + return true + default: + return false + } +} + +// checkHasResetTag reports whether any field in the struct has a reset tag. +func checkHasResetTag(fields *ast.FieldList) bool { + for _, field := range fields.List { + if field.Tag != nil { + if _, hasTag := parseTag(field.Tag.Value); hasTag { + return true + } + } + } + return false +} + +// parseStruct extracts struct field information based on reset tags and directives. +// When structFilter is provided, all exported fields are included; otherwise only +// fields with reset tags or structs with +resetgen directives are processed. +// Returns nil if the struct should not be processed or has no non-ignored fields. +func parseStruct(name string, st *ast.StructType, genDecl *ast.GenDecl, structFilter map[string]bool) *types.StructInfo { if st.Fields == nil { return nil } + var shouldProcess bool + var processAllExported bool + + if structFilter != nil { + _, shouldProcess = structFilter[name] + processAllExported = shouldProcess + } else { + hasResetTag := checkHasResetTag(st.Fields) + hasDirective := hasResetgenDirective(genDecl) + shouldProcess = hasResetTag || hasDirective + processAllExported = hasDirective + } + + if !shouldProcess { + return nil + } + var fields []types.FieldInfo - hasResetTag := false + hasNonIgnoredFields := false for _, field := range st.Fields.List { - if field.Tag == nil { - continue - } + var tag string + var hasTag bool - tag, ok := parseTag(field.Tag.Value) - if !ok { - continue + if field.Tag != nil { + tag, hasTag = parseTag(field.Tag.Value) } - hasResetTag = true - if len(field.Names) == 0 { - fi := parseField("", field.Type, tag, true) + if !hasTag && !processAllExported { + continue + } + + if processAllExported && !isExportedType(field.Type) { + continue + } + + tagVal := "" + if hasTag { + tagVal = tag + } + fi := parseField("", field.Type, tagVal, true) fields = append(fields, fi) + if fi.Action != types.ActionIgnore { + hasNonIgnoredFields = true + } continue } for _, ident := range field.Names { - fi := parseField(ident.Name, field.Type, tag, false) + if !hasTag && !processAllExported { + continue + } + + if processAllExported && !ast.IsExported(ident.Name) { + continue + } + + tagVal := "" + if hasTag { + tagVal = tag + } + fi := parseField(ident.Name, field.Type, tagVal, false) fields = append(fields, fi) + if fi.Action != types.ActionIgnore { + hasNonIgnoredFields = true + } } } - if !hasResetTag { + if len(fields) == 0 || !hasNonIgnoredFields { return nil } @@ -146,6 +251,8 @@ func parseTag(tagLit string) (string, bool) { return st.Lookup(tagName) } +// parseField creates a FieldInfo from an AST field expression and tag value. +// Determines the field's type kind, name, and reset action based on the tag. func parseField(name string, typeExpr ast.Expr, tagVal string, embedded bool) types.FieldInfo { fi := types.FieldInfo{ Name: name, @@ -224,6 +331,7 @@ func getEmbeddedName(expr ast.Expr) string { } } +// exprToString converts an AST expression to its string representation. func exprToString(expr ast.Expr) string { switch t := expr.(type) { case *ast.Ident: diff --git a/internal/parser/parser_test.go b/internal/parser/parser_test.go index d6be582..e1df52d 100644 --- a/internal/parser/parser_test.go +++ b/internal/parser/parser_test.go @@ -229,6 +229,482 @@ type Event struct { } } +// Directive tests + +func TestParseSource_DirectiveOnly(t *testing.T) { + src := `package test + +// +resetgen +type User struct { + ID int64 + Name string + Email string +} +` + info, err := ParseSource(src) + if err != nil { + t.Fatalf("ParseSource failed: %v", err) + } + + if len(info.Structs) != 1 { + t.Fatalf("expected 1 struct, got %d", len(info.Structs)) + } + + s := info.Structs[0] + if len(s.Fields) != 3 { + t.Fatalf("expected 3 fields, got %d", len(s.Fields)) + } + + // All fields should have ActionZero + for i, f := range s.Fields { + if f.Action != types.ActionZero { + t.Errorf("field %d: expected ActionZero, got %d", i, f.Action) + } + } +} + +func TestParseSource_DirectiveWithTags(t *testing.T) { + src := `package test + +//+resetgen +type User struct { + ID int64 + Name string ` + "`reset:\"guest\"`" + ` + Email string + Age int ` + "`reset:\"-\"`" + ` +} +` + info, err := ParseSource(src) + if err != nil { + t.Fatalf("ParseSource failed: %v", err) + } + + s := info.Structs[0] + if len(s.Fields) != 4 { + t.Fatalf("expected 4 fields, got %d", len(s.Fields)) + } + + tests := []struct { + name string + action types.TagAction + def string + }{ + {"ID", types.ActionZero, ""}, + {"Name", types.ActionDefault, "guest"}, + {"Email", types.ActionZero, ""}, + {"Age", types.ActionIgnore, ""}, + } + + for i, tt := range tests { + f := s.Fields[i] + if f.Name != tt.name { + t.Errorf("field %d: expected name %s, got %s", i, tt.name, f.Name) + } + if f.Action != tt.action { + t.Errorf("field %d: expected action %d, got %d", i, tt.action, f.Action) + } + if f.Default != tt.def { + t.Errorf("field %d: expected default %q, got %q", i, tt.def, f.Default) + } + } +} + +func TestParseSource_DirectiveRespectsIgnore(t *testing.T) { + src := `package test + +// +resetgen +type Config struct { + Host string + Port int + Secret string ` + "`reset:\"-\"`" + ` +} +` + info, err := ParseSource(src) + if err != nil { + t.Fatalf("ParseSource failed: %v", err) + } + + s := info.Structs[0] + if len(s.Fields) != 3 { + t.Fatalf("expected 3 fields, got %d", len(s.Fields)) + } + + // Host and Port should be ActionZero + if s.Fields[0].Action != types.ActionZero { + t.Errorf("Host: expected ActionZero, got %d", s.Fields[0].Action) + } + if s.Fields[1].Action != types.ActionZero { + t.Errorf("Port: expected ActionZero, got %d", s.Fields[1].Action) + } + // Secret should be ActionIgnore + if s.Fields[2].Action != types.ActionIgnore { + t.Errorf("Secret: expected ActionIgnore, got %d", s.Fields[2].Action) + } +} + +func TestParseSource_DirectiveSkipsUnexported(t *testing.T) { + src := `package test + +// +resetgen +type Request struct { + ID string + name string + Token string +} +` + info, err := ParseSource(src) + if err != nil { + t.Fatalf("ParseSource failed: %v", err) + } + + s := info.Structs[0] + if len(s.Fields) != 2 { + t.Fatalf("expected 2 fields (unexported 'name' skipped), got %d", len(s.Fields)) + } + + // Should have ID and Token, but not name + if s.Fields[0].Name != "ID" { + t.Errorf("expected first field ID, got %s", s.Fields[0].Name) + } + if s.Fields[1].Name != "Token" { + t.Errorf("expected second field Token, got %s", s.Fields[1].Name) + } +} + +func TestParseSource_DirectiveAllIgnored(t *testing.T) { + src := `package test + +// +resetgen +type Config struct { + Field1 int ` + "`reset:\"-\"`" + ` + Field2 int ` + "`reset:\"-\"`" + ` +} +` + info, err := ParseSource(src) + if err != nil { + t.Fatalf("ParseSource failed: %v", err) + } + + // Struct should be skipped entirely (no fields to reset) + if len(info.Structs) != 0 { + t.Fatalf("expected 0 structs (all fields ignored), got %d", len(info.Structs)) + } +} + +func TestParseSource_DirectiveFormats(t *testing.T) { + tests := []struct { + name string + src string + expected int + }{ + { + "no space", `package test +// +resetgen +type User struct { + ID int +}`, + 1, + }, + { + "single space", `package test +// +resetgen +type User struct { + ID int +}`, + 1, + }, + { + "multiple spaces", `package test +// +resetgen +type User struct { + ID int +}`, + 1, + }, + { + "no space after slash", `package test +//+resetgen +type User struct { + ID int +}`, + 1, + }, + { + "wrong prefix no plus", `package test +// resetgen +type User struct { + ID int +}`, + 0, + }, + { + "case sensitive", `package test +// +ResetGen +type User struct { + ID int +}`, + 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + info, err := ParseSource(tt.src) + if err != nil { + t.Fatalf("ParseSource failed: %v", err) + } + + if len(info.Structs) != tt.expected { + t.Errorf("expected %d structs, got %d", tt.expected, len(info.Structs)) + } + }) + } +} + +func TestParseSource_DirectiveEmbedded(t *testing.T) { + src := `package test + +// +resetgen +type Request struct { + Body io.Reader + Name string +} +` + info, err := ParseSource(src) + if err != nil { + t.Fatalf("ParseSource failed: %v", err) + } + + s := info.Structs[0] + if len(s.Fields) != 2 { + t.Fatalf("expected 2 fields, got %d", len(s.Fields)) + } + + // Both fields should have ActionZero + if s.Fields[0].Action != types.ActionZero { + t.Errorf("Body: expected ActionZero, got %d", s.Fields[0].Action) + } + if s.Fields[1].Action != types.ActionZero { + t.Errorf("Name: expected ActionZero, got %d", s.Fields[1].Action) + } +} + +func TestParseSource_BackwardCompatibility(t *testing.T) { + // Verify that tag-based detection still works without directive + src := `package test + +type NoTags struct { + ID int + Name string +} + +type SomeTags struct { + ID int + Name string ` + "`reset:\"\"`" + ` +} +` + info, err := ParseSource(src) + if err != nil { + t.Fatalf("ParseSource failed: %v", err) + } + + // Should only have SomeTags (backward compatibility) + if len(info.Structs) != 1 { + t.Fatalf("expected 1 struct, got %d", len(info.Structs)) + } + + if info.Structs[0].Name != "SomeTags" { + t.Errorf("expected SomeTags, got %s", info.Structs[0].Name) + } +} + +// Tests for -structs filter functionality + +func TestParseSourceWithFilter_SpecificStructs(t *testing.T) { + src := `package test + +type User struct { + ID int64 + Name string +} + +type Config struct { + Host string + Port int +} + +type Logger struct { + Level string +} +` + // Only process User and Config + filter := map[string]bool{ + "User": true, + "Config": true, + } + + info, err := ParseSourceWithFilter(src, filter) + if err != nil { + t.Fatalf("ParseSourceWithFilter failed: %v", err) + } + + if len(info.Structs) != 2 { + t.Fatalf("expected 2 structs, got %d", len(info.Structs)) + } + + // Should have User and Config, not Logger + names := make(map[string]bool) + for _, s := range info.Structs { + names[s.Name] = true + } + + if !names["User"] { + t.Error("expected User struct") + } + if !names["Config"] { + t.Error("expected Config struct") + } + if names["Logger"] { + t.Error("Logger should not be included") + } +} + +func TestParseSourceWithFilter_AllExportedFields(t *testing.T) { + src := `package test + +type User struct { + ID int64 + Name string + email string + Age int ` + "`reset:\"-\"`" + ` +} +` + filter := map[string]bool{"User": true} + + info, err := ParseSourceWithFilter(src, filter) + if err != nil { + t.Fatalf("ParseSourceWithFilter failed: %v", err) + } + + s := info.Structs[0] + + // Should have ID, Name, and Age (but not email) + if len(s.Fields) != 3 { + t.Fatalf("expected 3 fields, got %d", len(s.Fields)) + } + + // Check field names + hasID := false + hasName := false + hasAge := false + for _, f := range s.Fields { + if f.Name == "ID" { + hasID = true + } + if f.Name == "Name" { + hasName = true + } + if f.Name == "Age" { + hasAge = true + } + } + + if !hasID || !hasName || !hasAge { + t.Errorf("missing expected fields: ID=%v, Name=%v, Age=%v", hasID, hasName, hasAge) + } +} + +func TestParseSourceWithFilter_RespectsTagsInFilteredStructs(t *testing.T) { + src := `package test + +type User struct { + ID int64 ` + "`reset:\"\"`" + ` + Name string ` + "`reset:\"guest\"`" + ` + Email string + Secret string ` + "`reset:\"-\"`" + ` +} +` + filter := map[string]bool{"User": true} + + info, err := ParseSourceWithFilter(src, filter) + if err != nil { + t.Fatalf("ParseSourceWithFilter failed: %v", err) + } + + s := info.Structs[0] + + tests := []struct { + name string + action types.TagAction + def string + }{ + {"ID", types.ActionZero, ""}, + {"Name", types.ActionDefault, "guest"}, + {"Email", types.ActionZero, ""}, + {"Secret", types.ActionIgnore, ""}, + } + + for i, tt := range tests { + f := s.Fields[i] + if f.Name != tt.name { + t.Errorf("field %d: expected name %s, got %s", i, tt.name, f.Name) + } + if f.Action != tt.action { + t.Errorf("field %d (%s): expected action %d, got %d", i, tt.name, tt.action, f.Action) + } + if f.Default != tt.def { + t.Errorf("field %d (%s): expected default %q, got %q", i, tt.name, tt.def, f.Default) + } + } +} + +func TestParseSourceWithFilter_EmptyFilter(t *testing.T) { + src := `package test + +type User struct { + ID int64 + Name string +} +` + filter := map[string]bool{} + + info, err := ParseSourceWithFilter(src, filter) + if err != nil { + t.Fatalf("ParseSourceWithFilter failed: %v", err) + } + + // Empty filter means process nothing + if len(info.Structs) != 0 { + t.Fatalf("expected 0 structs with empty filter, got %d", len(info.Structs)) + } +} + +func TestParseSourceWithFilter_NilFilterUsesDefaultBehavior(t *testing.T) { + src := `package test + +type Tagged struct { + ID int64 ` + "`reset:\"\"`" + ` +} + +type NotTagged struct { + ID int64 +} +` + info, err := ParseSourceWithFilter(src, nil) + if err != nil { + t.Fatalf("ParseSourceWithFilter failed: %v", err) + } + + // Nil filter should use default behavior (only Tagged) + if len(info.Structs) != 1 { + t.Fatalf("expected 1 struct with nil filter, got %d", len(info.Structs)) + } + + if info.Structs[0].Name != "Tagged" { + t.Errorf("expected Tagged struct, got %s", info.Structs[0].Name) + } +} + func BenchmarkParseSource(b *testing.B) { src := `package test diff --git a/main.go b/main.go index ae6ded2..f47a14a 100644 --- a/main.go +++ b/main.go @@ -11,16 +11,19 @@ import ( "github.com/flaticols/resetgen/internal/generator" "github.com/flaticols/resetgen/internal/parser" + "github.com/flaticols/resetgen/internal/types" ) func main() { var ( showVersion bool dryRun bool + structsFlag string ) flag.BoolVar(&showVersion, "version", false, "print version and exit") flag.BoolVar(&dryRun, "dry-run", false, "print generated code instead of writing files") + flag.StringVar(&structsFlag, "structs", "", "comma-separated list of struct names to process (e.g., User,Order,Config)") flag.Usage = func() { fmt.Fprintf(os.Stderr, "Usage: resetgen [flags] [patterns...]\n\n") fmt.Fprintf(os.Stderr, "Generate Reset() methods for structs with reset tags.\n\n") @@ -39,25 +42,70 @@ func main() { return } + var structFilter map[string]bool + if structsFlag != "" { + structFilter = make(map[string]bool) + names := strings.Split(structsFlag, ",") + for _, name := range names { + name = strings.TrimSpace(name) + if name == "" { + continue + } + + if strings.Contains(name, ".") { + parts := strings.Split(name, ".") + if len(parts) != 2 { + fmt.Fprintf(os.Stderr, "resetgen: invalid format %s (use Package.Struct)\n", name) + os.Exit(1) + } + pkgPath := parts[0] + structName := parts[1] + + if !isValidGoIdentifier(structName) { + fmt.Fprintf(os.Stderr, "resetgen: invalid struct name in %s: %s\n", name, structName) + os.Exit(1) + } + + if !isValidPackagePath(pkgPath) { + fmt.Fprintf(os.Stderr, "resetgen: invalid package path in %s: %s\n", name, pkgPath) + os.Exit(1) + } + + structFilter[name] = true + } else { + if !isValidGoIdentifier(name) { + fmt.Fprintf(os.Stderr, "resetgen: invalid struct name: %s\n", name) + os.Exit(1) + } + structFilter[name] = true + } + } + + if len(structFilter) == 0 { + fmt.Fprintln(os.Stderr, "resetgen: -structs flag is empty, nothing to process") + os.Exit(0) + } + } + args := flag.Args() if len(args) == 0 { - // Check for go generate environment if gofile := os.Getenv("GOFILE"); gofile != "" { - // Running via go generate - process current file args = []string{gofile} } else { - // Default: process current directory args = []string{"."} } } - if err := run(args, dryRun); err != nil { + if err := run(args, dryRun, structFilter); err != nil { fmt.Fprintf(os.Stderr, "resetgen: %v\n", err) os.Exit(1) } } -func run(patterns []string, dryRun bool) error { +// run processes Go files found by the given patterns and generates Reset() methods for structs +// that match the structFilter (or have reset tags/directives if no filter is provided). +// If dryRun is true, generated code is printed instead of written to files. +func run(patterns []string, dryRun bool, structFilter map[string]bool) error { files, err := findFiles(patterns) if err != nil { return err @@ -69,7 +117,7 @@ func run(patterns []string, dryRun bool) error { processed := 0 for _, file := range files { - ok, err := processFile(file, dryRun) + ok, err := processFile(file, dryRun, structFilter) if err != nil { return fmt.Errorf("%s: %w", file, err) } @@ -85,12 +133,14 @@ func run(patterns []string, dryRun bool) error { return nil } +// findFiles resolves file patterns (e.g., "./...", "./pkg", "file.go") to a list of Go source files. +// Patterns ending with "/..." recursively walk the directory tree. Hidden directories, vendor, +// and testdata directories are skipped. Test files and generated files are excluded. func findFiles(patterns []string) ([]string, error) { var files []string seen := make(map[string]bool) for _, pattern := range patterns { - // Handle ./... pattern if strings.HasSuffix(pattern, "/...") { dir := strings.TrimSuffix(pattern, "/...") if dir == "." || dir == "" { @@ -101,7 +151,6 @@ func findFiles(patterns []string) ([]string, error) { return err } if info.IsDir() { - // Skip hidden directories and testdata name := info.Name() if strings.HasPrefix(name, ".") || name == "testdata" || name == "vendor" { return filepath.SkipDir @@ -120,14 +169,12 @@ func findFiles(patterns []string) ([]string, error) { continue } - // Check if it's a directory info, err := os.Stat(pattern) if err != nil { return nil, err } if info.IsDir() { - // Process all Go files in directory entries, err := os.ReadDir(pattern) if err != nil { return nil, err @@ -143,7 +190,6 @@ func findFiles(patterns []string) ([]string, error) { } } } else if isGoSourceFile(pattern) && !seen[pattern] { - // Single file files = append(files, pattern) seen[pattern] = true } @@ -156,36 +202,48 @@ func isGoSourceFile(path string) bool { if !strings.HasSuffix(path, ".go") { return false } - // Skip test files and generated files base := filepath.Base(path) - if strings.HasSuffix(base, "_test.go") { - return false - } - if strings.HasSuffix(base, ".gen.go") { + if strings.HasSuffix(base, "_test.go") || strings.HasSuffix(base, ".gen.go") { return false } return true } -func processFile(path string, dryRun bool) (bool, error) { - info, err := parser.ParseFile(path) +// processFile parses a Go file, applies struct filtering, generates Reset() methods, +// and writes the result to a .gen.go file. Returns true if at least one struct was processed. +// Parsed structs are filtered by structFilter if provided; otherwise all structs with +// reset tags or directives are processed. +func processFile(path string, dryRun bool, structFilter map[string]bool) (bool, error) { + info, err := parser.ParseFile(path, nil) if err != nil { return false, err } + if structFilter != nil && len(info.Structs) > 0 { + var filteredStructs []types.StructInfo + for _, s := range info.Structs { + if shouldProcessStruct(s.Name, info.PkgName, structFilter) { + filteredStructs = append(filteredStructs, s) + } + } + info.Structs = filteredStructs + } + if len(info.Structs) == 0 { return false, nil } + if structFilter != nil { + warnUnfoundStructs(info, structFilter) + } + code := generator.Generate(info) if code == "" { return false, nil } - // Format the generated code formatted, err := format.Source([]byte(code)) if err != nil { - // If formatting fails, write unformatted code (useful for debugging) formatted = []byte(code) } @@ -195,7 +253,6 @@ func processFile(path string, dryRun bool) (bool, error) { return true, nil } - // Write to .gen.go file outPath := outputPath(path) if err := os.WriteFile(outPath, formatted, 0o644); err != nil { //nolint:gosec // generated code should be world-readable return false, err @@ -225,3 +282,81 @@ func printVersion() { } fmt.Println("resetgen", "dev") } + +func shouldProcessStruct(structName, pkgName string, filter map[string]bool) bool { + if filter == nil { + return true + } + + if filter[structName] { + return true + } + + qualifiedName := pkgName + "." + structName + return filter[qualifiedName] +} + +// warnUnfoundStructs emits warnings for structs specified in the filter but not found in the file. +// Only warns for entries relevant to this file's package; qualified names are only warned if +// they match this package, while simple names always trigger warnings if not found. +func warnUnfoundStructs(info *types.FileInfo, structFilter map[string]bool) { + if len(structFilter) == 0 { + return + } + + foundNames := make(map[string]bool) + for _, s := range info.Structs { + foundNames[s.Name] = true + foundNames[info.PkgName+"."+s.Name] = true + } + + for name := range structFilter { + if strings.Contains(name, ".") { + parts := strings.Split(name, ".") + if parts[0] == info.PkgName && !foundNames[name] { + fmt.Fprintf(os.Stderr, "resetgen: warning: struct %s not found in %s\n", parts[1], info.Path) + } + } else if !foundNames[name] { + fmt.Fprintf(os.Stderr, "resetgen: warning: struct %s not found in %s\n", name, info.Path) + } + } +} + +func isValidGoIdentifier(name string) bool { + if len(name) == 0 { + return false + } + + if name[0] < 'A' || name[0] > 'Z' { + return false + } + + for i := 1; i < len(name); i++ { + c := name[i] + if (c < 'A' || c > 'Z') && (c < 'a' || c > 'z') && + (c < '0' || c > '9') && c != '_' { + return false + } + } + + return true +} + +func isValidPackagePath(path string) bool { + if len(path) == 0 { + return false + } + + if path[0] == '.' { + return false + } + + for _, c := range path { + if (c < 'a' || c > 'z') && (c < '0' || c > '9') && + c != '.' && c != '/' && c != '_' { + return false + } + } + + return true +} diff --git a/main_test.go b/main_test.go new file mode 100644 index 0000000..5067e57 --- /dev/null +++ b/main_test.go @@ -0,0 +1,112 @@ +package main + +import "testing" + +func TestIsValidGoIdentifier(t *testing.T) { + tests := []struct { + name string + input string + valid bool + }{ + // Valid identifiers + {"single uppercase letter", "U", true}, + {"simple struct name", "User", true}, + {"with underscore", "User_Data", true}, + {"with number", "User123", true}, + {"CamelCase", "UserConfig", true}, + {"with multiple underscores", "User_Config_Data", true}, + + // Invalid identifiers + {"empty string", "", false}, + {"lowercase start", "user", false}, + {"starts with number", "123User", false}, + {"starts with underscore", "_User", false}, + {"contains space", "User Type", false}, + {"contains hyphen", "User-Type", false}, + {"contains dot", "User.Type", false}, + {"lowercase only", "config", false}, + {"number only", "123", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isValidGoIdentifier(tt.input) + if got != tt.valid { + t.Errorf("isValidGoIdentifier(%q) = %v, want %v", tt.input, got, tt.valid) + } + }) + } +} + +func TestIsValidPackagePath(t *testing.T) { + tests := []struct { + name string + input string + valid bool + }{ + // Valid package paths + {"simple lowercase", "models", true}, + {"with underscore", "api_v2", true}, + {"with dot", "github.com/user/pkg", true}, + {"nested path", "internal/api", true}, + {"complex path", "github.com/flaticols/resetgen", true}, + {"with numbers", "v2", true}, + + // Invalid package paths + {"empty string", "", false}, + {"uppercase only", "MODELS", false}, + {"starts with uppercase", "Models", false}, + {"mixed case not allowed", "myPackage", false}, + {"contains space", "api models", false}, + {"contains hyphen", "api-v2", false}, + {"starts with dot", ".models", false}, + {"contains special chars", "api@models", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isValidPackagePath(tt.input) + if got != tt.valid { + t.Errorf("isValidPackagePath(%q) = %v, want %v", tt.input, got, tt.valid) + } + }) + } +} + +func TestShouldProcessStruct(t *testing.T) { + tests := []struct { + name string + structName string + pkgName string + filter map[string]bool + want bool + }{ + // No filter - process all + {"no filter", "User", "models", nil, true}, + + // Simple name matches + {"simple name match", "User", "models", map[string]bool{"User": true}, true}, + {"simple name no match", "User", "models", map[string]bool{"Order": true}, false}, + + // Package-qualified matches + {"qualified match exact", "User", "models", map[string]bool{"models.User": true}, true}, + {"qualified no match different pkg", "User", "api", map[string]bool{"models.User": true}, false}, + + // Mixed filters + {"simple name takes precedence", "User", "api", map[string]bool{"User": true, "models.User": true}, true}, + {"qualified matches in correct pkg", "User", "api", map[string]bool{"api.User": true}, true}, + + // Empty filter + {"empty filter", "User", "models", map[string]bool{}, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := shouldProcessStruct(tt.structName, tt.pkgName, tt.filter) + if got != tt.want { + t.Errorf("shouldProcessStruct(%q, %q, %v) = %v, want %v", + tt.structName, tt.pkgName, tt.filter, got, tt.want) + } + }) + } +} diff --git a/testdata/directive/models.gen.go b/testdata/directive/models.gen.go new file mode 100644 index 0000000..6f09a1e --- /dev/null +++ b/testdata/directive/models.gen.go @@ -0,0 +1,28 @@ +// Code generated by resetgen. DO NOT EDIT. + +package directive + +func (s *Request) Reset() { + s.ID = "" + s.Method = "" + s.Path = "" + clear(s.Headers) + s.Body = s.Body[:0] +} + +func (s *Response) Reset() { + s.Status = 200 + s.Body = s.Body[:0] + clear(s.Headers) +} + +func (s *Config) Reset() { + s.Host = "" + s.Port = 8080 + s.Timeout = 0 +} + +func (s *User) Reset() { + s.ID = 0 + s.Name = "unknown" +} diff --git a/testdata/directive/models.go b/testdata/directive/models.go new file mode 100644 index 0000000..cdee496 --- /dev/null +++ b/testdata/directive/models.go @@ -0,0 +1,41 @@ +//go:generate go run ../.. + +package directive + +// +resetgen +type Request struct { + ID string + Method string + Path string + Headers map[string]string + Body []byte +} + +// +resetgen +type Response struct { + Status int `reset:"200"` + Body []byte + Headers map[string]string +} + +// Struct with directive and mixed tags +// +resetgen +type Config struct { + Host string + Port int `reset:"8080"` + Timeout int + secret string // unexported, will be skipped +} + +// Regular struct with tags (no directive) - should still work +type User struct { + ID int64 `reset:""` + Name string `reset:"unknown"` + Email string +} + +// Struct without directive and no tags - should be ignored +type NoTags struct { + Field1 string + Field2 int +} diff --git a/testdata/qualified-names/api/user.go b/testdata/qualified-names/api/user.go new file mode 100644 index 0000000..b2b63ae --- /dev/null +++ b/testdata/qualified-names/api/user.go @@ -0,0 +1,9 @@ +//go:generate go run ../../.. -structs models.User,api.User + +package api + +type User struct { + ID string `reset:""` + Status string `reset:"active"` + Metadata map[string]string `reset:""` +} diff --git a/testdata/qualified-names/models/user.go b/testdata/qualified-names/models/user.go new file mode 100644 index 0000000..a0c0e3d --- /dev/null +++ b/testdata/qualified-names/models/user.go @@ -0,0 +1,9 @@ +//go:generate go run ../../.. -structs models.User,api.User + +package models + +type User struct { + ID int64 `reset:""` + Name string `reset:""` + Email string `reset:""` +} diff --git a/testdata/structs-flag/models.gen.go b/testdata/structs-flag/models.gen.go new file mode 100644 index 0000000..10ad65f --- /dev/null +++ b/testdata/structs-flag/models.gen.go @@ -0,0 +1,15 @@ +// Code generated by resetgen. DO NOT EDIT. + +package structsflag + +func (s *User) Reset() { + s.ID = 0 + s.Name = "" + s.Email = "" +} + +func (s *Order) Reset() { + s.ID = 0 + s.Total = 0.0 + s.Items = s.Items[:0] +} diff --git a/testdata/structs-flag/models.go b/testdata/structs-flag/models.go new file mode 100644 index 0000000..12d1d4b --- /dev/null +++ b/testdata/structs-flag/models.go @@ -0,0 +1,25 @@ +//go:generate go run ../.. -structs User,Order + +package structsflag + +type User struct { + ID int64 + Name string + Email string + Secret string `reset:"-"` // Should respect the ignore tag even with -structs +} + +type Order struct { + ID int64 + Total float64 `reset:"0.0"` // Should respect custom value + Items []string +} + +type Logger struct { + Level string +} // Should NOT be generated (not in -structs list) + +type Config struct { + Host string `reset:""` + Port int `reset:"8080"` +} // Has tags but NOT in -structs list - should NOT be generated