Skip to content

Commit 1d67064

Browse files
systayclaude
andcommitted
asthelpergen: refactor for improved readability and maintainability
Improve code organization and error handling while preserving all functionality and generated output: - Extract constants (visitableName, anyTypeName) to reduce duplication - Replace log.Fatal() with proper error returns and context - Break down complex createFiles() into focused helper functions - Add comprehensive package and type documentation - Remove unused imports and duplicate definitions All tests pass and generated code remains identical. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com> Signed-off-by: Andres Taylor <andres@planetscale.com>
1 parent 902ef84 commit 1d67064

File tree

4 files changed

+173
-53
lines changed

4 files changed

+173
-53
lines changed

go/tools/asthelpergen/asthelpergen.go

Lines changed: 162 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,34 @@ See the License for the specific language governing permissions and
1414
limitations under the License.
1515
*/
1616

17+
// Package asthelpergen provides code generation for AST (Abstract Syntax Tree) helper methods.
18+
//
19+
// This package automatically generates helper methods for AST nodes including:
20+
// - Deep cloning (Clone methods)
21+
// - Equality comparison (Equals methods)
22+
// - Visitor pattern support (Visit methods)
23+
// - AST rewriting/transformation (Rewrite methods)
24+
// - Path enumeration for navigation
25+
// - Copy-on-write functionality
26+
//
27+
// The generator works by discovering all types that implement a root interface and
28+
// then generating the appropriate helper methods for each type using a plugin architecture.
29+
//
30+
// Usage:
31+
//
32+
// result, err := asthelpergen.GenerateASTHelpers(&asthelpergen.Options{
33+
// Packages: []string{"./mypackage"},
34+
// RootInterface: "github.com/example/mypackage.MyASTInterface",
35+
// })
36+
//
37+
// The generated code follows Go conventions and includes proper error handling,
38+
// nil checks, and type safety.
1739
package asthelpergen
1840

1941
import (
2042
"bytes"
2143
"fmt"
2244
"go/types"
23-
"log"
2445
"os"
2546
"path"
2647
"strings"
@@ -32,7 +53,13 @@ import (
3253
"vitess.io/vitess/go/tools/codegen"
3354
)
3455

35-
const licenseFileHeader = `Copyright 2025 The Vitess Authors.
56+
const (
57+
// Common constants used across generators
58+
visitableName = "Visitable"
59+
anyTypeName = "any"
60+
61+
// License header for generated files
62+
licenseFileHeader = `Copyright 2025 The Vitess Authors.
3663
3764
Licensed under the Apache License, Version 2.0 (the "License");
3865
you may not use this file except in compliance with the License.
@@ -45,42 +72,80 @@ distributed under the License is distributed on an "AS IS" BASIS,
4572
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
4673
See the License for the specific language governing permissions and
4774
limitations under the License.`
75+
)
4876

4977
type (
78+
// generatorSPI provides services to individual generators during code generation.
79+
// It acts as a service provider interface, giving generators access to type discovery
80+
// and scope information needed for generating helper methods.
5081
generatorSPI interface {
82+
// addType adds a newly discovered type to the processing queue
5183
addType(t types.Type)
84+
// scope returns the type scope for finding implementations
5285
scope() *types.Scope
86+
// findImplementations finds all types that implement the given interface
5387
findImplementations(iff *types.Interface, impl func(types.Type) error) error
54-
iface() *types.Interface // the root interface that all nodes are expected to implement
88+
// iface returns the root interface that all nodes are expected to implement
89+
iface() *types.Interface
5590
}
91+
92+
// generator defines the interface that all specialized generators must implement.
93+
// Each generator handles specific types of Go constructs (structs, interfaces, etc.)
94+
// and produces the appropriate helper methods for those types.
5695
generator interface {
96+
// genFile generates the final output file for this generator
5797
genFile(generatorSPI) (string, *jen.File)
98+
// interfaceMethod handles interface types with type switching logic
5899
interfaceMethod(t types.Type, iface *types.Interface, spi generatorSPI) error
100+
// structMethod handles struct types with field iteration
59101
structMethod(t types.Type, strct *types.Struct, spi generatorSPI) error
102+
// ptrToStructMethod handles pointer-to-struct types
60103
ptrToStructMethod(t types.Type, strct *types.Struct, spi generatorSPI) error
104+
// ptrToBasicMethod handles pointer-to-basic types (e.g., *int, *string)
61105
ptrToBasicMethod(t types.Type, basic *types.Basic, spi generatorSPI) error
106+
// sliceMethod handles slice types with element processing
62107
sliceMethod(t types.Type, slice *types.Slice, spi generatorSPI) error
108+
// basicMethod handles basic types (int, string, bool, etc.)
63109
basicMethod(t types.Type, basic *types.Basic, spi generatorSPI) error
64110
}
65-
// astHelperGen finds implementations of the given interface,
66-
// and uses the supplied `generator`s to produce the output code
111+
112+
// astHelperGen is the main orchestrator that coordinates the code generation process.
113+
// It discovers implementations of a root interface and uses multiple specialized
114+
// generators to produce helper methods for all discovered types.
67115
astHelperGen struct {
116+
// DebugTypes enables debug output for type processing
68117
DebugTypes bool
69-
mod *packages.Module
70-
sizes types.Sizes
118+
// mod is the Go module information for path resolution
119+
mod *packages.Module
120+
// sizes provides platform-specific type size information
121+
sizes types.Sizes
122+
// namedIface is the root interface type for which helpers are generated
71123
namedIface *types.Named
72-
_iface *types.Interface
73-
gens []generator
124+
// _iface is the underlying interface type
125+
_iface *types.Interface
126+
// gens is the list of specialized generators (clone, equals, visit, etc.)
127+
gens []generator
74128

129+
// _scope is the type scope for finding implementations
75130
_scope *types.Scope
76-
todo []types.Type
131+
// todo is the queue of types that need to be processed
132+
todo []types.Type
77133
}
78134
)
79135

80136
func (gen *astHelperGen) iface() *types.Interface {
81137
return gen._iface
82138
}
83139

140+
// newGenerator creates a new AST helper generator with the specified configuration.
141+
//
142+
// Parameters:
143+
// - mod: Go module information for path resolution
144+
// - sizes: Platform-specific type size information
145+
// - named: The root interface type for which helpers will be generated
146+
// - generators: Specialized generators for different helper types (clone, equals, etc.)
147+
//
148+
// Returns a configured astHelperGen ready to process types and generate code.
84149
func newGenerator(mod *packages.Module, sizes types.Sizes, named *types.Named, generators ...generator) *astHelperGen {
85150
return &astHelperGen{
86151
DebugTypes: true,
@@ -107,7 +172,7 @@ func findImplementations(scope *types.Scope, iff *types.Interface, impl func(typ
107172
case *types.Interface:
108173
// This is OK; interfaces are references
109174
default:
110-
panic(fmt.Errorf("interface %s implemented by %s (%s as %T) without ptr", iff.String(), baseType, tt.String(), tt))
175+
return fmt.Errorf("interface %s implemented by %s (%s as %T) without ptr", iff.String(), baseType, tt.String(), tt)
111176
}
112177
}
113178
if types.TypeString(baseType, noQualifier) == visitableName {
@@ -140,7 +205,10 @@ func (gen *astHelperGen) GenerateCode() (map[string]*jen.File, error) {
140205

141206
gen._scope = pkg.Scope()
142207
gen.todo = append(gen.todo, gen.namedIface)
143-
jenFiles := gen.createFiles()
208+
jenFiles, err := gen.createFiles()
209+
if err != nil {
210+
return nil, err
211+
}
144212

145213
result := map[string]*jen.File{}
146214
for fName, genFile := range jenFiles {
@@ -175,16 +243,34 @@ func VerifyFilesOnDisk(result map[string]*jen.File) (errors []error) {
175243
return errors
176244
}
177245

246+
// Options configures the AST helper generation process.
178247
type Options struct {
179-
Packages []string
248+
// Packages specifies the Go packages to analyze for AST types.
249+
// Can be package paths like "./mypackage" or import paths like "github.com/example/ast".
250+
Packages []string
251+
252+
// RootInterface is the fully qualified name of the root interface that all AST nodes implement.
253+
// Format: "package.path.InterfaceName" (e.g., "github.com/example/ast.Node")
180254
RootInterface string
181255

182-
Clone CloneOptions
256+
// Clone configures the clone generator options
257+
Clone CloneOptions
258+
259+
// Equals configures the equality comparison generator options
183260
Equals EqualsOptions
184261
}
185262

186-
// GenerateASTHelpers loads the input code, constructs the necessary generators,
187-
// and generates the rewriter and clone methods for the AST
263+
// GenerateASTHelpers is the main entry point for generating AST helper methods.
264+
//
265+
// It loads the specified packages, analyzes the types that implement the root interface,
266+
// and generates comprehensive helper methods including clone, equals, visit, rewrite,
267+
// path enumeration, and copy-on-write functionality.
268+
//
269+
// The function returns a map where keys are file paths and values are the generated
270+
// Go source files. The caller is responsible for writing these files to disk.
271+
//
272+
// Returns an error if package loading fails, the root interface cannot be found,
273+
// or code generation encounters any issues.
188274
func GenerateASTHelpers(options *Options) (map[string]*jen.File, error) {
189275
loaded, err := packages.Load(&packages.Config{
190276
Mode: packages.NeedName | packages.NeedTypes | packages.NeedTypesSizes | packages.NeedTypesInfo | packages.NeedDeps | packages.NeedImports | packages.NeedModule,
@@ -260,48 +346,76 @@ func (gen *astHelperGen) addType(t types.Type) {
260346
gen.todo = append(gen.todo, t)
261347
}
262348

263-
func (gen *astHelperGen) createFiles() map[string]*jen.File {
349+
func (gen *astHelperGen) createFiles() (map[string]*jen.File, error) {
350+
if err := gen.processTypeQueue(); err != nil {
351+
return nil, err
352+
}
353+
return gen.generateOutputFiles(), nil
354+
}
355+
356+
// processTypeQueue processes all types in the todo queue with all generators
357+
func (gen *astHelperGen) processTypeQueue() error {
264358
alreadyDone := map[string]bool{}
265359
for len(gen.todo) > 0 {
266360
t := gen.todo[0]
267-
underlying := t.Underlying()
268361
typeName := printableTypeName(t)
269362
gen.todo = gen.todo[1:]
270363

271364
if alreadyDone[typeName] {
272365
continue
273366
}
274-
var err error
275-
for _, g := range gen.gens {
276-
switch underlying := underlying.(type) {
277-
case *types.Interface:
278-
err = g.interfaceMethod(t, underlying, gen)
279-
case *types.Slice:
280-
err = g.sliceMethod(t, underlying, gen)
281-
case *types.Struct:
282-
err = g.structMethod(t, underlying, gen)
283-
case *types.Pointer:
284-
ptrToType := underlying.Elem().Underlying()
285-
switch ptrToType := ptrToType.(type) {
286-
case *types.Struct:
287-
err = g.ptrToStructMethod(t, ptrToType, gen)
288-
case *types.Basic:
289-
err = g.ptrToBasicMethod(t, ptrToType, gen)
290-
default:
291-
panic(fmt.Sprintf("%T", ptrToType))
292-
}
293-
case *types.Basic:
294-
err = g.basicMethod(t, underlying, gen)
295-
default:
296-
log.Fatalf("don't know how to handle %s %T", typeName, underlying)
297-
}
298-
if err != nil {
299-
log.Fatal(err)
300-
}
367+
368+
if err := gen.processTypeWithGenerators(t); err != nil {
369+
return fmt.Errorf("failed to process type %s: %w", typeName, err)
301370
}
302371
alreadyDone[typeName] = true
303372
}
373+
return nil
374+
}
375+
376+
// processTypeWithGenerators dispatches a type to all generators based on its underlying type
377+
func (gen *astHelperGen) processTypeWithGenerators(t types.Type) error {
378+
underlying := t.Underlying()
379+
typeName := printableTypeName(t)
380+
381+
for _, g := range gen.gens {
382+
var err error
383+
switch underlying := underlying.(type) {
384+
case *types.Interface:
385+
err = g.interfaceMethod(t, underlying, gen)
386+
case *types.Slice:
387+
err = g.sliceMethod(t, underlying, gen)
388+
case *types.Struct:
389+
err = g.structMethod(t, underlying, gen)
390+
case *types.Pointer:
391+
err = gen.handlePointerType(t, underlying, g)
392+
case *types.Basic:
393+
err = g.basicMethod(t, underlying, gen)
394+
default:
395+
return fmt.Errorf("don't know how to handle type %s %T", typeName, underlying)
396+
}
397+
if err != nil {
398+
return fmt.Errorf("generator failed for type %s: %w", typeName, err)
399+
}
400+
}
401+
return nil
402+
}
304403

404+
// handlePointerType handles pointer types by dispatching to the appropriate method
405+
func (gen *astHelperGen) handlePointerType(t types.Type, ptr *types.Pointer, g generator) error {
406+
ptrToType := ptr.Elem().Underlying()
407+
switch ptrToType := ptrToType.(type) {
408+
case *types.Struct:
409+
return g.ptrToStructMethod(t, ptrToType, gen)
410+
case *types.Basic:
411+
return g.ptrToBasicMethod(t, ptrToType, gen)
412+
default:
413+
return fmt.Errorf("unsupported pointer type %T", ptrToType)
414+
}
415+
}
416+
417+
// generateOutputFiles collects the generated files from all generators
418+
func (gen *astHelperGen) generateOutputFiles() map[string]*jen.File {
305419
result := map[string]*jen.File{}
306420
for _, g := range gen.gens {
307421
fName, jenFile := g.genFile(gen)
@@ -310,6 +424,9 @@ func (gen *astHelperGen) createFiles() map[string]*jen.File {
310424
return result
311425
}
312426

427+
// noQualifier is used to print types without package qualifiers
428+
var noQualifier = func(*types.Package) string { return "" }
429+
313430
// printableTypeName returns a string that can be used as a valid golang identifier
314431
func printableTypeName(t types.Type) string {
315432
switch t := t.(type) {

go/tools/asthelpergen/clone_gen.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,11 @@ import (
2626
"github.com/dave/jennifer/jen"
2727
)
2828

29+
// CloneOptions configures the clone generator behavior.
2930
type CloneOptions struct {
31+
// Exclude specifies type patterns that should not be deep cloned.
32+
// Types matching these patterns will be returned as-is instead of being cloned.
33+
// Patterns use glob-style matching (e.g., "*NoCloneType").
3034
Exclude []string
3135
}
3236

@@ -68,7 +72,7 @@ func (c *cloneGen) readValueOfType(t types.Type, expr jen.Code, spi generatorSPI
6872
case *types.Basic:
6973
return expr
7074
case *types.Interface:
71-
if types.TypeString(t, noQualifier) == "any" {
75+
if types.TypeString(t, noQualifier) == anyTypeName {
7276
// these fields have to be taken care of manually
7377
return expr
7478
}

go/tools/asthelpergen/equals_gen.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,11 @@ import (
2626

2727
const Comparator = "Comparator"
2828

29+
// EqualsOptions configures the equals generator behavior.
2930
type EqualsOptions struct {
31+
// AllowCustom specifies types that can have custom equality comparators.
32+
// For these types, the generated Comparator struct will include function fields
33+
// that allow custom comparison logic to be injected at runtime.
3034
AllowCustom []string
3135
}
3236

@@ -167,7 +171,7 @@ func compareAllStructFields(strct *types.Struct, spi generatorSPI) jen.Code {
167171
var others []*jen.Statement
168172
for i := 0; i < strct.NumFields(); i++ {
169173
field := strct.Field(i)
170-
if field.Type().Underlying().String() == "any" || strings.HasPrefix(field.Name(), "_") {
174+
if field.Type().Underlying().String() == anyTypeName || strings.HasPrefix(field.Name(), "_") {
171175
// we can safely ignore this, we do not want ast to contain `any` types.
172176
continue
173177
}

go/tools/asthelpergen/rewrite_gen.go

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@ import (
2424
)
2525

2626
const (
27-
rewriteName = "rewrite"
28-
visitableName = "Visitable"
27+
rewriteName = "rewrite"
2928
)
3029

3130
type rewriteGen struct {
@@ -550,7 +549,3 @@ func returnTrue() jen.Code {
550549
func returnFalse() jen.Code {
551550
return jen.Return(jen.False())
552551
}
553-
554-
var noQualifier = func(p *types.Package) string {
555-
return ""
556-
}

0 commit comments

Comments
 (0)