From bb87d291eba268205b930df9e6028267527451fb Mon Sep 17 00:00:00 2001 From: Kolbe Kegel Date: Fri, 17 Oct 2025 14:40:31 -0500 Subject: [PATCH 01/12] add linter framework --- pkg/lint/README.md | 164 ++++++++ pkg/lint/example/example.go | 132 ++++++ pkg/lint/example/example_test.go | 247 +++++++++++ pkg/lint/lint.go | 166 ++++++++ pkg/lint/lint_test.go | 484 ++++++++++++++++++++++ pkg/lint/linter.go | 36 ++ pkg/lint/linters/invisible_index.go | 91 ++++ pkg/lint/linters/invisible_index_test.go | 188 +++++++++ pkg/lint/linters/multiple_alter.go | 102 +++++ pkg/lint/linters/multiple_alter_test.go | 271 ++++++++++++ pkg/lint/linters/primary_key_type.go | 156 +++++++ pkg/lint/linters/primary_key_type_test.go | 396 ++++++++++++++++++ pkg/lint/registry.go | 131 ++++++ pkg/lint/violation.go | 84 ++++ 14 files changed, 2648 insertions(+) create mode 100644 pkg/lint/README.md create mode 100644 pkg/lint/example/example.go create mode 100644 pkg/lint/example/example_test.go create mode 100644 pkg/lint/lint.go create mode 100644 pkg/lint/lint_test.go create mode 100644 pkg/lint/linter.go create mode 100644 pkg/lint/linters/invisible_index.go create mode 100644 pkg/lint/linters/invisible_index_test.go create mode 100644 pkg/lint/linters/multiple_alter.go create mode 100644 pkg/lint/linters/multiple_alter_test.go create mode 100644 pkg/lint/linters/primary_key_type.go create mode 100644 pkg/lint/linters/primary_key_type_test.go create mode 100644 pkg/lint/registry.go create mode 100644 pkg/lint/violation.go diff --git a/pkg/lint/README.md b/pkg/lint/README.md new file mode 100644 index 00000000..fe00cfff --- /dev/null +++ b/pkg/lint/README.md @@ -0,0 +1,164 @@ +# Linter Framework + +The `lint` package provides a framework for static analysis of MySQL schema definitions and DDL statements. It enables validation and best-practice enforcement beyond the runtime checks provided by the `check` package. + +## Quick Start + +### Creating a Linter + +```go +package linters + +import ( + "github.com/block/spirit/pkg/lint" + "github.com/block/spirit/pkg/statement" +) + +// Register your linter in init() +func init() { + lint.Register(&MyLinter{}) +} + +type MyLinter struct{} + +func (l *MyLinter) Name() string { return "my_linter" } +func (l *MyLinter) Category() string { return "naming" } +func (l *MyLinter) Description() string { return "Checks naming conventions" } +func (l *MyLinter) String() string { return l.Name() } + +func (l *MyLinter) Lint(createTables []*statement.CreateTable, alterStatements []*statement.AbstractStatement) []lint.Violation { + var violations []lint.Violation + + for _, ct := range createTables { + // Check table properties + if /* condition */ { + violations = append(violations, lint.Violation{ + Linter: l, + Severity: lint.SeverityWarning, + Message: "Table name issue", + Location: &lint.Location{ + Table: ct.GetTableName(), + }, + }) + } + } + + return violations +} +``` + +### Running Linters + +```go +import ( + "github.com/block/spirit/pkg/lint" +) + +// Run all enabled linters +violations := lint.RunLinters(tables, stmts, lint.Config{}) + +// Check for errors +if lint.HasErrors(violations) { + // Handle errors +} + +// Filter violations +errors := lint.FilterBySeverity(violations, lint.SeverityError) +warnings := lint.FilterBySeverity(violations, lint.SeverityWarning) +``` + +### Configuring Linters + +```go +// Disable specific linters +violations := lint.RunLinters(tables, stmts, lint.Config{ + Enabled: map[string]bool{ + "table_name_length": false, + "duplicate_column": true, + }, +}) +``` + +## Core Types + +### Severity Levels + +- **ERROR**: Will cause actual problems (syntax errors, MySQL limitations) +- **WARNING**: Best practice violations, potential issues +- **INFO**: Suggestions, style preferences + +### Violation + +```go +type Violation struct { + Linter Linter // The linter that produced this violation + Severity Severity // ERROR, WARNING, or INFO + Message string // Human-readable message + Location *Location // Where the violation occurred + Suggestion *string // Optional fix suggestion + Context map[string]any // Additional context +} +``` + +### Location + +```go +type Location struct { + Table string // Table name + Column *string // Column name (if applicable) + Index *string // Index name (if applicable) + Constraint *string // Constraint name (if applicable) +} +``` + +## API Functions + +### Registration + +- `Register(l Linter)` - Register a linter (call from init()) +- `Enable(name string)` - Enable a linter by name +- `Disable(name string)` - Disable a linter by name +- `List()` - Get all registered linter names +- `ListByCategory(category string)` - Get linters in a category +- `Get(name string)` - Get a linter by name + +### Execution + +- `RunLinters(createTables, alterStatements, config)` - Run all enabled linters +- `HasErrors(violations)` - Check if any violations are errors +- `HasWarnings(violations)` - Check if any violations are warnings +- `FilterBySeverity(violations, severity)` - Filter by severity level +- `FilterByLinter(violations, name)` - Filter by linter name + +## Example Linters + +The `example` package provides two demonstration linters: + +### TableNameLengthLinter + +Checks that table names don't exceed MySQL's 64 character limit. + +```go +lint.Register(example.NewTableNameLengthLinter()) +``` + +### DuplicateColumnLinter + +Detects duplicate column definitions in CREATE TABLE statements. + +```go +lint.Register(&example.DuplicateColumnLinter{}) +``` + +## Contributing + +When adding new linters: + +1. Implement the `Linter` interface +2. Register in `init()` function +3. Add comprehensive tests +4. Document the linter's behavior +5. Choose appropriate severity levels +6. Provide helpful error messages and suggestions + +See `pkg/lint/example/` for reference implementations. diff --git a/pkg/lint/example/example.go b/pkg/lint/example/example.go new file mode 100644 index 00000000..0fc79ced --- /dev/null +++ b/pkg/lint/example/example.go @@ -0,0 +1,132 @@ +// Package example provides example linters to demonstrate the linter framework. +// These linters are for demonstration purposes and are not registered by default. +package example + +import ( + "errors" + "fmt" + + "github.com/block/spirit/pkg/lint" + "github.com/block/spirit/pkg/statement" +) + +// TableNameLengthLinter checks that table names are not too long. +// MySQL has a limit of 64 characters for table names. +type TableNameLengthLinter struct { + maxLength int +} + +// TableNameLengthConfig holds configuration for the table name length linter. +type TableNameLengthConfig struct { + MaxLength int `json:"max_length"` +} + +// NewTableNameLengthLinter creates a new table name length linter with default configuration. +func NewTableNameLengthLinter() *TableNameLengthLinter { + return &TableNameLengthLinter{ + maxLength: 64, // MySQL's limit + } +} + +func (l *TableNameLengthLinter) String() string { + return l.Name() +} + +func (l *TableNameLengthLinter) Name() string { + return "table_name_length" +} + +func (l *TableNameLengthLinter) Category() string { + return "naming" +} + +func (l *TableNameLengthLinter) Description() string { + return "Checks that table names do not exceed the configured maximum length (default: 64 characters)" +} + +func (l *TableNameLengthLinter) DefaultConfig() any { + return TableNameLengthConfig{ + MaxLength: 64, + } +} + +func (l *TableNameLengthLinter) Configure(config any) error { + cfg, ok := config.(TableNameLengthConfig) + if !ok { + return errors.New("invalid config type for table_name_length linter: expected TableNameLengthConfig") + } + + if cfg.MaxLength <= 0 { + return fmt.Errorf("max_length must be positive, got %d", cfg.MaxLength) + } + + l.maxLength = cfg.MaxLength + + return nil +} + +func (l *TableNameLengthLinter) Lint(createTables []*statement.CreateTable, statements []*statement.AbstractStatement) []lint.Violation { + var violations []lint.Violation + + for _, ct := range createTables { + tableName := ct.GetTableName() + if len(tableName) > l.maxLength { + violations = append(violations, lint.Violation{ + Linter: l, + Severity: lint.SeverityError, + Message: fmt.Sprintf("Table name '%s' exceeds maximum length of %d characters (actual: %d)", tableName, l.maxLength, len(tableName)), + Location: &lint.Location{ + Table: tableName, + }, + }) + } + } + + return violations +} + +// DuplicateColumnLinter checks for duplicate column names in CREATE TABLE statements. +type DuplicateColumnLinter struct{} + +func (l *DuplicateColumnLinter) String() string { + return l.Name() +} + +func (l *DuplicateColumnLinter) Name() string { + return "duplicate_column" +} + +func (l *DuplicateColumnLinter) Category() string { + return "schema" +} + +func (l *DuplicateColumnLinter) Description() string { + return "Detects duplicate column definitions in CREATE TABLE statements" +} + +func (l *DuplicateColumnLinter) Lint(createTables []*statement.CreateTable, statements []*statement.AbstractStatement) []lint.Violation { + var violations []lint.Violation + + for _, ct := range createTables { + tableName := ct.GetTableName() + seen := make(map[string]bool) + + for _, col := range ct.GetColumns() { + if seen[col.Name] { + violations = append(violations, lint.Violation{ + Linter: l, + Severity: lint.SeverityError, + Message: fmt.Sprintf("Duplicate column definition: '%s'", col.Name), + Location: &lint.Location{ + Table: tableName, + Column: &col.Name, + }, + }) + } + + seen[col.Name] = true + } + } + + return violations +} diff --git a/pkg/lint/example/example_test.go b/pkg/lint/example/example_test.go new file mode 100644 index 00000000..7cfd23c7 --- /dev/null +++ b/pkg/lint/example/example_test.go @@ -0,0 +1,247 @@ +package example + +import ( + "testing" + + "github.com/block/spirit/pkg/lint" + "github.com/block/spirit/pkg/statement" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTableNameLengthLinter_Valid(t *testing.T) { + sql := "CREATE TABLE users (id INT PRIMARY KEY)" + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := NewTableNameLengthLinter() + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + + assert.Empty(t, violations) +} + +func TestTableNameLengthLinter_TooLong(t *testing.T) { + // Create a table name that's 65 characters (exceeds MySQL's 64 char limit) + longName := "this_is_a_very_long_table_name_that_exceeds_the_mysql_limit_abcde" + require.Len(t, longName, 65, "Test setup: name should be 65 chars") + + sql := "CREATE TABLE " + longName + " (id INT PRIMARY KEY)" + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := NewTableNameLengthLinter() + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + + require.Len(t, violations, 1) + assert.Equal(t, "table_name_length", violations[0].Linter.Name()) + assert.Equal(t, lint.SeverityError, violations[0].Severity) + assert.Contains(t, violations[0].Message, "exceeds maximum length") + assert.Equal(t, longName, violations[0].Location.Table) +} + +func TestTableNameLengthLinter_ExactlyAtLimit(t *testing.T) { + // Create a table name that's exactly 64 characters + exactName := "abcdefghij_abcdefghij_abcdefghij_abcdefghij_abcdefghij_abcdefghi" + require.Len(t, exactName, 64, "Test setup: name should be exactly 64 chars") + + sql := "CREATE TABLE " + exactName + " (id INT PRIMARY KEY)" + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := NewTableNameLengthLinter() + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + + assert.Empty(t, violations, "64 character name should be allowed") +} + +func TestTableNameLengthLinter_Configure(t *testing.T) { + // Create a table name that's 50 characters + name := "abcdefghij_abcdefghij_abcdefghij_abcdefghij_abcdef" + require.Len(t, name, 50, "Test setup: name should be 50 chars") + + sql := "CREATE TABLE " + name + " (id INT PRIMARY KEY)" + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := NewTableNameLengthLinter() + + // With default config (64), should pass + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + assert.Empty(t, violations) + + // Configure to max length of 40 + err = linter.Configure(TableNameLengthConfig{MaxLength: 40}) + require.NoError(t, err) + + // Now should fail + violations = linter.Lint([]*statement.CreateTable{ct}, nil) + require.Len(t, violations, 1) + assert.Contains(t, violations[0].Message, "exceeds maximum length of 40") +} + +func TestTableNameLengthLinter_Configure_InvalidConfig(t *testing.T) { + linter := NewTableNameLengthLinter() + + // Wrong type + err := linter.Configure("invalid") + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid config type") + + // Zero length + err = linter.Configure(TableNameLengthConfig{MaxLength: 0}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "must be positive") + + // Negative length + err = linter.Configure(TableNameLengthConfig{MaxLength: -1}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "must be positive") +} + +func TestTableNameLengthLinter_DefaultConfig(t *testing.T) { + linter := NewTableNameLengthLinter() + + config := linter.DefaultConfig() + require.NotNil(t, config) + + cfg, ok := config.(TableNameLengthConfig) + require.True(t, ok) + assert.Equal(t, 64, cfg.MaxLength) +} + +func TestDuplicateColumnLinter_NoDuplicates(t *testing.T) { + sql := "CREATE TABLE users (id INT, name VARCHAR(100), email VARCHAR(255))" + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := &DuplicateColumnLinter{} + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + + assert.Empty(t, violations) +} + +func TestDuplicateColumnLinter_WithDuplicates(t *testing.T) { + sql := "CREATE TABLE users (id INT, name VARCHAR(100), id INT)" + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := &DuplicateColumnLinter{} + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + + require.Len(t, violations, 1) + assert.Equal(t, "duplicate_column", violations[0].Linter.Name()) + assert.Equal(t, lint.SeverityError, violations[0].Severity) + assert.Contains(t, violations[0].Message, "Duplicate column definition") + assert.Equal(t, "users", violations[0].Location.Table) + assert.NotNil(t, violations[0].Location.Column) + assert.Equal(t, "id", *violations[0].Location.Column) +} + +func TestDuplicateColumnLinter_MultipleDuplicates(t *testing.T) { + sql := "CREATE TABLE test (id INT, name VARCHAR(100), id INT, name VARCHAR(100))" + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := &DuplicateColumnLinter{} + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + + require.Len(t, violations, 2, "Should detect both duplicate columns") + + // Check that both duplicates are reported + duplicates := make(map[string]bool) + + for _, v := range violations { + if v.Location.Column != nil { + duplicates[*v.Location.Column] = true + } + } + + assert.True(t, duplicates["id"]) + assert.True(t, duplicates["name"]) +} + +func TestExampleLinters_Integration(t *testing.T) { + // Reset the global registry + lint.Reset() + + // Register our example linters + lint.Register(NewTableNameLengthLinter()) + lint.Register(&DuplicateColumnLinter{}) + + // Create a table with both issues + longName := "this_is_a_very_long_table_name_that_exceeds_the_mysql_limit_abcde" + require.Len(t, longName, 65, "Test setup: name should be 65 chars") + + sql := "CREATE TABLE " + longName + " (id INT, name VARCHAR(100), id INT)" + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + // Run all linters + violations := lint.RunLinters([]*statement.CreateTable{ct}, nil, lint.Config{}) + + // Should have violations from both linters + require.Len(t, violations, 2) + + // Check that we have one violation from each linter + linterCounts := make(map[string]int) + for _, v := range violations { + linterCounts[v.Linter.Name()]++ + } + + assert.Equal(t, 1, linterCounts["table_name_length"]) + assert.Equal(t, 1, linterCounts["duplicate_column"]) +} + +func TestExampleLinters_WithConfig(t *testing.T) { + lint.Reset() + + lint.Register(NewTableNameLengthLinter()) + lint.Register(&DuplicateColumnLinter{}) + + longName := "this_is_a_very_long_table_name_that_exceeds_the_mysql_limit_abc" + sql := "CREATE TABLE " + longName + " (id INT, name VARCHAR(100), id INT)" + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + // Disable the table name length linter + violations := lint.RunLinters([]*statement.CreateTable{ct}, nil, lint.Config{ + Enabled: map[string]bool{ + "table_name_length": false, + "duplicate_column": true, + }, + }) + + // Should only have violation from duplicate_column linter + require.Len(t, violations, 1) + assert.Equal(t, "duplicate_column", violations[0].Linter.Name()) +} + +func TestTableNameLengthLinter_WithConfigSettings(t *testing.T) { + lint.Reset() + + lint.Register(NewTableNameLengthLinter()) + + // Create a table name that's 50 characters + name := "abcdefghij_abcdefghij_abcdefghij_abcdefghij_abcdef" + require.Len(t, name, 50, "Test setup: name should be 50 chars") + + sql := "CREATE TABLE " + name + " (id INT PRIMARY KEY)" + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + // With default config (64), should pass + violations := lint.RunLinters([]*statement.CreateTable{ct}, nil, lint.Config{}) + assert.Empty(t, violations) + + // Configure max length to 40 via Config.Settings + violations = lint.RunLinters([]*statement.CreateTable{ct}, nil, lint.Config{ + Settings: map[string]any{ + "table_name_length": TableNameLengthConfig{MaxLength: 40}, + }, + }) + + // Should now have a violation + require.Len(t, violations, 1) + assert.Equal(t, "table_name_length", violations[0].Linter.Name()) + assert.Contains(t, violations[0].Message, "exceeds maximum length of 40") +} diff --git a/pkg/lint/lint.go b/pkg/lint/lint.go new file mode 100644 index 00000000..62de1ed3 --- /dev/null +++ b/pkg/lint/lint.go @@ -0,0 +1,166 @@ +// Package lint provides a framework for static analysis of MySQL schema definitions +// and DDL statements. It enables validation and best-practice enforcement beyond +// the runtime checks provided by the check package. +// +// The linter framework operates on parsed CREATE TABLE statements rather than live +// database connections. +// +// # Basic Usage +// +// Linters are registered via init() functions and executed via RunLinters(): +// +// package naming +// +// func init() { +// lint.Register(&TableNameLinter{}) +// } +// +// // Later, run all linters: +// violations := lint.RunLinters(tables, stmts, config) +// +// # Creating a Linter +// +// To create a custom linter, implement the Linter interface: +// +// type MyLinter struct{} +// +// func (l *MyLinter) Name() string { return "my_linter" } +// func (l *MyLinter) Category() string { return "custom" } +// func (l *MyLinter) Description() string { return "My custom linter" } +// func (l *MyLinter) Lint(createTables []*statement.CreateTable, statements []*statement.AbstractStatement) []lint.Violation { +// // Perform linting logic +// return violations +// } +// +// # Configuration +// +// Linters can be enabled/disabled via the Config.Enabled map: +// +// config := lint.Config{ +// Enabled: map[string]bool{ +// "table_name": true, +// "column_name": false, +// }, +// } +// +// Configurable linters can implement the ConfigurableLinter interface to accept +// custom settings via Config.Settings. +package lint + +import ( + "github.com/block/spirit/pkg/statement" +) + +// Config holds linter configuration +type Config struct { + // Enabled maps linter names to whether they are enabled + // If a linter is not in this map, it uses its default enabled state + Enabled map[string]bool + + // Settings maps linter names to their configuration + // The configuration type is linter-specific + Settings map[string]any +} + +// RunLinters runs all enabled linters and returns any violations found. +// Linters are executed in an undefined order. +// +// A linter is executed if: +// - It is enabled by default (set during Register), AND +// - It is not explicitly disabled in config.Enabled +// +// OR: +// - It is explicitly enabled in config.Enabled +// +// If a linter implements ConfigurableLinter and has settings in config.Settings, +// those settings are applied before running the linter. +func RunLinters(createTables []*statement.CreateTable, alterStatements []*statement.AbstractStatement, config Config) []Violation { + lock.RLock() + defer lock.RUnlock() + + var violations []Violation + + for name, linter := range linters { + // Check if linter is explicitly disabled in config + if enabled, ok := config.Enabled[name]; ok && !enabled { + continue + } + + // Check if linter is explicitly enabled in config + explicitlyEnabled := false + if enabled, ok := config.Enabled[name]; ok && enabled { + explicitlyEnabled = true + } + + // Skip if not enabled by default and not explicitly enabled + if !linter.enabled && !explicitlyEnabled { + continue + } + + // Apply configuration if available + if configurableLinter, ok := linter.impl.(ConfigurableLinter); ok { + if settings, ok := config.Settings[name]; ok { + err := configurableLinter.Configure(settings) + if err != nil { + // Configuration error - skip this linter + // In a production system, we might want to log this + continue + } + } + } + + // Run the linter + lintViolations := linter.impl.Lint(createTables, alterStatements) + violations = append(violations, lintViolations...) + } + + return violations +} + +// HasErrors returns true if any violations have ERROR severity. +func HasErrors(violations []Violation) bool { + for _, v := range violations { + if v.Severity == SeverityError { + return true + } + } + + return false +} + +// HasWarnings returns true if any violations have WARNING severity. +func HasWarnings(violations []Violation) bool { + for _, v := range violations { + if v.Severity == SeverityWarning { + return true + } + } + + return false +} + +// FilterBySeverity returns only violations with the specified severity. +func FilterBySeverity(violations []Violation, severity Severity) []Violation { + var filtered []Violation + + for _, v := range violations { + if v.Severity == severity { + filtered = append(filtered, v) + } + } + + return filtered +} + +// FilterByLinter returns only violations from the specified linter. +func FilterByLinter(violations []Violation, linterName string) []Violation { + var filtered []Violation + + for _, v := range violations { + if v.Linter.Name() == linterName { + filtered = append(filtered, v) + } + } + + return filtered +} diff --git a/pkg/lint/lint_test.go b/pkg/lint/lint_test.go new file mode 100644 index 00000000..a9e36437 --- /dev/null +++ b/pkg/lint/lint_test.go @@ -0,0 +1,484 @@ +package lint + +import ( + "testing" + + "github.com/block/spirit/pkg/statement" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Mock linter for testing +type mockLinter struct { + name string + category string + description string + violations []Violation +} + +func (m *mockLinter) String() string { + //TODO implement me + panic("implement me") +} + +func (m *mockLinter) Name() string { return m.name } +func (m *mockLinter) Category() string { return m.category } +func (m *mockLinter) Description() string { return m.description } +func (m *mockLinter) Lint(createTables []*statement.CreateTable, statements []*statement.AbstractStatement) []Violation { + return m.violations +} + +// Configurable mock linter for testing +type mockConfigurableLinter struct { + mockLinter + + configCalled bool + configValue any +} + +func (m *mockConfigurableLinter) String() string { + //TODO implement me + panic("implement me") +} + +func (m *mockConfigurableLinter) Configure(config any) error { + m.configCalled = true + m.configValue = config + + return nil +} + +func (m *mockConfigurableLinter) DefaultConfig() any { + return "default" +} + +func TestRegister(t *testing.T) { + // Reset registry before test + Reset() + + linter := &mockLinter{ + name: "test_linter", + category: "test", + description: "A test linter", + } + + Register(linter) + + // Verify linter was registered + names := List() + assert.Contains(t, names, "test_linter") + + // Verify we can get it back + retrieved, err := Get("test_linter") + require.NoError(t, err) + assert.Equal(t, "test_linter", retrieved.Name()) +} + +func TestRegisterMultiple(t *testing.T) { + Reset() + + linter1 := &mockLinter{name: "linter1", category: "cat1"} + linter2 := &mockLinter{name: "linter2", category: "cat2"} + linter3 := &mockLinter{name: "linter3", category: "cat1"} + + Register(linter1) + Register(linter2) + Register(linter3) + + names := List() + assert.Len(t, names, 3) + assert.Contains(t, names, "linter1") + assert.Contains(t, names, "linter2") + assert.Contains(t, names, "linter3") +} + +func TestListByCategory(t *testing.T) { + Reset() + + linter1 := &mockLinter{name: "linter1", category: "naming"} + linter2 := &mockLinter{name: "linter2", category: "performance"} + linter3 := &mockLinter{name: "linter3", category: "naming"} + + Register(linter1) + Register(linter2) + Register(linter3) + + namingLinters := ListByCategory("naming") + assert.Len(t, namingLinters, 2) + assert.Contains(t, namingLinters, "linter1") + assert.Contains(t, namingLinters, "linter3") + + perfLinters := ListByCategory("performance") + assert.Len(t, perfLinters, 1) + assert.Contains(t, perfLinters, "linter2") + + emptyLinters := ListByCategory("nonexistent") + assert.Empty(t, emptyLinters) +} + +func TestEnableDisable(t *testing.T) { + Reset() + + linter := &mockLinter{name: "test_linter", category: "test"} + Register(linter) + + // Linters are enabled by default + assert.True(t, linters["test_linter"].enabled) + + // Disable it + err := Disable("test_linter") + require.NoError(t, err) + assert.False(t, linters["test_linter"].enabled) + + // Enable it again + err = Enable("test_linter") + require.NoError(t, err) + assert.True(t, linters["test_linter"].enabled) +} + +func TestEnableDisableNonexistent(t *testing.T) { + Reset() + + err := Enable("nonexistent") + assert.Error(t, err) + assert.Contains(t, err.Error(), "not found") + + err = Disable("nonexistent") + assert.Error(t, err) + assert.Contains(t, err.Error(), "not found") +} + +func TestGet(t *testing.T) { + Reset() + + linter := &mockLinter{ + name: "test_linter", + category: "test", + description: "A test linter", + } + Register(linter) + + retrieved, err := Get("test_linter") + require.NoError(t, err) + assert.Equal(t, "test_linter", retrieved.Name()) + assert.Equal(t, "test", retrieved.Category()) + assert.Equal(t, "A test linter", retrieved.Description()) +} + +func TestGetNonexistent(t *testing.T) { + Reset() + + _, err := Get("nonexistent") + assert.Error(t, err) + assert.Contains(t, err.Error(), "not found") +} + +func TestRunLinters_Empty(t *testing.T) { + Reset() + + violations := RunLinters(nil, nil, Config{}) + assert.Empty(t, violations) +} + +func TestRunLinters_SingleLinter(t *testing.T) { + Reset() + + linter := &mockLinter{ + name: "test_linter", + category: "test", + } + + expectedViolations := []Violation{ + { + Linter: linter, + Severity: SeverityError, + Message: "Test error", + }, + } + linter.violations = expectedViolations + + Register(linter) + + violations := RunLinters(nil, nil, Config{}) + assert.Len(t, violations, 1) + assert.Equal(t, "test_linter", violations[0].Linter.Name()) + assert.Equal(t, SeverityError, violations[0].Severity) + assert.Equal(t, "Test error", violations[0].Message) +} + +func TestRunLinters_MultipleLinters(t *testing.T) { + Reset() + + linter1 := &mockLinter{ + name: "linter1", + category: "test", + } + linter1.violations = []Violation{ + {Linter: linter1, Severity: SeverityError, Message: "Error 1"}, + } + + linter2 := &mockLinter{ + name: "linter2", + category: "test", + } + linter2.violations = []Violation{ + {Linter: linter2, Severity: SeverityWarning, Message: "Warning 1"}, + {Linter: linter2, Severity: SeverityInfo, Message: "Info 1"}, + } + + Register(linter1) + Register(linter2) + + violations := RunLinters(nil, nil, Config{}) + assert.Len(t, violations, 3) +} + +func TestRunLinters_WithConfig_Disabled(t *testing.T) { + Reset() + + linter := &mockLinter{ + name: "test_linter", + category: "test", + } + linter.violations = []Violation{ + {Linter: linter, Severity: SeverityError, Message: "Should not see this"}, + } + Register(linter) + + // Disable the linter via config + violations := RunLinters(nil, nil, Config{ + Enabled: map[string]bool{ + "test_linter": false, + }, + }) + + assert.Empty(t, violations) +} + +func TestRunLinters_WithConfig_Enabled(t *testing.T) { + Reset() + + linter := &mockLinter{ + name: "test_linter", + category: "test", + } + linter.violations = []Violation{ + {Linter: linter, Severity: SeverityError, Message: "Should see this"}, + } + + // Disable by default + Register(linter) + require.NoError(t, Disable("test_linter")) + + // But explicitly enable via config + violations := RunLinters(nil, nil, Config{ + Enabled: map[string]bool{ + "test_linter": true, + }, + }) + + assert.Len(t, violations, 1) + assert.Equal(t, "Should see this", violations[0].Message) +} + +func TestRunLinters_ConfigurableLinter(t *testing.T) { + Reset() + + linter := &mockConfigurableLinter{} + linter.name = "configurable_linter" + linter.category = "test" + linter.violations = []Violation{ + {Linter: linter, Severity: SeverityError, Message: "Test"}, + } + Register(linter) + + config := map[string]string{"key": "value"} + violations := RunLinters(nil, nil, Config{ + Settings: map[string]any{ + "configurable_linter": config, + }, + }) + + assert.Len(t, violations, 1) + assert.True(t, linter.configCalled) + assert.Equal(t, config, linter.configValue) +} + +func TestRunLinters_ConfigurableLinter_NoConfig(t *testing.T) { + Reset() + + linter := &mockConfigurableLinter{} + linter.name = "configurable_linter" + linter.category = "test" + linter.violations = []Violation{ + {Linter: linter, Severity: SeverityError, Message: "Test"}, + } + Register(linter) + + violations := RunLinters(nil, nil, Config{}) + + assert.Len(t, violations, 1) + assert.False(t, linter.configCalled) +} + +func TestHasErrors(t *testing.T) { + violations := []Violation{ + {Severity: SeverityWarning}, + {Severity: SeverityInfo}, + } + assert.False(t, HasErrors(violations)) + + violations = append(violations, Violation{Severity: SeverityError}) + assert.True(t, HasErrors(violations)) +} + +func TestHasWarnings(t *testing.T) { + violations := []Violation{ + {Severity: SeverityError}, + {Severity: SeverityInfo}, + } + assert.False(t, HasWarnings(violations)) + + violations = append(violations, Violation{Severity: SeverityWarning}) + assert.True(t, HasWarnings(violations)) +} + +func TestFilterBySeverity(t *testing.T) { + violations := []Violation{ + {Severity: SeverityError, Message: "Error 1"}, + {Severity: SeverityWarning, Message: "Warning 1"}, + {Severity: SeverityError, Message: "Error 2"}, + {Severity: SeverityInfo, Message: "Info 1"}, + } + + errors := FilterBySeverity(violations, SeverityError) + assert.Len(t, errors, 2) + assert.Equal(t, "Error 1", errors[0].Message) + assert.Equal(t, "Error 2", errors[1].Message) + + warnings := FilterBySeverity(violations, SeverityWarning) + assert.Len(t, warnings, 1) + assert.Equal(t, "Warning 1", warnings[0].Message) + + infos := FilterBySeverity(violations, SeverityInfo) + assert.Len(t, infos, 1) + assert.Equal(t, "Info 1", infos[0].Message) +} + +func TestFilterByLinter(t *testing.T) { + linter1 := &mockLinter{name: "linter1", category: "test"} + linter2 := &mockLinter{name: "linter2", category: "test"} + + violations := []Violation{ + {Linter: linter1, Message: "Message 1"}, + {Linter: linter2, Message: "Message 2"}, + {Linter: linter1, Message: "Message 3"}, + } + + linter1Violations := FilterByLinter(violations, "linter1") + assert.Len(t, linter1Violations, 2) + assert.Equal(t, "Message 1", linter1Violations[0].Message) + assert.Equal(t, "Message 3", linter1Violations[1].Message) + + linter2Violations := FilterByLinter(violations, "linter2") + assert.Len(t, linter2Violations, 1) + assert.Equal(t, "Message 2", linter2Violations[0].Message) + + nonexistentViolations := FilterByLinter(violations, "nonexistent") + assert.Empty(t, nonexistentViolations) +} + +func TestListSorted(t *testing.T) { + Reset() + + // Register in non-alphabetical order + Register(&mockLinter{name: "zebra", category: "test"}) + Register(&mockLinter{name: "alpha", category: "test"}) + Register(&mockLinter{name: "beta", category: "test"}) + + names := List() + assert.Equal(t, []string{"alpha", "beta", "zebra"}, names) +} + +func TestListByCategorySorted(t *testing.T) { + Reset() + + // Register in non-alphabetical order + Register(&mockLinter{name: "zebra", category: "cat1"}) + Register(&mockLinter{name: "alpha", category: "cat1"}) + Register(&mockLinter{name: "beta", category: "cat2"}) + Register(&mockLinter{name: "gamma", category: "cat1"}) + + names := ListByCategory("cat1") + assert.Equal(t, []string{"alpha", "gamma", "zebra"}, names) +} + +func TestReset(t *testing.T) { + Reset() + + Register(&mockLinter{name: "linter1", category: "test"}) + Register(&mockLinter{name: "linter2", category: "test"}) + + assert.Len(t, List(), 2) + + Reset() + + assert.Empty(t, List()) +} + +func TestViolationWithLocation(t *testing.T) { + column := "test_column" + index := "test_index" + constraint := "test_constraint" + linter := &mockLinter{name: "test_linter", category: "test"} + + violation := Violation{ + Linter: linter, + Severity: SeverityError, + Message: "Test message", + Location: &Location{ + Table: "test_table", + Column: &column, + Index: &index, + Constraint: &constraint, + }, + } + + assert.Equal(t, "test_table", violation.Location.Table) + assert.Equal(t, "test_column", *violation.Location.Column) + assert.Equal(t, "test_index", *violation.Location.Index) + assert.Equal(t, "test_constraint", *violation.Location.Constraint) +} + +func TestViolationWithSuggestion(t *testing.T) { + suggestion := "Try this instead" + linter := &mockLinter{name: "test_linter", category: "test"} + + violation := Violation{ + Linter: linter, + Severity: SeverityWarning, + Message: "Test message", + Suggestion: &suggestion, + } + + assert.NotNil(t, violation.Suggestion) + assert.Equal(t, "Try this instead", *violation.Suggestion) +} + +func TestViolationWithContext(t *testing.T) { + linter := &mockLinter{name: "test_linter", category: "test"} + + violation := Violation{ + Linter: linter, + Severity: SeverityInfo, + Message: "Test message", + Context: map[string]any{ + "key1": "value1", + "key2": 42, + }, + } + + assert.Len(t, violation.Context, 2) + assert.Equal(t, "value1", violation.Context["key1"]) + assert.Equal(t, 42, violation.Context["key2"]) +} diff --git a/pkg/lint/linter.go b/pkg/lint/linter.go new file mode 100644 index 00000000..c2dd2984 --- /dev/null +++ b/pkg/lint/linter.go @@ -0,0 +1,36 @@ +package lint + +import ( + "github.com/block/spirit/pkg/statement" +) + +// Linter is the interface that all linters must implement +type Linter interface { + // Name returns the unique name of this linter + Name() string + + // Category returns the category this linter belongs to + // (e.g., "naming", "performance", "security", "schema") + Category() string + + // Description returns a human-readable description of what this linter checks + Description() string + + // Lint performs the actual linting and returns any violations found. + // Linters can use either or both of the parameters as needed. + Lint(createTables []*statement.CreateTable, alterStatements []*statement.AbstractStatement) []Violation + + // String returns a string representation of the linter + String() string +} + +// ConfigurableLinter is an optional interface for linters that support configuration +type ConfigurableLinter interface { + Linter + + // Configure applies configuration to the linter + Configure(config any) error + + // DefaultConfig returns the default configuration for this linter + DefaultConfig() any +} diff --git a/pkg/lint/linters/invisible_index.go b/pkg/lint/linters/invisible_index.go new file mode 100644 index 00000000..e25b1827 --- /dev/null +++ b/pkg/lint/linters/invisible_index.go @@ -0,0 +1,91 @@ +package linters + +import ( + "fmt" + + "github.com/block/spirit/pkg/lint" + "github.com/block/spirit/pkg/statement" + "github.com/pingcap/tidb/pkg/parser/ast" +) + +func init() { + lint.Register(&InvisibleIndexBeforeDropLinter{}) +} + +// InvisibleIndexBeforeDropLinter checks that indexes are made invisible before dropping. +// This is a safety practice to ensure the index is not needed before permanently removing it. +type InvisibleIndexBeforeDropLinter struct{} + +func (l *InvisibleIndexBeforeDropLinter) String() string { + return l.Name() +} + +func (l *InvisibleIndexBeforeDropLinter) Name() string { + return "invisible_index_before_drop" +} + +func (l *InvisibleIndexBeforeDropLinter) Category() string { + return "schema" +} + +func (l *InvisibleIndexBeforeDropLinter) Description() string { + return "Requires indexes to be made invisible before dropping them as a safety measure" +} + +func (l *InvisibleIndexBeforeDropLinter) Lint(createTables []*statement.CreateTable, statements []*statement.AbstractStatement) []lint.Violation { + var violations []lint.Violation + + for _, stmt := range statements { + // Only check ALTER TABLE statements + if !stmt.IsAlterTable() { + continue + } + + alterStmt, ok := (*stmt.StmtNode).(*ast.AlterTableStmt) + if !ok { + continue + } + + tableName := stmt.Table + + // Check each ALTER specification + for _, spec := range alterStmt.Specs { + if spec.Tp != ast.AlterTableDropIndex { + continue + } + + indexName := spec.Name + + madeInvisible := false + // If not made invisible in this ALTER, check if it's invisible in the CREATE TABLE + if len(createTables) > 0 { + for _, ct := range createTables { + if ct.GetTableName() == tableName { + for _, idx := range ct.GetIndexes() { + if idx.Name == indexName && idx.Invisible != nil && *idx.Invisible { + madeInvisible = true + break + } + } + } + } + } + + if !madeInvisible { + suggestion := fmt.Sprintf("First make the index invisible: ALTER TABLE %s ALTER INDEX %s INVISIBLE", tableName, indexName) + violations = append(violations, lint.Violation{ + Linter: l, + Severity: lint.SeverityWarning, + Message: fmt.Sprintf("Index '%s' should be made invisible before dropping to ensure it's not needed", indexName), + Location: &lint.Location{ + Table: tableName, + Index: &indexName, + }, + Suggestion: &suggestion, + }) + } + } + } + + return violations +} diff --git a/pkg/lint/linters/invisible_index_test.go b/pkg/lint/linters/invisible_index_test.go new file mode 100644 index 00000000..13b8046f --- /dev/null +++ b/pkg/lint/linters/invisible_index_test.go @@ -0,0 +1,188 @@ +package linters + +import ( + "testing" + + "github.com/block/spirit/pkg/lint" + "github.com/block/spirit/pkg/statement" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestInvisibleIndexBeforeDropLinter_DropWithoutInvisible(t *testing.T) { + sql := "ALTER TABLE users DROP INDEX idx_email" + stmts, err := statement.New(sql) + require.NoError(t, err) + require.Len(t, stmts, 1) + + linter := &InvisibleIndexBeforeDropLinter{} + violations := linter.Lint(nil, stmts) + + require.Len(t, violations, 1) + assert.Equal(t, "invisible_index_before_drop", violations[0].Linter.Name()) + assert.Equal(t, lint.SeverityWarning, violations[0].Severity) + assert.Contains(t, violations[0].Message, "should be made invisible before dropping") + assert.Equal(t, "users", violations[0].Location.Table) + assert.NotNil(t, violations[0].Location.Index) + assert.Equal(t, "idx_email", *violations[0].Location.Index) + assert.NotNil(t, violations[0].Suggestion) + assert.Contains(t, *violations[0].Suggestion, "ALTER INDEX idx_email INVISIBLE") +} + +func TestInvisibleIndexBeforeDropLinter_DropAfterInvisibleInSameAlter(t *testing.T) { + sql := "ALTER TABLE users ALTER INDEX idx_email INVISIBLE, DROP INDEX idx_email" + stmts, err := statement.New(sql) + require.NoError(t, err) + require.Len(t, stmts, 1) + + linter := &InvisibleIndexBeforeDropLinter{} + violations := linter.Lint(nil, stmts) + + // Making an index invisible in the same ALTER statement where you drop it is obviously not good enough + assert.Len(t, violations, 1) + assert.IsType(t, &InvisibleIndexBeforeDropLinter{}, violations[0].Linter) + assert.Equal(t, "invisible_index_before_drop", violations[0].Linter.Name()) +} + +func TestInvisibleIndexBeforeDropLinter_DropAlreadyInvisibleIndex(t *testing.T) { + // Create a table with an invisible index + createSQL := `CREATE TABLE users ( + id INT PRIMARY KEY, + email VARCHAR(255), + INDEX idx_email (email) INVISIBLE + )` + ct, err := statement.ParseCreateTable(createSQL) + require.NoError(t, err) + + // Drop the invisible index + alterSQL := "ALTER TABLE users DROP INDEX idx_email" + stmts, err := statement.New(alterSQL) + require.NoError(t, err) + require.Len(t, stmts, 1) + + linter := &InvisibleIndexBeforeDropLinter{} + violations := linter.Lint([]*statement.CreateTable{ct}, stmts) + + // Should not have violations since index is already invisible + assert.Empty(t, violations) +} + +func TestInvisibleIndexBeforeDropLinter_DropVisibleIndex(t *testing.T) { + // Create a table with a visible index + createSQL := `CREATE TABLE users ( + id INT PRIMARY KEY, + email VARCHAR(255), + INDEX idx_email (email) + )` + ct, err := statement.ParseCreateTable(createSQL) + require.NoError(t, err) + + // Drop the visible index + alterSQL := "ALTER TABLE users DROP INDEX idx_email" + stmts, err := statement.New(alterSQL) + require.NoError(t, err) + require.Len(t, stmts, 1) + + linter := &InvisibleIndexBeforeDropLinter{} + violations := linter.Lint([]*statement.CreateTable{ct}, stmts) + + // Should have violation since index is visible + require.Len(t, violations, 1) + assert.Equal(t, "invisible_index_before_drop", violations[0].Linter.Name()) +} + +func TestInvisibleIndexBeforeDropLinter_MultipleDrops(t *testing.T) { + sql := "ALTER TABLE users DROP INDEX idx_email, DROP INDEX idx_name" + stmts, err := statement.New(sql) + require.NoError(t, err) + require.Len(t, stmts, 1) + + linter := &InvisibleIndexBeforeDropLinter{} + violations := linter.Lint(nil, stmts) + + // Should have violations for both indexes + require.Len(t, violations, 2) + + indexNames := make(map[string]bool) + + for _, v := range violations { + assert.Equal(t, "invisible_index_before_drop", v.Linter.Name()) + assert.Equal(t, lint.SeverityWarning, v.Severity) + + if v.Location.Index != nil { + indexNames[*v.Location.Index] = true + } + } + + assert.True(t, indexNames["idx_email"]) + assert.True(t, indexNames["idx_name"]) +} + +func TestInvisibleIndexBeforeDropLinter_NonAlterStatement(t *testing.T) { + sql := "CREATE TABLE users (id INT PRIMARY KEY)" + stmts, err := statement.New(sql) + require.NoError(t, err) + require.Len(t, stmts, 1) + + linter := &InvisibleIndexBeforeDropLinter{} + violations := linter.Lint(nil, stmts) + + // Should not have violations for non-ALTER statements + assert.Empty(t, violations) +} + +func TestInvisibleIndexBeforeDropLinter_AlterWithoutDrop(t *testing.T) { + sql := "ALTER TABLE users ADD COLUMN age INT" + stmts, err := statement.New(sql) + require.NoError(t, err) + require.Len(t, stmts, 1) + + linter := &InvisibleIndexBeforeDropLinter{} + violations := linter.Lint(nil, stmts) + + // Should not have violations for ALTER without DROP INDEX + assert.Empty(t, violations) +} + +func TestInvisibleIndexBeforeDropLinter_Integration(t *testing.T) { + // Reset registry and register linter + lint.Reset() + lint.Register(&InvisibleIndexBeforeDropLinter{}) + + sql := "ALTER TABLE users DROP INDEX idx_email" + stmts, err := statement.New(sql) + require.NoError(t, err) + + violations := lint.RunLinters(nil, stmts, lint.Config{}) + + require.Len(t, violations, 1) + assert.Equal(t, "invisible_index_before_drop", violations[0].Linter.Name()) +} + +func TestInvisibleIndexBeforeDropLinter_IntegrationDisabled(t *testing.T) { + // Reset registry and register linter + lint.Reset() + lint.Register(&InvisibleIndexBeforeDropLinter{}) + + sql := "ALTER TABLE users DROP INDEX idx_email" + stmts, err := statement.New(sql) + require.NoError(t, err) + + // Disable the linter + violations := lint.RunLinters(nil, stmts, lint.Config{ + Enabled: map[string]bool{ + "invisible_index_before_drop": false, + }, + }) + + // Should not have violations when disabled + assert.Empty(t, violations) +} + +func TestInvisibleIndexBeforeDropLinter_Metadata(t *testing.T) { + linter := &InvisibleIndexBeforeDropLinter{} + + assert.Equal(t, "invisible_index_before_drop", linter.Name()) + assert.Equal(t, "schema", linter.Category()) + assert.NotEmpty(t, linter.Description()) +} diff --git a/pkg/lint/linters/multiple_alter.go b/pkg/lint/linters/multiple_alter.go new file mode 100644 index 00000000..e1a346f1 --- /dev/null +++ b/pkg/lint/linters/multiple_alter.go @@ -0,0 +1,102 @@ +package linters + +import ( + "fmt" + "strings" + + "github.com/block/spirit/pkg/lint" + "github.com/block/spirit/pkg/statement" +) + +func init() { + lint.Register(&MultipleAlterTableLinter{}) +} + +// MultipleAlterTableLinter checks for multiple ALTER TABLE statements affecting the same table. +// Multiple ALTER TABLE statements on the same table should be combined into a single statement +// for better performance, fewer table rebuilds, and decreased danger of bad intermediate state. +type MultipleAlterTableLinter struct{} + +func (l *MultipleAlterTableLinter) String() string { + return l.Name() +} + +func (l *MultipleAlterTableLinter) Name() string { + return "multiple_alter_table" +} + +func (l *MultipleAlterTableLinter) Category() string { + return "schema" +} + +func (l *MultipleAlterTableLinter) Description() string { + return "Detects multiple ALTER TABLE statements on the same table that could be combined" +} + +func (l *MultipleAlterTableLinter) Lint(_ []*statement.CreateTable, statements []*statement.AbstractStatement) []lint.Violation { + var violations []lint.Violation + + // Count ALTER TABLE statements per table + tableAlterCounts := make(map[string][]int) // table name -> statement indices + + for i, stmt := range statements { + if !stmt.IsAlterTable() { + continue + } + + tableName := stmt.Table + if tableName == "" { + continue + } + + tableAlterCounts[tableName] = append(tableAlterCounts[tableName], i) + } + + // Report violations for tables with multiple ALTER statements + for tableName, indices := range tableAlterCounts { + if len(indices) < 2 { + continue + } + + // Build a list of the ALTER operations for the suggestion + var operations []string + + for _, idx := range indices { + if statements[idx].Alter != "" { + operations = append(operations, statements[idx].Alter) + } + } + + suggestion := "" + if len(operations) > 0 { + suggestion = fmt.Sprintf("Combine into: ALTER TABLE %s %s", + tableName, + strings.Join(operations, ", ")) + } + + message := fmt.Sprintf("Table '%s' has %d separate ALTER TABLE statements that could be combined into one for better performance", + tableName, + len(indices)) + + violation := lint.Violation{ + Linter: l, + Severity: lint.SeverityInfo, + Message: message, + Location: &lint.Location{ + Table: tableName, + }, + Context: map[string]any{ + "alter_count": len(indices), + "statement_indices": indices, + }, + } + + if suggestion != "" { + violation.Suggestion = &suggestion + } + + violations = append(violations, violation) + } + + return violations +} diff --git a/pkg/lint/linters/multiple_alter_test.go b/pkg/lint/linters/multiple_alter_test.go new file mode 100644 index 00000000..a6b99522 --- /dev/null +++ b/pkg/lint/linters/multiple_alter_test.go @@ -0,0 +1,271 @@ +package linters + +import ( + "testing" + + "github.com/block/spirit/pkg/lint" + "github.com/block/spirit/pkg/statement" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMultipleAlterTableLinter_SingleAlter(t *testing.T) { + sql := "ALTER TABLE users ADD COLUMN age INT" + stmts, err := statement.New(sql) + require.NoError(t, err) + + linter := &MultipleAlterTableLinter{} + violations := linter.Lint(nil, stmts) + + // No violation for single ALTER + assert.Empty(t, violations) +} + +func TestMultipleAlterTableLinter_TwoAltersOnSameTable(t *testing.T) { + sql := `ALTER TABLE users ADD COLUMN age INT; + ALTER TABLE users ADD INDEX idx_age (age)` + stmts, err := statement.New(sql) + require.NoError(t, err) + require.Len(t, stmts, 2) + + linter := &MultipleAlterTableLinter{} + violations := linter.Lint(nil, stmts) + + require.Len(t, violations, 1) + assert.Equal(t, "multiple_alter_table", violations[0].Linter.Name()) + assert.Equal(t, lint.SeverityInfo, violations[0].Severity) + assert.Contains(t, violations[0].Message, "2 separate ALTER TABLE statements") + assert.Equal(t, "users", violations[0].Location.Table) + assert.NotNil(t, violations[0].Suggestion) + assert.Contains(t, *violations[0].Suggestion, "Combine into") + assert.Contains(t, *violations[0].Suggestion, "ADD COLUMN `age` INT") + assert.Contains(t, *violations[0].Suggestion, "ADD INDEX `idx_age`(`age`)") +} + +func TestMultipleAlterTableLinter_ThreeAltersOnSameTable(t *testing.T) { + sql := `ALTER TABLE users ADD COLUMN age INT; + ALTER TABLE users ADD COLUMN email VARCHAR(255); + ALTER TABLE users ADD INDEX idx_email (email)` + stmts, err := statement.New(sql) + require.NoError(t, err) + require.Len(t, stmts, 3) + + linter := &MultipleAlterTableLinter{} + violations := linter.Lint(nil, stmts) + + require.Len(t, violations, 1) + assert.Equal(t, "multiple_alter_table", violations[0].Linter.Name()) + assert.Contains(t, violations[0].Message, "3 separate ALTER TABLE statements") + + // Check context + assert.NotNil(t, violations[0].Context) + assert.Equal(t, 3, violations[0].Context["alter_count"]) + indices, ok := violations[0].Context["statement_indices"].([]int) + require.True(t, ok) + assert.Len(t, indices, 3) +} + +func TestMultipleAlterTableLinter_AltersOnDifferentTables(t *testing.T) { + sql := `ALTER TABLE users ADD COLUMN age INT; + ALTER TABLE orders ADD COLUMN total DECIMAL(10,2)` + stmts, err := statement.New(sql) + require.NoError(t, err) + require.Len(t, stmts, 2) + + linter := &MultipleAlterTableLinter{} + violations := linter.Lint(nil, stmts) + + // No violation when altering different tables + assert.Empty(t, violations) +} + +func TestMultipleAlterTableLinter_MixedStatements(t *testing.T) { + // Parse statements individually since statement.New() doesn't support mixed types + var stmts []*statement.AbstractStatement + + sql1 := "CREATE TABLE products (id INT PRIMARY KEY)" + s1, err := statement.New(sql1) + require.NoError(t, err) + + stmts = append(stmts, s1...) + + sql2 := "ALTER TABLE users ADD COLUMN age INT" + s2, err := statement.New(sql2) + require.NoError(t, err) + + stmts = append(stmts, s2...) + + sql3 := "ALTER TABLE users ADD INDEX idx_age (age)" + s3, err := statement.New(sql3) + require.NoError(t, err) + + stmts = append(stmts, s3...) + + sql4 := "DROP TABLE old_table" + s4, err := statement.New(sql4) + require.NoError(t, err) + + stmts = append(stmts, s4...) + + require.Len(t, stmts, 4) + + linter := &MultipleAlterTableLinter{} + violations := linter.Lint(nil, stmts) + + // Should only detect the two ALTER TABLE users statements + require.Len(t, violations, 1) + assert.Equal(t, "users", violations[0].Location.Table) + assert.Contains(t, violations[0].Message, "2 separate ALTER TABLE statements") +} + +func TestMultipleAlterTableLinter_MultipleTables(t *testing.T) { + sql := `ALTER TABLE users ADD COLUMN age INT; + ALTER TABLE users ADD INDEX idx_age (age); + ALTER TABLE orders ADD COLUMN status VARCHAR(50); + ALTER TABLE orders ADD INDEX idx_status (status)` + stmts, err := statement.New(sql) + require.NoError(t, err) + require.Len(t, stmts, 4) + + linter := &MultipleAlterTableLinter{} + violations := linter.Lint(nil, stmts) + + // Should detect violations for both tables + require.Len(t, violations, 2) + + tableNames := make(map[string]bool) + for _, v := range violations { + tableNames[v.Location.Table] = true + assert.Equal(t, "multiple_alter_table", v.Linter.Name()) + assert.Contains(t, v.Message, "2 separate ALTER TABLE statements") + } + + assert.True(t, tableNames["users"]) + assert.True(t, tableNames["orders"]) +} + +func TestMultipleAlterTableLinter_ComplexAlters(t *testing.T) { + sql := `ALTER TABLE users ADD COLUMN age INT, ADD COLUMN email VARCHAR(255); + ALTER TABLE users DROP COLUMN old_field` + stmts, err := statement.New(sql) + require.NoError(t, err) + require.Len(t, stmts, 2) + + linter := &MultipleAlterTableLinter{} + violations := linter.Lint(nil, stmts) + + require.Len(t, violations, 1) + assert.Equal(t, "users", violations[0].Location.Table) + + // Suggestion should include both ALTER operations (with backticks as generated by parser) + assert.NotNil(t, violations[0].Suggestion) + suggestion := *violations[0].Suggestion + assert.Contains(t, suggestion, "ADD COLUMN `age` INT, ADD COLUMN `email` VARCHAR(255)") + assert.Contains(t, suggestion, "DROP COLUMN `old_field`") +} + +func TestMultipleAlterTableLinter_NonAlterStatements(t *testing.T) { + // Parse statements individually since statement.New() doesn't support mixed types + var stmts []*statement.AbstractStatement + + sql1 := "CREATE TABLE users (id INT PRIMARY KEY)" + s1, err := statement.New(sql1) + require.NoError(t, err) + + stmts = append(stmts, s1...) + + sql2 := "CREATE TABLE orders (id INT PRIMARY KEY)" + s2, err := statement.New(sql2) + require.NoError(t, err) + + stmts = append(stmts, s2...) + + linter := &MultipleAlterTableLinter{} + violations := linter.Lint(nil, stmts) + + // No violations for non-ALTER statements + assert.Empty(t, violations) +} + +func TestMultipleAlterTableLinter_EmptyStatements(t *testing.T) { + linter := &MultipleAlterTableLinter{} + violations := linter.Lint(nil, nil) + + assert.Empty(t, violations) +} + +func TestMultipleAlterTableLinter_Integration(t *testing.T) { + lint.Reset() + lint.Register(&MultipleAlterTableLinter{}) + + sql := `ALTER TABLE users ADD COLUMN age INT; + ALTER TABLE users ADD INDEX idx_age (age)` + stmts, err := statement.New(sql) + require.NoError(t, err) + + violations := lint.RunLinters(nil, stmts, lint.Config{}) + + require.Len(t, violations, 1) + assert.Equal(t, "multiple_alter_table", violations[0].Linter.Name()) +} + +func TestMultipleAlterTableLinter_IntegrationDisabled(t *testing.T) { + lint.Reset() + lint.Register(&MultipleAlterTableLinter{}) + + sql := `ALTER TABLE users ADD COLUMN age INT; + ALTER TABLE users ADD INDEX idx_age (age)` + stmts, err := statement.New(sql) + require.NoError(t, err) + + violations := lint.RunLinters(nil, stmts, lint.Config{ + Enabled: map[string]bool{ + "multiple_alter_table": false, + }, + }) + + assert.Empty(t, violations) +} + +func TestMultipleAlterTableLinter_Metadata(t *testing.T) { + linter := &MultipleAlterTableLinter{} + + assert.Equal(t, "multiple_alter_table", linter.Name()) + assert.Equal(t, "schema", linter.Category()) + assert.NotEmpty(t, linter.Description()) +} + +func TestMultipleAlterTableLinter_ContextData(t *testing.T) { + sql := `ALTER TABLE users ADD COLUMN age INT; + ALTER TABLE users ADD COLUMN email VARCHAR(255); + ALTER TABLE users ADD INDEX idx_email (email)` + stmts, err := statement.New(sql) + require.NoError(t, err) + + linter := &MultipleAlterTableLinter{} + violations := linter.Lint(nil, stmts) + + require.Len(t, violations, 1) + + // Verify context contains useful debugging info + assert.NotNil(t, violations[0].Context) + assert.Equal(t, 3, violations[0].Context["alter_count"]) + + indices, ok := violations[0].Context["statement_indices"].([]int) + require.True(t, ok) + assert.Equal(t, []int{0, 1, 2}, indices) +} + +func TestMultipleAlterTableLinter_SeverityIsInfo(t *testing.T) { + sql := `ALTER TABLE users ADD COLUMN age INT; + ALTER TABLE users ADD INDEX idx_age (age)` + stmts, err := statement.New(sql) + require.NoError(t, err) + + linter := &MultipleAlterTableLinter{} + violations := linter.Lint(nil, stmts) + + require.Len(t, violations, 1) + // This is INFO level because it's an optimization suggestion, not an error + assert.Equal(t, lint.SeverityInfo, violations[0].Severity) +} diff --git a/pkg/lint/linters/primary_key_type.go b/pkg/lint/linters/primary_key_type.go new file mode 100644 index 00000000..3d65fcfe --- /dev/null +++ b/pkg/lint/linters/primary_key_type.go @@ -0,0 +1,156 @@ +package linters + +import ( + "fmt" + "strings" + + "github.com/block/spirit/pkg/lint" + "github.com/block/spirit/pkg/statement" + "github.com/pingcap/tidb/pkg/parser/mysql" +) + +func init() { + lint.Register(&PrimaryKeyTypeLinter{}) +} + +// PrimaryKeyTypeLinter checks that primary keys use appropriate data types. +// Primary keys should be BIGINT (preferably UNSIGNED) or BINARY/VARBINARY. +// Other types are flagged as errors, and signed BIGINT is flagged as a warning. +type PrimaryKeyTypeLinter struct{} + +func (l *PrimaryKeyTypeLinter) String() string { + return l.Name() +} + +func (l *PrimaryKeyTypeLinter) Name() string { + return "primary_key_type" +} + +func (l *PrimaryKeyTypeLinter) Category() string { + return "schema" +} + +func (l *PrimaryKeyTypeLinter) Description() string { + return "Ensures primary keys use BIGINT (preferably UNSIGNED) or BINARY/VARBINARY types" +} + +func (l *PrimaryKeyTypeLinter) Lint(createTables []*statement.CreateTable, _ []*statement.AbstractStatement) []lint.Violation { + var violations []lint.Violation + + for _, ct := range createTables { + tableName := ct.GetTableName() + + // Get primary key columns from indexes (this includes both table-level and column-level PRIMARY KEY) + pkColumns := l.getPrimaryKeyColumnsFromIndexes(ct) + if len(pkColumns) == 0 { + continue + } + + // Check each primary key column's type + for _, pkCol := range pkColumns { + column := ct.GetColumns().ByName(pkCol) + if column == nil { + continue + } + + violation := l.checkColumnType(tableName, column) + if violation != nil { + violations = append(violations, *violation) + } + } + } + + return violations +} + +// getPrimaryKeyColumnsFromIndexes returns the names of columns that are part of the primary key +// by checking the indexes (which includes both table-level and column-level PRIMARY KEY definitions) +func (l *PrimaryKeyTypeLinter) getPrimaryKeyColumnsFromIndexes(ct *statement.CreateTable) []string { + var pkColumns []string + + // Check for PRIMARY KEY in indexes + for _, index := range ct.GetIndexes() { + if index.Type == "PRIMARY KEY" { + pkColumns = append(pkColumns, index.Columns...) + break // There can only be one PRIMARY KEY + } + } + + return pkColumns +} + +// checkColumnType checks if a primary key column has an appropriate type +func (l *PrimaryKeyTypeLinter) checkColumnType(tableName string, column *statement.Column) *lint.Violation { + columnType := strings.ToUpper(column.Type) + + // Check for BIGINT + if strings.HasPrefix(columnType, "BIGINT") { + // Check if it's unsigned (either in type string or Unsigned field) + isUnsigned := column.Unsigned != nil && *column.Unsigned + if !isUnsigned { + isUnsigned = strings.Contains(columnType, "UNSIGNED") + } + + if isUnsigned { + // BIGINT UNSIGNED is ideal - no violation + return nil + } + + // BIGINT without UNSIGNED is a warning + suggestion := fmt.Sprintf("Consider using BIGINT UNSIGNED for column '%s' to avoid negative values and increase range", column.Name) + + return &lint.Violation{ + Linter: l, + Severity: lint.SeverityWarning, + Message: fmt.Sprintf("Primary key column '%s' uses signed BIGINT; UNSIGNED is preferred", column.Name), + Location: &lint.Location{ + Table: tableName, + Column: &column.Name, + }, + Suggestion: &suggestion, + } + } + + // Check for BINARY/VARBINARY + // Note: The parser returns "char" for BINARY and "varchar" for VARBINARY, + // so we need to check the raw type and binary flag + if l.isBinaryType(column) { + // BINARY/VARBINARY is acceptable - no violation + return nil + } + + // Any other type is an error + suggestion := fmt.Sprintf("Change column '%s' to BIGINT UNSIGNED or BINARY/VARBINARY", column.Name) + + return &lint.Violation{ + Linter: l, + Severity: lint.SeverityError, + Message: fmt.Sprintf("Primary key column '%s' has type '%s'; must be BIGINT or BINARY/VARBINARY", column.Name, column.Type), + Location: &lint.Location{ + Table: tableName, + Column: &column.Name, + }, + Suggestion: &suggestion, + Context: map[string]interface{}{ + "current_type": column.Type, + }, + } +} + +// isBinaryType checks if a column is BINARY or VARBINARY type +// The parser returns "char" for BINARY and "varchar" for VARBINARY, so we need to check the binary flag +func (l *PrimaryKeyTypeLinter) isBinaryType(column *statement.Column) bool { + if column.Raw == nil || column.Raw.Tp == nil { + return false + } + + // BINARY is mysql.TypeString with binary flag + // VARBINARY is mysql.TypeVarchar with binary flag + rawType := column.Raw.Tp.GetType() + + // Check if it's a string type with binary flag (BINARY/VARBINARY) + + fmt.Printf("Debug: type=%s rawType=%d, flags=%d, options=%#v\n", column.Type, rawType, column.Raw.Tp.GetFlag(), column.Options) + + return (rawType == mysql.TypeString || rawType == mysql.TypeVarchar) && mysql.HasBinaryFlag(column.Raw.Tp.GetFlag()) +} diff --git a/pkg/lint/linters/primary_key_type_test.go b/pkg/lint/linters/primary_key_type_test.go new file mode 100644 index 00000000..1c90c948 --- /dev/null +++ b/pkg/lint/linters/primary_key_type_test.go @@ -0,0 +1,396 @@ +package linters + +import ( + "testing" + + "github.com/block/spirit/pkg/lint" + "github.com/block/spirit/pkg/statement" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPrimaryKeyTypeLinter_BigIntUnsigned(t *testing.T) { + sql := `CREATE TABLE users ( + id BIGINT UNSIGNED PRIMARY KEY, + name VARCHAR(255) + )` + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := &PrimaryKeyTypeLinter{} + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + + // BIGINT UNSIGNED is ideal - no violations + assert.Empty(t, violations) +} + +func TestPrimaryKeyTypeLinter_BigIntSigned(t *testing.T) { + sql := `CREATE TABLE users ( + id BIGINT PRIMARY KEY, + name VARCHAR(255) + )` + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := &PrimaryKeyTypeLinter{} + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + + // BIGINT without UNSIGNED should be a warning + require.Len(t, violations, 1) + assert.Equal(t, "primary_key_type", violations[0].Linter.Name()) + assert.Equal(t, lint.SeverityWarning, violations[0].Severity) + assert.Contains(t, violations[0].Message, "signed BIGINT") + assert.Contains(t, violations[0].Message, "UNSIGNED is preferred") + assert.Equal(t, "users", violations[0].Location.Table) + assert.NotNil(t, violations[0].Location.Column) + assert.Equal(t, "id", *violations[0].Location.Column) + assert.NotNil(t, violations[0].Suggestion) + assert.Contains(t, *violations[0].Suggestion, "BIGINT UNSIGNED") +} + +func TestPrimaryKeyTypeLinter_Binary(t *testing.T) { + sql := `CREATE TABLE users ( + id BINARY(16) PRIMARY KEY, + name VARCHAR(255) + )` + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := &PrimaryKeyTypeLinter{} + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + + // BINARY is acceptable - no violations + assert.Empty(t, violations) +} + +func TestPrimaryKeyTypeLinter_VarBinary(t *testing.T) { + sql := `CREATE TABLE users ( + id VARBINARY(255) PRIMARY KEY, + name VARCHAR(255) + )` + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := &PrimaryKeyTypeLinter{} + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + + // VARBINARY is acceptable - no violations + assert.Empty(t, violations) +} + +func TestPrimaryKeyTypeLinter_IntError(t *testing.T) { + sql := `CREATE TABLE users ( + id INT PRIMARY KEY, + name VARCHAR(255) + )` + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := &PrimaryKeyTypeLinter{} + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + + // INT is not acceptable - should be an error + require.Len(t, violations, 1) + assert.Equal(t, "primary_key_type", violations[0].Linter.Name()) + assert.Equal(t, lint.SeverityError, violations[0].Severity) + assert.Contains(t, violations[0].Message, "must be BIGINT or BINARY/VARBINARY") + assert.Equal(t, "users", violations[0].Location.Table) + assert.NotNil(t, violations[0].Location.Column) + assert.Equal(t, "id", *violations[0].Location.Column) + assert.NotNil(t, violations[0].Suggestion) + assert.Contains(t, *violations[0].Suggestion, "BIGINT UNSIGNED") +} + +func TestPrimaryKeyTypeLinter_VarcharError(t *testing.T) { + sql := `CREATE TABLE users ( + id VARCHAR(36) PRIMARY KEY, + name VARCHAR(255) + )` + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := &PrimaryKeyTypeLinter{} + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + + // VARCHAR is not acceptable - should be an error + require.Len(t, violations, 1) + assert.Equal(t, "primary_key_type", violations[0].Linter.Name()) + assert.Equal(t, lint.SeverityError, violations[0].Severity) + assert.Contains(t, violations[0].Message, "must be BIGINT or BINARY/VARBINARY") + assert.NotNil(t, violations[0].Context) + // The parser returns lowercase "varchar" + assert.Equal(t, "varchar", violations[0].Context["current_type"]) +} + +func TestPrimaryKeyTypeLinter_CompositePrimaryKey(t *testing.T) { + sql := `CREATE TABLE user_roles ( + user_id BIGINT UNSIGNED, + role_id INT, + PRIMARY KEY (user_id, role_id) + )` + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := &PrimaryKeyTypeLinter{} + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + + // user_id is BIGINT UNSIGNED (good), role_id is INT (error) + require.Len(t, violations, 1) + assert.Equal(t, "primary_key_type", violations[0].Linter.Name()) + assert.Equal(t, lint.SeverityError, violations[0].Severity) + assert.Equal(t, "role_id", *violations[0].Location.Column) +} + +func TestPrimaryKeyTypeLinter_CompositePrimaryKeyAllGood(t *testing.T) { + sql := `CREATE TABLE user_roles ( + user_id BIGINT UNSIGNED, + role_id BIGINT UNSIGNED, + PRIMARY KEY (user_id, role_id) + )` + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := &PrimaryKeyTypeLinter{} + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + + // Both columns are BIGINT UNSIGNED - no violations + assert.Empty(t, violations) +} + +func TestPrimaryKeyTypeLinter_CompositePrimaryKeyMixed(t *testing.T) { + sql := `CREATE TABLE user_roles ( + user_id BIGINT, + role_id BIGINT UNSIGNED, + PRIMARY KEY (user_id, role_id) + )` + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := &PrimaryKeyTypeLinter{} + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + + // user_id is signed BIGINT (warning) + require.Len(t, violations, 1) + assert.Equal(t, "primary_key_type", violations[0].Linter.Name()) + assert.Equal(t, lint.SeverityWarning, violations[0].Severity) + assert.Equal(t, "user_id", *violations[0].Location.Column) +} + +func TestPrimaryKeyTypeLinter_NoPrimaryKey(t *testing.T) { + sql := `CREATE TABLE logs ( + id BIGINT UNSIGNED, + message TEXT + )` + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := &PrimaryKeyTypeLinter{} + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + + // No primary key - no violations from this linter + assert.Empty(t, violations) +} + +func TestPrimaryKeyTypeLinter_MultipleTables(t *testing.T) { + sql1 := `CREATE TABLE users ( + id BIGINT UNSIGNED PRIMARY KEY, + name VARCHAR(255) + )` + ct1, err := statement.ParseCreateTable(sql1) + require.NoError(t, err) + + sql2 := `CREATE TABLE orders ( + id INT PRIMARY KEY, + user_id BIGINT UNSIGNED + )` + ct2, err := statement.ParseCreateTable(sql2) + require.NoError(t, err) + + linter := &PrimaryKeyTypeLinter{} + violations := linter.Lint([]*statement.CreateTable{ct1, ct2}, nil) + + // Only orders table should have a violation + require.Len(t, violations, 1) + assert.Equal(t, "orders", violations[0].Location.Table) + assert.Equal(t, lint.SeverityError, violations[0].Severity) +} + +func TestPrimaryKeyTypeLinter_SmallIntError(t *testing.T) { + sql := `CREATE TABLE users ( + id SMALLINT PRIMARY KEY, + name VARCHAR(255) + )` + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := &PrimaryKeyTypeLinter{} + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + + // SMALLINT is not acceptable - should be an error + require.Len(t, violations, 1) + assert.Equal(t, lint.SeverityError, violations[0].Severity) +} + +func TestPrimaryKeyTypeLinter_MediumIntError(t *testing.T) { + sql := `CREATE TABLE users ( + id MEDIUMINT PRIMARY KEY, + name VARCHAR(255) + )` + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := &PrimaryKeyTypeLinter{} + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + + // MEDIUMINT is not acceptable - should be an error + require.Len(t, violations, 1) + assert.Equal(t, lint.SeverityError, violations[0].Severity) +} + +func TestPrimaryKeyTypeLinter_CharError(t *testing.T) { + sql := `CREATE TABLE users ( + id CHAR(36) PRIMARY KEY, + name VARCHAR(255) + )` + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := &PrimaryKeyTypeLinter{} + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + + // CHAR is not acceptable - should be an error + require.Len(t, violations, 1) + assert.Equal(t, lint.SeverityError, violations[0].Severity) +} + +func TestPrimaryKeyTypeLinter_EmptyInput(t *testing.T) { + linter := &PrimaryKeyTypeLinter{} + violations := linter.Lint(nil, nil) + + assert.Empty(t, violations) +} + +func TestPrimaryKeyTypeLinter_Integration(t *testing.T) { + lint.Reset() + lint.Register(&PrimaryKeyTypeLinter{}) + + sql := `CREATE TABLE users ( + id INT PRIMARY KEY, + name VARCHAR(255) + )` + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + violations := lint.RunLinters([]*statement.CreateTable{ct}, nil, lint.Config{}) + + require.Len(t, violations, 1) + assert.Equal(t, "primary_key_type", violations[0].Linter.Name()) +} + +func TestPrimaryKeyTypeLinter_IntegrationDisabled(t *testing.T) { + lint.Reset() + lint.Register(&PrimaryKeyTypeLinter{}) + + sql := `CREATE TABLE users ( + id INT PRIMARY KEY, + name VARCHAR(255) + )` + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + violations := lint.RunLinters([]*statement.CreateTable{ct}, nil, lint.Config{ + Enabled: map[string]bool{ + "primary_key_type": false, + }, + }) + + assert.Empty(t, violations) +} + +func TestPrimaryKeyTypeLinter_Metadata(t *testing.T) { + linter := &PrimaryKeyTypeLinter{} + + assert.Equal(t, "primary_key_type", linter.Name()) + assert.Equal(t, "schema", linter.Category()) + assert.NotEmpty(t, linter.Description()) +} + +func TestPrimaryKeyTypeLinter_BigIntUnsignedExplicit(t *testing.T) { + sql := `CREATE TABLE users ( + id BIGINT(20) UNSIGNED PRIMARY KEY, + name VARCHAR(255) + )` + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := &PrimaryKeyTypeLinter{} + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + + // BIGINT(20) UNSIGNED is ideal - no violations + assert.Empty(t, violations) +} + +func TestPrimaryKeyTypeLinter_CaseInsensitive(t *testing.T) { + // Test that type checking is case-insensitive + sql := `CREATE TABLE users ( + id bigint unsigned PRIMARY KEY, + name VARCHAR(255) + )` + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := &PrimaryKeyTypeLinter{} + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + + // Should recognize lowercase bigint unsigned + assert.Empty(t, violations) +} + +func TestPrimaryKeyTypeLinter_AutoIncrement(t *testing.T) { + sql := `CREATE TABLE users ( + id BIGINT UNSIGNED AUTO_INCREMENT PRIMARY KEY, + name VARCHAR(255) + )` + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := &PrimaryKeyTypeLinter{} + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + + // BIGINT UNSIGNED with AUTO_INCREMENT is ideal - no violations + assert.Empty(t, violations) +} + +func TestPrimaryKeyTypeLinter_UUIDAsVarchar(t *testing.T) { + // Common anti-pattern: using VARCHAR for UUIDs + sql := `CREATE TABLE users ( + uuid VARCHAR(36) PRIMARY KEY, + name VARCHAR(255) + )` + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := &PrimaryKeyTypeLinter{} + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + + // VARCHAR for UUID should be an error + require.Len(t, violations, 1) + assert.Equal(t, lint.SeverityError, violations[0].Severity) + assert.Contains(t, *violations[0].Suggestion, "BINARY/VARBINARY") +} + +func TestPrimaryKeyTypeLinter_UUIDAsBinary(t *testing.T) { + // Correct way to store UUIDs + sql := `CREATE TABLE users ( + uuid BINARY(16) PRIMARY KEY, + name VARCHAR(255) + )` + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := &PrimaryKeyTypeLinter{} + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + + // BINARY(16) for UUID is correct - no violations + assert.Empty(t, violations) +} diff --git a/pkg/lint/registry.go b/pkg/lint/registry.go new file mode 100644 index 00000000..92853873 --- /dev/null +++ b/pkg/lint/registry.go @@ -0,0 +1,131 @@ +package lint + +import ( + "fmt" + "sort" + "sync" +) + +// linter represents a registered linter with metadata +type linter struct { + impl Linter + category string + enabled bool +} + +var ( + linters map[string]linter + lock sync.RWMutex +) + +// Register registers a linter with the global registry. +// This should be called from init() functions in linter implementations. +// Linters are enabled by default when registered. +func Register(l Linter) { + lock.Lock() + defer lock.Unlock() + + if linters == nil { + linters = make(map[string]linter) + } + + linters[l.Name()] = linter{ + impl: l, + category: l.Category(), + enabled: true, + } +} + +// Enable enables specific linters by name. +// Returns an error if the linter is not found. +func Enable(names ...string) error { + lock.Lock() + defer lock.Unlock() + + for _, name := range names { + l, ok := linters[name] + if !ok { + return fmt.Errorf("linter %q not found", name) + } + + l.enabled = true + linters[name] = l + } + + return nil +} + +// Disable disables specific linters by name. +// Returns an error if the linter is not found. +func Disable(names ...string) error { + lock.Lock() + defer lock.Unlock() + + for _, name := range names { + l, ok := linters[name] + if !ok { + return fmt.Errorf("linter %q not found", name) + } + + l.enabled = false + linters[name] = l + } + + return nil +} + +// List returns the names of all registered linters in sorted order. +func List() []string { + lock.RLock() + defer lock.RUnlock() + + names := make([]string, 0, len(linters)) + for name := range linters { + names = append(names, name) + } + + sort.Strings(names) + + return names +} + +// ListByCategory returns the names of all linters in a specific category in sorted order. +func ListByCategory(category string) []string { + lock.RLock() + defer lock.RUnlock() + + var names []string + + for name, linter := range linters { + if linter.category == category { + names = append(names, name) + } + } + + sort.Strings(names) + + return names +} + +// Get returns a linter by name. +// Returns an error if the linter is not found. +func Get(name string) (Linter, error) { + lock.RLock() + defer lock.RUnlock() + + l, ok := linters[name] + if !ok { + return nil, fmt.Errorf("linter %q not found", name) + } + + return l.impl, nil +} + +// Reset clears all registered linters. +// This is primarily useful for testing. +func Reset() { + lock.Lock() + defer lock.Unlock() + + linters = make(map[string]linter) +} diff --git a/pkg/lint/violation.go b/pkg/lint/violation.go new file mode 100644 index 00000000..d9f725fc --- /dev/null +++ b/pkg/lint/violation.go @@ -0,0 +1,84 @@ +package lint + +import "fmt" + +// Severity represents the severity level of a linting violation +type Severity string + +const ( + // SeverityError indicates a violation that will cause actual problems + // (syntax errors, MySQL limitations, etc.) + SeverityError Severity = "ERROR" + + // SeverityWarning indicates a best practice violation or potential issue + SeverityWarning Severity = "WARNING" + + // SeverityInfo indicates a suggestion or style preference + SeverityInfo Severity = "INFO" +) + +// Violation represents a linting violation found during analysis +type Violation struct { + // Linter is the linter that produced this violation + Linter Linter + + // Severity is the severity level of the violation + Severity Severity + + // Message is a human-readable description of the violation + Message string + + // Location provides information about where the violation occurred + Location *Location + + // Suggestion is an optional suggestion for fixing the violation + Suggestion *string + + // Context provides additional context-specific information + Context map[string]any +} + +func (v Violation) String() string { + msg := fmt.Sprintf("[%s] %s: %s", v.Severity, v.Linter.Name(), v.Message) + if v.Location != nil { + msg += fmt.Sprintf(" (%s)", v.Location) + } + + if v.Suggestion != nil { + msg += " Suggestion: " + *v.Suggestion + } + + return msg +} + +// Location provides information about where a violation occurred +type Location struct { + // Table is the name of the table where the violation occurred + Table string + + // Column is the name of the column (if applicable) + Column *string + + // Index is the name of the index (if applicable) + Index *string + + // Constraint is the name of the constraint (if applicable) + Constraint *string +} + +func (l *Location) String() string { + msg := "Table: " + l.Table + if l.Column != nil { + msg += ", Column: " + *l.Column + } + + if l.Index != nil { + msg += ", Index: " + *l.Index + } + + if l.Constraint != nil { + msg += ", Constraint: " + *l.Constraint + } + + return msg +} From 61278aa9f8a1e40ac985b59b7ed3b282f231e299 Mon Sep 17 00:00:00 2001 From: Kolbe Kegel Date: Fri, 17 Oct 2025 22:18:34 -0500 Subject: [PATCH 02/12] flatten package/dir structure to a single "lint" package. update README.md --- pkg/lint/README.md | 195 +++++++++++++----- ...sible_index.go => lint_invisible_index.go} | 15 +- ...x_test.go => lint_invisible_index_test.go} | 19 +- ...ltiple_alter.go => lint_multiple_alter.go} | 15 +- ...er_test.go => lint_multiple_alter_test.go} | 19 +- ...y_key_type.go => lint_primary_key_type.go} | 23 +-- ..._test.go => lint_primary_key_type_test.go} | 35 ++-- 7 files changed, 204 insertions(+), 117 deletions(-) rename pkg/lint/{linters/invisible_index.go => lint_invisible_index.go} (88%) rename pkg/lint/{linters/invisible_index_test.go => lint_invisible_index_test.go} (93%) rename pkg/lint/{linters/multiple_alter.go => lint_multiple_alter.go} (88%) rename pkg/lint/{linters/multiple_alter_test.go => lint_multiple_alter_test.go} (95%) rename pkg/lint/{linters/primary_key_type.go => lint_primary_key_type.go} (91%) rename pkg/lint/{linters/primary_key_type_test.go => lint_primary_key_type_test.go} (91%) diff --git a/pkg/lint/README.md b/pkg/lint/README.md index fe00cfff..be5de00e 100644 --- a/pkg/lint/README.md +++ b/pkg/lint/README.md @@ -2,41 +2,74 @@ The `lint` package provides a framework for static analysis of MySQL schema definitions and DDL statements. It enables validation and best-practice enforcement beyond the runtime checks provided by the `check` package. +## Architecture + +All built-in linters are automatically registered and enabled when the `lint` package is imported. The framework uses a flat package structure: + +- **Core framework files**: `lint.go`, `linter.go`, `registry.go`, `violation.go` +- **Linter implementations**: `lint_*.go` (e.g., `lint_invisible_index.go`) + ## Quick Start -### Creating a Linter +### Using the Linter Framework ```go -package linters - import ( "github.com/block/spirit/pkg/lint" +) + +// All built-in linters are automatically registered! +violations := lint.RunLinters(tables, stmts, lint.Config{}) + +// Check for errors +if lint.HasErrors(violations) { + // Handle errors +} + +// Filter violations +errors := lint.FilterBySeverity(violations, lint.SeverityError) +warnings := lint.FilterBySeverity(violations, lint.SeverityWarning) +``` + +### Creating a Custom Linter + +Custom linters can be +1. added directly to the `lint` package (in new files with the `lint_` prefix, for consistency) +2. added to your own package and registered by blank import that relies on the `init()` function +3. added to your own code and registered explicitly using `lint.Register()` + +```go +// lint_my_custom.go +package lint + +import ( "github.com/block/spirit/pkg/statement" ) // Register your linter in init() func init() { - lint.Register(&MyLinter{}) + Register(&MyCustomLinter{}) } -type MyLinter struct{} +// MyCustomLinter checks custom rules +type MyCustomLinter struct{} -func (l *MyLinter) Name() string { return "my_linter" } -func (l *MyLinter) Category() string { return "naming" } -func (l *MyLinter) Description() string { return "Checks naming conventions" } -func (l *MyLinter) String() string { return l.Name() } +func (l *MyCustomLinter) Name() string { return "my_custom" } +func (l *MyCustomLinter) Category() string { return "naming" } +func (l *MyCustomLinter) Description() string { return "Checks naming conventions" } +func (l *MyCustomLinter) String() string { return l.Name() } -func (l *MyLinter) Lint(createTables []*statement.CreateTable, alterStatements []*statement.AbstractStatement) []lint.Violation { - var violations []lint.Violation +func (l *MyCustomLinter) Lint(createTables []*statement.CreateTable, alterStatements []*statement.AbstractStatement) []Violation { + var violations []Violation for _, ct := range createTables { // Check table properties if /* condition */ { - violations = append(violations, lint.Violation{ + violations = append(violations, Violation{ Linter: l, - Severity: lint.SeverityWarning, + Severity: SeverityWarning, Message: "Table name issue", - Location: &lint.Location{ + Location: &Location{ Table: ct.GetTableName(), }, }) @@ -47,26 +80,6 @@ func (l *MyLinter) Lint(createTables []*statement.CreateTable, alterStatements [ } ``` -### Running Linters - -```go -import ( - "github.com/block/spirit/pkg/lint" -) - -// Run all enabled linters -violations := lint.RunLinters(tables, stmts, lint.Config{}) - -// Check for errors -if lint.HasErrors(violations) { - // Handle errors -} - -// Filter violations -errors := lint.FilterBySeverity(violations, lint.SeverityError) -warnings := lint.FilterBySeverity(violations, lint.SeverityWarning) -``` - ### Configuring Linters ```go @@ -83,7 +96,7 @@ violations := lint.RunLinters(tables, stmts, lint.Config{ ### Severity Levels -- **ERROR**: Will cause actual problems (syntax errors, MySQL limitations) +- **ERROR**: Will cause actual problems (data loss, inconsistency, MySQL limitations) - **WARNING**: Best practice violations, potential issues - **INFO**: Suggestions, style preferences @@ -130,35 +143,115 @@ type Location struct { - `FilterBySeverity(violations, severity)` - Filter by severity level - `FilterByLinter(violations, name)` - Filter by linter name -## Example Linters +## Built-in Linters -The `example` package provides two demonstration linters: +The `lint` package includes several linters: -### TableNameLengthLinter +### invisible_index_before_drop -Checks that table names don't exceed MySQL's 64 character limit. +**Category**: schema +**Severity**: Warning + +Requires indexes to be made invisible before dropping them as a safety measure. This ensures the index isn't needed before permanently removing it. ```go -lint.Register(example.NewTableNameLengthLinter()) +// ❌ Violation +ALTER TABLE users DROP INDEX idx_email; + +// ✅ Correct +ALTER TABLE users ALTER INDEX idx_email INVISIBLE; +-- Wait and monitor performance +ALTER TABLE users DROP INDEX idx_email; ``` -### DuplicateColumnLinter +### multiple_alter_table -Detects duplicate column definitions in CREATE TABLE statements. +**Category**: schema +**Severity**: Info + +Detects multiple ALTER TABLE statements on the same table that could be combined into one for better performance and fewer table rebuilds. ```go -lint.Register(&example.DuplicateColumnLinter{}) +// ❌ Violation +ALTER TABLE users ADD COLUMN age INT; +ALTER TABLE users ADD INDEX idx_age (age); + +// ✅ Better +ALTER TABLE users + ADD COLUMN age INT, + ADD INDEX idx_age (age); ``` -## Contributing +### primary_key_type + +**Category**: schema +**Severity**: Error (invalid types), Warning (signed BIGINT) + +Ensures primary keys use BIGINT (preferably UNSIGNED) or BINARY/VARBINARY types. -When adding new linters: +```go +// ❌ Error - invalid type +CREATE TABLE users ( + id INT PRIMARY KEY -- Should be BIGINT +); + +// ⚠️ Warning - should be unsigned +CREATE TABLE users ( + id BIGINT PRIMARY KEY -- Should be BIGINT UNSIGNED +); + +// ✅ Correct +CREATE TABLE users ( + id BIGINT UNSIGNED PRIMARY KEY +); +``` + +## Example Linters -1. Implement the `Linter` interface -2. Register in `init()` function -3. Add comprehensive tests -4. Document the linter's behavior -5. Choose appropriate severity levels -6. Provide helpful error messages and suggestions +The `example` package provides demonstration linters for learning purposes: + +### TableNameLengthLinter + +Checks that table names don't exceed MySQL's 64 character limit. + +### DuplicateColumnLinter + +Detects duplicate column definitions in CREATE TABLE statements. See `pkg/lint/example/` for reference implementations. + +## Contributing + +When adding new linters to the `lint` package: + +1. **Create a new file** with the `lint_` prefix (e.g., `lint_my_rule.go`) +2. **Implement the `Linter` interface** with all required methods +3. **Register in `init()`** function to enable automatic registration +4. **Add comprehensive tests** in a corresponding `lint_my_rule_test.go` file +5. **Document the linter** with clear comments and examples +6. **Choose appropriate severity levels**: + - `SeverityError` for violations that will cause actual problems + - `SeverityWarning` for best practice violations + - `SeverityInfo` for suggestions and style preferences +7. **Provide helpful messages** with actionable suggestions when possible +8. **Update this README** with documentation for the new linter + +### File Naming Convention + +- Linter implementation: `lint_.go` +- Linter tests: `lint__test.go` + +### Example Structure + +``` +pkg/lint/ +├── lint.go # Core API +├── linter.go # Interface definition +├── registry.go # Registration system +├── violation.go # Violation types +├── lint_invisible_index.go # Built-in linter +├── lint_multiple_alter.go # Built-in linter +├── lint_primary_key_type.go # Built-in linter +├── lint_my_new_rule.go # Your new linter +└── lint_my_new_rule_test.go # Your tests +``` diff --git a/pkg/lint/linters/invisible_index.go b/pkg/lint/lint_invisible_index.go similarity index 88% rename from pkg/lint/linters/invisible_index.go rename to pkg/lint/lint_invisible_index.go index e25b1827..efd0e7bc 100644 --- a/pkg/lint/linters/invisible_index.go +++ b/pkg/lint/lint_invisible_index.go @@ -1,15 +1,14 @@ -package linters +package lint import ( "fmt" - "github.com/block/spirit/pkg/lint" "github.com/block/spirit/pkg/statement" "github.com/pingcap/tidb/pkg/parser/ast" ) func init() { - lint.Register(&InvisibleIndexBeforeDropLinter{}) + Register(&InvisibleIndexBeforeDropLinter{}) } // InvisibleIndexBeforeDropLinter checks that indexes are made invisible before dropping. @@ -32,8 +31,8 @@ func (l *InvisibleIndexBeforeDropLinter) Description() string { return "Requires indexes to be made invisible before dropping them as a safety measure" } -func (l *InvisibleIndexBeforeDropLinter) Lint(createTables []*statement.CreateTable, statements []*statement.AbstractStatement) []lint.Violation { - var violations []lint.Violation +func (l *InvisibleIndexBeforeDropLinter) Lint(createTables []*statement.CreateTable, statements []*statement.AbstractStatement) []Violation { + var violations []Violation for _, stmt := range statements { // Only check ALTER TABLE statements @@ -73,11 +72,11 @@ func (l *InvisibleIndexBeforeDropLinter) Lint(createTables []*statement.CreateTa if !madeInvisible { suggestion := fmt.Sprintf("First make the index invisible: ALTER TABLE %s ALTER INDEX %s INVISIBLE", tableName, indexName) - violations = append(violations, lint.Violation{ + violations = append(violations, Violation{ Linter: l, - Severity: lint.SeverityWarning, + Severity: SeverityWarning, Message: fmt.Sprintf("Index '%s' should be made invisible before dropping to ensure it's not needed", indexName), - Location: &lint.Location{ + Location: &Location{ Table: tableName, Index: &indexName, }, diff --git a/pkg/lint/linters/invisible_index_test.go b/pkg/lint/lint_invisible_index_test.go similarity index 93% rename from pkg/lint/linters/invisible_index_test.go rename to pkg/lint/lint_invisible_index_test.go index 13b8046f..8c5d6ff7 100644 --- a/pkg/lint/linters/invisible_index_test.go +++ b/pkg/lint/lint_invisible_index_test.go @@ -1,9 +1,8 @@ -package linters +package lint import ( "testing" - "github.com/block/spirit/pkg/lint" "github.com/block/spirit/pkg/statement" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -20,7 +19,7 @@ func TestInvisibleIndexBeforeDropLinter_DropWithoutInvisible(t *testing.T) { require.Len(t, violations, 1) assert.Equal(t, "invisible_index_before_drop", violations[0].Linter.Name()) - assert.Equal(t, lint.SeverityWarning, violations[0].Severity) + assert.Equal(t, SeverityWarning, violations[0].Severity) assert.Contains(t, violations[0].Message, "should be made invisible before dropping") assert.Equal(t, "users", violations[0].Location.Table) assert.NotNil(t, violations[0].Location.Index) @@ -107,7 +106,7 @@ func TestInvisibleIndexBeforeDropLinter_MultipleDrops(t *testing.T) { for _, v := range violations { assert.Equal(t, "invisible_index_before_drop", v.Linter.Name()) - assert.Equal(t, lint.SeverityWarning, v.Severity) + assert.Equal(t, SeverityWarning, v.Severity) if v.Location.Index != nil { indexNames[*v.Location.Index] = true @@ -146,14 +145,14 @@ func TestInvisibleIndexBeforeDropLinter_AlterWithoutDrop(t *testing.T) { func TestInvisibleIndexBeforeDropLinter_Integration(t *testing.T) { // Reset registry and register linter - lint.Reset() - lint.Register(&InvisibleIndexBeforeDropLinter{}) + Reset() + Register(&InvisibleIndexBeforeDropLinter{}) sql := "ALTER TABLE users DROP INDEX idx_email" stmts, err := statement.New(sql) require.NoError(t, err) - violations := lint.RunLinters(nil, stmts, lint.Config{}) + violations := RunLinters(nil, stmts, Config{}) require.Len(t, violations, 1) assert.Equal(t, "invisible_index_before_drop", violations[0].Linter.Name()) @@ -161,15 +160,15 @@ func TestInvisibleIndexBeforeDropLinter_Integration(t *testing.T) { func TestInvisibleIndexBeforeDropLinter_IntegrationDisabled(t *testing.T) { // Reset registry and register linter - lint.Reset() - lint.Register(&InvisibleIndexBeforeDropLinter{}) + Reset() + Register(&InvisibleIndexBeforeDropLinter{}) sql := "ALTER TABLE users DROP INDEX idx_email" stmts, err := statement.New(sql) require.NoError(t, err) // Disable the linter - violations := lint.RunLinters(nil, stmts, lint.Config{ + violations := RunLinters(nil, stmts, Config{ Enabled: map[string]bool{ "invisible_index_before_drop": false, }, diff --git a/pkg/lint/linters/multiple_alter.go b/pkg/lint/lint_multiple_alter.go similarity index 88% rename from pkg/lint/linters/multiple_alter.go rename to pkg/lint/lint_multiple_alter.go index e1a346f1..96b6a13e 100644 --- a/pkg/lint/linters/multiple_alter.go +++ b/pkg/lint/lint_multiple_alter.go @@ -1,15 +1,14 @@ -package linters +package lint import ( "fmt" "strings" - "github.com/block/spirit/pkg/lint" "github.com/block/spirit/pkg/statement" ) func init() { - lint.Register(&MultipleAlterTableLinter{}) + Register(&MultipleAlterTableLinter{}) } // MultipleAlterTableLinter checks for multiple ALTER TABLE statements affecting the same table. @@ -33,8 +32,8 @@ func (l *MultipleAlterTableLinter) Description() string { return "Detects multiple ALTER TABLE statements on the same table that could be combined" } -func (l *MultipleAlterTableLinter) Lint(_ []*statement.CreateTable, statements []*statement.AbstractStatement) []lint.Violation { - var violations []lint.Violation +func (l *MultipleAlterTableLinter) Lint(_ []*statement.CreateTable, statements []*statement.AbstractStatement) []Violation { + var violations []Violation // Count ALTER TABLE statements per table tableAlterCounts := make(map[string][]int) // table name -> statement indices @@ -78,11 +77,11 @@ func (l *MultipleAlterTableLinter) Lint(_ []*statement.CreateTable, statements [ tableName, len(indices)) - violation := lint.Violation{ + violation := Violation{ Linter: l, - Severity: lint.SeverityInfo, + Severity: SeverityInfo, Message: message, - Location: &lint.Location{ + Location: &Location{ Table: tableName, }, Context: map[string]any{ diff --git a/pkg/lint/linters/multiple_alter_test.go b/pkg/lint/lint_multiple_alter_test.go similarity index 95% rename from pkg/lint/linters/multiple_alter_test.go rename to pkg/lint/lint_multiple_alter_test.go index a6b99522..d0aea9f5 100644 --- a/pkg/lint/linters/multiple_alter_test.go +++ b/pkg/lint/lint_multiple_alter_test.go @@ -1,9 +1,8 @@ -package linters +package lint import ( "testing" - "github.com/block/spirit/pkg/lint" "github.com/block/spirit/pkg/statement" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -33,7 +32,7 @@ func TestMultipleAlterTableLinter_TwoAltersOnSameTable(t *testing.T) { require.Len(t, violations, 1) assert.Equal(t, "multiple_alter_table", violations[0].Linter.Name()) - assert.Equal(t, lint.SeverityInfo, violations[0].Severity) + assert.Equal(t, SeverityInfo, violations[0].Severity) assert.Contains(t, violations[0].Message, "2 separate ALTER TABLE statements") assert.Equal(t, "users", violations[0].Location.Table) assert.NotNil(t, violations[0].Suggestion) @@ -195,30 +194,30 @@ func TestMultipleAlterTableLinter_EmptyStatements(t *testing.T) { } func TestMultipleAlterTableLinter_Integration(t *testing.T) { - lint.Reset() - lint.Register(&MultipleAlterTableLinter{}) + Reset() + Register(&MultipleAlterTableLinter{}) sql := `ALTER TABLE users ADD COLUMN age INT; ALTER TABLE users ADD INDEX idx_age (age)` stmts, err := statement.New(sql) require.NoError(t, err) - violations := lint.RunLinters(nil, stmts, lint.Config{}) + violations := RunLinters(nil, stmts, Config{}) require.Len(t, violations, 1) assert.Equal(t, "multiple_alter_table", violations[0].Linter.Name()) } func TestMultipleAlterTableLinter_IntegrationDisabled(t *testing.T) { - lint.Reset() - lint.Register(&MultipleAlterTableLinter{}) + Reset() + Register(&MultipleAlterTableLinter{}) sql := `ALTER TABLE users ADD COLUMN age INT; ALTER TABLE users ADD INDEX idx_age (age)` stmts, err := statement.New(sql) require.NoError(t, err) - violations := lint.RunLinters(nil, stmts, lint.Config{ + violations := RunLinters(nil, stmts, Config{ Enabled: map[string]bool{ "multiple_alter_table": false, }, @@ -267,5 +266,5 @@ func TestMultipleAlterTableLinter_SeverityIsInfo(t *testing.T) { require.Len(t, violations, 1) // This is INFO level because it's an optimization suggestion, not an error - assert.Equal(t, lint.SeverityInfo, violations[0].Severity) + assert.Equal(t, SeverityInfo, violations[0].Severity) } diff --git a/pkg/lint/linters/primary_key_type.go b/pkg/lint/lint_primary_key_type.go similarity index 91% rename from pkg/lint/linters/primary_key_type.go rename to pkg/lint/lint_primary_key_type.go index 3d65fcfe..d26a8dc6 100644 --- a/pkg/lint/linters/primary_key_type.go +++ b/pkg/lint/lint_primary_key_type.go @@ -1,16 +1,15 @@ -package linters +package lint import ( "fmt" "strings" - "github.com/block/spirit/pkg/lint" "github.com/block/spirit/pkg/statement" "github.com/pingcap/tidb/pkg/parser/mysql" ) func init() { - lint.Register(&PrimaryKeyTypeLinter{}) + Register(&PrimaryKeyTypeLinter{}) } // PrimaryKeyTypeLinter checks that primary keys use appropriate data types. @@ -34,8 +33,8 @@ func (l *PrimaryKeyTypeLinter) Description() string { return "Ensures primary keys use BIGINT (preferably UNSIGNED) or BINARY/VARBINARY types" } -func (l *PrimaryKeyTypeLinter) Lint(createTables []*statement.CreateTable, _ []*statement.AbstractStatement) []lint.Violation { - var violations []lint.Violation +func (l *PrimaryKeyTypeLinter) Lint(createTables []*statement.CreateTable, _ []*statement.AbstractStatement) []Violation { + var violations []Violation for _, ct := range createTables { tableName := ct.GetTableName() @@ -80,7 +79,7 @@ func (l *PrimaryKeyTypeLinter) getPrimaryKeyColumnsFromIndexes(ct *statement.Cre } // checkColumnType checks if a primary key column has an appropriate type -func (l *PrimaryKeyTypeLinter) checkColumnType(tableName string, column *statement.Column) *lint.Violation { +func (l *PrimaryKeyTypeLinter) checkColumnType(tableName string, column *statement.Column) *Violation { columnType := strings.ToUpper(column.Type) // Check for BIGINT @@ -99,11 +98,11 @@ func (l *PrimaryKeyTypeLinter) checkColumnType(tableName string, column *stateme // BIGINT without UNSIGNED is a warning suggestion := fmt.Sprintf("Consider using BIGINT UNSIGNED for column '%s' to avoid negative values and increase range", column.Name) - return &lint.Violation{ + return &Violation{ Linter: l, - Severity: lint.SeverityWarning, + Severity: SeverityWarning, Message: fmt.Sprintf("Primary key column '%s' uses signed BIGINT; UNSIGNED is preferred", column.Name), - Location: &lint.Location{ + Location: &Location{ Table: tableName, Column: &column.Name, }, @@ -122,11 +121,11 @@ func (l *PrimaryKeyTypeLinter) checkColumnType(tableName string, column *stateme // Any other type is an error suggestion := fmt.Sprintf("Change column '%s' to BIGINT UNSIGNED or BINARY/VARBINARY", column.Name) - return &lint.Violation{ + return &Violation{ Linter: l, - Severity: lint.SeverityError, + Severity: SeverityError, Message: fmt.Sprintf("Primary key column '%s' has type '%s'; must be BIGINT or BINARY/VARBINARY", column.Name, column.Type), - Location: &lint.Location{ + Location: &Location{ Table: tableName, Column: &column.Name, }, diff --git a/pkg/lint/linters/primary_key_type_test.go b/pkg/lint/lint_primary_key_type_test.go similarity index 91% rename from pkg/lint/linters/primary_key_type_test.go rename to pkg/lint/lint_primary_key_type_test.go index 1c90c948..498c0494 100644 --- a/pkg/lint/linters/primary_key_type_test.go +++ b/pkg/lint/lint_primary_key_type_test.go @@ -1,9 +1,8 @@ -package linters +package lint import ( "testing" - "github.com/block/spirit/pkg/lint" "github.com/block/spirit/pkg/statement" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -38,7 +37,7 @@ func TestPrimaryKeyTypeLinter_BigIntSigned(t *testing.T) { // BIGINT without UNSIGNED should be a warning require.Len(t, violations, 1) assert.Equal(t, "primary_key_type", violations[0].Linter.Name()) - assert.Equal(t, lint.SeverityWarning, violations[0].Severity) + assert.Equal(t, SeverityWarning, violations[0].Severity) assert.Contains(t, violations[0].Message, "signed BIGINT") assert.Contains(t, violations[0].Message, "UNSIGNED is preferred") assert.Equal(t, "users", violations[0].Location.Table) @@ -92,7 +91,7 @@ func TestPrimaryKeyTypeLinter_IntError(t *testing.T) { // INT is not acceptable - should be an error require.Len(t, violations, 1) assert.Equal(t, "primary_key_type", violations[0].Linter.Name()) - assert.Equal(t, lint.SeverityError, violations[0].Severity) + assert.Equal(t, SeverityError, violations[0].Severity) assert.Contains(t, violations[0].Message, "must be BIGINT or BINARY/VARBINARY") assert.Equal(t, "users", violations[0].Location.Table) assert.NotNil(t, violations[0].Location.Column) @@ -115,7 +114,7 @@ func TestPrimaryKeyTypeLinter_VarcharError(t *testing.T) { // VARCHAR is not acceptable - should be an error require.Len(t, violations, 1) assert.Equal(t, "primary_key_type", violations[0].Linter.Name()) - assert.Equal(t, lint.SeverityError, violations[0].Severity) + assert.Equal(t, SeverityError, violations[0].Severity) assert.Contains(t, violations[0].Message, "must be BIGINT or BINARY/VARBINARY") assert.NotNil(t, violations[0].Context) // The parser returns lowercase "varchar" @@ -137,7 +136,7 @@ func TestPrimaryKeyTypeLinter_CompositePrimaryKey(t *testing.T) { // user_id is BIGINT UNSIGNED (good), role_id is INT (error) require.Len(t, violations, 1) assert.Equal(t, "primary_key_type", violations[0].Linter.Name()) - assert.Equal(t, lint.SeverityError, violations[0].Severity) + assert.Equal(t, SeverityError, violations[0].Severity) assert.Equal(t, "role_id", *violations[0].Location.Column) } @@ -172,7 +171,7 @@ func TestPrimaryKeyTypeLinter_CompositePrimaryKeyMixed(t *testing.T) { // user_id is signed BIGINT (warning) require.Len(t, violations, 1) assert.Equal(t, "primary_key_type", violations[0].Linter.Name()) - assert.Equal(t, lint.SeverityWarning, violations[0].Severity) + assert.Equal(t, SeverityWarning, violations[0].Severity) assert.Equal(t, "user_id", *violations[0].Location.Column) } @@ -212,7 +211,7 @@ func TestPrimaryKeyTypeLinter_MultipleTables(t *testing.T) { // Only orders table should have a violation require.Len(t, violations, 1) assert.Equal(t, "orders", violations[0].Location.Table) - assert.Equal(t, lint.SeverityError, violations[0].Severity) + assert.Equal(t, SeverityError, violations[0].Severity) } func TestPrimaryKeyTypeLinter_SmallIntError(t *testing.T) { @@ -228,7 +227,7 @@ func TestPrimaryKeyTypeLinter_SmallIntError(t *testing.T) { // SMALLINT is not acceptable - should be an error require.Len(t, violations, 1) - assert.Equal(t, lint.SeverityError, violations[0].Severity) + assert.Equal(t, SeverityError, violations[0].Severity) } func TestPrimaryKeyTypeLinter_MediumIntError(t *testing.T) { @@ -244,7 +243,7 @@ func TestPrimaryKeyTypeLinter_MediumIntError(t *testing.T) { // MEDIUMINT is not acceptable - should be an error require.Len(t, violations, 1) - assert.Equal(t, lint.SeverityError, violations[0].Severity) + assert.Equal(t, SeverityError, violations[0].Severity) } func TestPrimaryKeyTypeLinter_CharError(t *testing.T) { @@ -260,7 +259,7 @@ func TestPrimaryKeyTypeLinter_CharError(t *testing.T) { // CHAR is not acceptable - should be an error require.Len(t, violations, 1) - assert.Equal(t, lint.SeverityError, violations[0].Severity) + assert.Equal(t, SeverityError, violations[0].Severity) } func TestPrimaryKeyTypeLinter_EmptyInput(t *testing.T) { @@ -271,8 +270,8 @@ func TestPrimaryKeyTypeLinter_EmptyInput(t *testing.T) { } func TestPrimaryKeyTypeLinter_Integration(t *testing.T) { - lint.Reset() - lint.Register(&PrimaryKeyTypeLinter{}) + Reset() + Register(&PrimaryKeyTypeLinter{}) sql := `CREATE TABLE users ( id INT PRIMARY KEY, @@ -281,15 +280,15 @@ func TestPrimaryKeyTypeLinter_Integration(t *testing.T) { ct, err := statement.ParseCreateTable(sql) require.NoError(t, err) - violations := lint.RunLinters([]*statement.CreateTable{ct}, nil, lint.Config{}) + violations := RunLinters([]*statement.CreateTable{ct}, nil, Config{}) require.Len(t, violations, 1) assert.Equal(t, "primary_key_type", violations[0].Linter.Name()) } func TestPrimaryKeyTypeLinter_IntegrationDisabled(t *testing.T) { - lint.Reset() - lint.Register(&PrimaryKeyTypeLinter{}) + Reset() + Register(&PrimaryKeyTypeLinter{}) sql := `CREATE TABLE users ( id INT PRIMARY KEY, @@ -298,7 +297,7 @@ func TestPrimaryKeyTypeLinter_IntegrationDisabled(t *testing.T) { ct, err := statement.ParseCreateTable(sql) require.NoError(t, err) - violations := lint.RunLinters([]*statement.CreateTable{ct}, nil, lint.Config{ + violations := RunLinters([]*statement.CreateTable{ct}, nil, Config{ Enabled: map[string]bool{ "primary_key_type": false, }, @@ -375,7 +374,7 @@ func TestPrimaryKeyTypeLinter_UUIDAsVarchar(t *testing.T) { // VARCHAR for UUID should be an error require.Len(t, violations, 1) - assert.Equal(t, lint.SeverityError, violations[0].Severity) + assert.Equal(t, SeverityError, violations[0].Severity) assert.Contains(t, *violations[0].Suggestion, "BINARY/VARBINARY") } From 8c8bda6e188b4f04c7d64f395c3827cea9370966 Mon Sep 17 00:00:00 2001 From: Kolbe Kegel Date: Mon, 20 Oct 2025 12:32:31 -0600 Subject: [PATCH 03/12] Removed Category from Linter interface --- pkg/lint/example/example.go | 10 +-- pkg/lint/example/example_test.go | 8 +-- pkg/lint/lint_invisible_index.go | 4 -- pkg/lint/lint_invisible_index_test.go | 1 - pkg/lint/lint_multiple_alter.go | 4 -- pkg/lint/lint_multiple_alter_test.go | 1 - pkg/lint/lint_primary_key_type.go | 4 -- pkg/lint/lint_primary_key_type_test.go | 1 - pkg/lint/lint_test.go | 87 ++++++-------------------- pkg/lint/linter.go | 4 -- pkg/lint/registry.go | 28 ++------- 11 files changed, 28 insertions(+), 124 deletions(-) diff --git a/pkg/lint/example/example.go b/pkg/lint/example/example.go index 0fc79ced..b6a5d85b 100644 --- a/pkg/lint/example/example.go +++ b/pkg/lint/example/example.go @@ -24,7 +24,7 @@ type TableNameLengthConfig struct { // NewTableNameLengthLinter creates a new table name length linter with default configuration. func NewTableNameLengthLinter() *TableNameLengthLinter { return &TableNameLengthLinter{ - maxLength: 64, // MySQL's limit + maxLength: 58, // MySQL's limit is 64 but we use 58 to allow for prefixes/suffixes } } @@ -36,10 +36,6 @@ func (l *TableNameLengthLinter) Name() string { return "table_name_length" } -func (l *TableNameLengthLinter) Category() string { - return "naming" -} - func (l *TableNameLengthLinter) Description() string { return "Checks that table names do not exceed the configured maximum length (default: 64 characters)" } @@ -96,10 +92,6 @@ func (l *DuplicateColumnLinter) Name() string { return "duplicate_column" } -func (l *DuplicateColumnLinter) Category() string { - return "schema" -} - func (l *DuplicateColumnLinter) Description() string { return "Detects duplicate column definitions in CREATE TABLE statements" } diff --git a/pkg/lint/example/example_test.go b/pkg/lint/example/example_test.go index 7cfd23c7..622b7a07 100644 --- a/pkg/lint/example/example_test.go +++ b/pkg/lint/example/example_test.go @@ -40,9 +40,9 @@ func TestTableNameLengthLinter_TooLong(t *testing.T) { } func TestTableNameLengthLinter_ExactlyAtLimit(t *testing.T) { - // Create a table name that's exactly 64 characters - exactName := "abcdefghij_abcdefghij_abcdefghij_abcdefghij_abcdefghij_abcdefghi" - require.Len(t, exactName, 64, "Test setup: name should be exactly 64 chars") + // Create a table name that's exactly 58 characters + exactName := "abcdefghij_abcdefghij_abcdefghij_abcdefghij_abcdefghij_abc" + require.Len(t, exactName, 58, "Test setup: name should be exactly 58 chars") sql := "CREATE TABLE " + exactName + " (id INT PRIMARY KEY)" ct, err := statement.ParseCreateTable(sql) @@ -51,7 +51,7 @@ func TestTableNameLengthLinter_ExactlyAtLimit(t *testing.T) { linter := NewTableNameLengthLinter() violations := linter.Lint([]*statement.CreateTable{ct}, nil) - assert.Empty(t, violations, "64 character name should be allowed") + assert.Empty(t, violations, "58 character name should be allowed") } func TestTableNameLengthLinter_Configure(t *testing.T) { diff --git a/pkg/lint/lint_invisible_index.go b/pkg/lint/lint_invisible_index.go index efd0e7bc..06fa1eb0 100644 --- a/pkg/lint/lint_invisible_index.go +++ b/pkg/lint/lint_invisible_index.go @@ -23,10 +23,6 @@ func (l *InvisibleIndexBeforeDropLinter) Name() string { return "invisible_index_before_drop" } -func (l *InvisibleIndexBeforeDropLinter) Category() string { - return "schema" -} - func (l *InvisibleIndexBeforeDropLinter) Description() string { return "Requires indexes to be made invisible before dropping them as a safety measure" } diff --git a/pkg/lint/lint_invisible_index_test.go b/pkg/lint/lint_invisible_index_test.go index 8c5d6ff7..291d9757 100644 --- a/pkg/lint/lint_invisible_index_test.go +++ b/pkg/lint/lint_invisible_index_test.go @@ -182,6 +182,5 @@ func TestInvisibleIndexBeforeDropLinter_Metadata(t *testing.T) { linter := &InvisibleIndexBeforeDropLinter{} assert.Equal(t, "invisible_index_before_drop", linter.Name()) - assert.Equal(t, "schema", linter.Category()) assert.NotEmpty(t, linter.Description()) } diff --git a/pkg/lint/lint_multiple_alter.go b/pkg/lint/lint_multiple_alter.go index 96b6a13e..49fe05a7 100644 --- a/pkg/lint/lint_multiple_alter.go +++ b/pkg/lint/lint_multiple_alter.go @@ -24,10 +24,6 @@ func (l *MultipleAlterTableLinter) Name() string { return "multiple_alter_table" } -func (l *MultipleAlterTableLinter) Category() string { - return "schema" -} - func (l *MultipleAlterTableLinter) Description() string { return "Detects multiple ALTER TABLE statements on the same table that could be combined" } diff --git a/pkg/lint/lint_multiple_alter_test.go b/pkg/lint/lint_multiple_alter_test.go index d0aea9f5..c027240c 100644 --- a/pkg/lint/lint_multiple_alter_test.go +++ b/pkg/lint/lint_multiple_alter_test.go @@ -230,7 +230,6 @@ func TestMultipleAlterTableLinter_Metadata(t *testing.T) { linter := &MultipleAlterTableLinter{} assert.Equal(t, "multiple_alter_table", linter.Name()) - assert.Equal(t, "schema", linter.Category()) assert.NotEmpty(t, linter.Description()) } diff --git a/pkg/lint/lint_primary_key_type.go b/pkg/lint/lint_primary_key_type.go index d26a8dc6..b37d81be 100644 --- a/pkg/lint/lint_primary_key_type.go +++ b/pkg/lint/lint_primary_key_type.go @@ -25,10 +25,6 @@ func (l *PrimaryKeyTypeLinter) Name() string { return "primary_key_type" } -func (l *PrimaryKeyTypeLinter) Category() string { - return "schema" -} - func (l *PrimaryKeyTypeLinter) Description() string { return "Ensures primary keys use BIGINT (preferably UNSIGNED) or BINARY/VARBINARY types" } diff --git a/pkg/lint/lint_primary_key_type_test.go b/pkg/lint/lint_primary_key_type_test.go index 498c0494..1887f745 100644 --- a/pkg/lint/lint_primary_key_type_test.go +++ b/pkg/lint/lint_primary_key_type_test.go @@ -310,7 +310,6 @@ func TestPrimaryKeyTypeLinter_Metadata(t *testing.T) { linter := &PrimaryKeyTypeLinter{} assert.Equal(t, "primary_key_type", linter.Name()) - assert.Equal(t, "schema", linter.Category()) assert.NotEmpty(t, linter.Description()) } diff --git a/pkg/lint/lint_test.go b/pkg/lint/lint_test.go index a9e36437..388804c0 100644 --- a/pkg/lint/lint_test.go +++ b/pkg/lint/lint_test.go @@ -11,7 +11,6 @@ import ( // Mock linter for testing type mockLinter struct { name string - category string description string violations []Violation } @@ -22,7 +21,6 @@ func (m *mockLinter) String() string { } func (m *mockLinter) Name() string { return m.name } -func (m *mockLinter) Category() string { return m.category } func (m *mockLinter) Description() string { return m.description } func (m *mockLinter) Lint(createTables []*statement.CreateTable, statements []*statement.AbstractStatement) []Violation { return m.violations @@ -58,7 +56,6 @@ func TestRegister(t *testing.T) { linter := &mockLinter{ name: "test_linter", - category: "test", description: "A test linter", } @@ -77,9 +74,9 @@ func TestRegister(t *testing.T) { func TestRegisterMultiple(t *testing.T) { Reset() - linter1 := &mockLinter{name: "linter1", category: "cat1"} - linter2 := &mockLinter{name: "linter2", category: "cat2"} - linter3 := &mockLinter{name: "linter3", category: "cat1"} + linter1 := &mockLinter{name: "linter1"} + linter2 := &mockLinter{name: "linter2"} + linter3 := &mockLinter{name: "linter3"} Register(linter1) Register(linter2) @@ -92,34 +89,10 @@ func TestRegisterMultiple(t *testing.T) { assert.Contains(t, names, "linter3") } -func TestListByCategory(t *testing.T) { - Reset() - - linter1 := &mockLinter{name: "linter1", category: "naming"} - linter2 := &mockLinter{name: "linter2", category: "performance"} - linter3 := &mockLinter{name: "linter3", category: "naming"} - - Register(linter1) - Register(linter2) - Register(linter3) - - namingLinters := ListByCategory("naming") - assert.Len(t, namingLinters, 2) - assert.Contains(t, namingLinters, "linter1") - assert.Contains(t, namingLinters, "linter3") - - perfLinters := ListByCategory("performance") - assert.Len(t, perfLinters, 1) - assert.Contains(t, perfLinters, "linter2") - - emptyLinters := ListByCategory("nonexistent") - assert.Empty(t, emptyLinters) -} - func TestEnableDisable(t *testing.T) { Reset() - linter := &mockLinter{name: "test_linter", category: "test"} + linter := &mockLinter{name: "test_linter"} Register(linter) // Linters are enabled by default @@ -153,7 +126,6 @@ func TestGet(t *testing.T) { linter := &mockLinter{ name: "test_linter", - category: "test", description: "A test linter", } Register(linter) @@ -161,7 +133,6 @@ func TestGet(t *testing.T) { retrieved, err := Get("test_linter") require.NoError(t, err) assert.Equal(t, "test_linter", retrieved.Name()) - assert.Equal(t, "test", retrieved.Category()) assert.Equal(t, "A test linter", retrieved.Description()) } @@ -184,8 +155,7 @@ func TestRunLinters_SingleLinter(t *testing.T) { Reset() linter := &mockLinter{ - name: "test_linter", - category: "test", + name: "test_linter", } expectedViolations := []Violation{ @@ -210,16 +180,14 @@ func TestRunLinters_MultipleLinters(t *testing.T) { Reset() linter1 := &mockLinter{ - name: "linter1", - category: "test", + name: "linter1", } linter1.violations = []Violation{ {Linter: linter1, Severity: SeverityError, Message: "Error 1"}, } linter2 := &mockLinter{ - name: "linter2", - category: "test", + name: "linter2", } linter2.violations = []Violation{ {Linter: linter2, Severity: SeverityWarning, Message: "Warning 1"}, @@ -237,8 +205,7 @@ func TestRunLinters_WithConfig_Disabled(t *testing.T) { Reset() linter := &mockLinter{ - name: "test_linter", - category: "test", + name: "test_linter", } linter.violations = []Violation{ {Linter: linter, Severity: SeverityError, Message: "Should not see this"}, @@ -259,8 +226,7 @@ func TestRunLinters_WithConfig_Enabled(t *testing.T) { Reset() linter := &mockLinter{ - name: "test_linter", - category: "test", + name: "test_linter", } linter.violations = []Violation{ {Linter: linter, Severity: SeverityError, Message: "Should see this"}, @@ -286,7 +252,6 @@ func TestRunLinters_ConfigurableLinter(t *testing.T) { linter := &mockConfigurableLinter{} linter.name = "configurable_linter" - linter.category = "test" linter.violations = []Violation{ {Linter: linter, Severity: SeverityError, Message: "Test"}, } @@ -309,7 +274,6 @@ func TestRunLinters_ConfigurableLinter_NoConfig(t *testing.T) { linter := &mockConfigurableLinter{} linter.name = "configurable_linter" - linter.category = "test" linter.violations = []Violation{ {Linter: linter, Severity: SeverityError, Message: "Test"}, } @@ -366,8 +330,8 @@ func TestFilterBySeverity(t *testing.T) { } func TestFilterByLinter(t *testing.T) { - linter1 := &mockLinter{name: "linter1", category: "test"} - linter2 := &mockLinter{name: "linter2", category: "test"} + linter1 := &mockLinter{name: "linter1"} + linter2 := &mockLinter{name: "linter2"} violations := []Violation{ {Linter: linter1, Message: "Message 1"}, @@ -392,32 +356,19 @@ func TestListSorted(t *testing.T) { Reset() // Register in non-alphabetical order - Register(&mockLinter{name: "zebra", category: "test"}) - Register(&mockLinter{name: "alpha", category: "test"}) - Register(&mockLinter{name: "beta", category: "test"}) + Register(&mockLinter{name: "zebra"}) + Register(&mockLinter{name: "alpha"}) + Register(&mockLinter{name: "beta"}) names := List() assert.Equal(t, []string{"alpha", "beta", "zebra"}, names) } -func TestListByCategorySorted(t *testing.T) { - Reset() - - // Register in non-alphabetical order - Register(&mockLinter{name: "zebra", category: "cat1"}) - Register(&mockLinter{name: "alpha", category: "cat1"}) - Register(&mockLinter{name: "beta", category: "cat2"}) - Register(&mockLinter{name: "gamma", category: "cat1"}) - - names := ListByCategory("cat1") - assert.Equal(t, []string{"alpha", "gamma", "zebra"}, names) -} - func TestReset(t *testing.T) { Reset() - Register(&mockLinter{name: "linter1", category: "test"}) - Register(&mockLinter{name: "linter2", category: "test"}) + Register(&mockLinter{name: "linter1"}) + Register(&mockLinter{name: "linter2"}) assert.Len(t, List(), 2) @@ -430,7 +381,7 @@ func TestViolationWithLocation(t *testing.T) { column := "test_column" index := "test_index" constraint := "test_constraint" - linter := &mockLinter{name: "test_linter", category: "test"} + linter := &mockLinter{name: "test_linter"} violation := Violation{ Linter: linter, @@ -452,7 +403,7 @@ func TestViolationWithLocation(t *testing.T) { func TestViolationWithSuggestion(t *testing.T) { suggestion := "Try this instead" - linter := &mockLinter{name: "test_linter", category: "test"} + linter := &mockLinter{name: "test_linter"} violation := Violation{ Linter: linter, @@ -466,7 +417,7 @@ func TestViolationWithSuggestion(t *testing.T) { } func TestViolationWithContext(t *testing.T) { - linter := &mockLinter{name: "test_linter", category: "test"} + linter := &mockLinter{name: "test_linter"} violation := Violation{ Linter: linter, diff --git a/pkg/lint/linter.go b/pkg/lint/linter.go index c2dd2984..09564f0c 100644 --- a/pkg/lint/linter.go +++ b/pkg/lint/linter.go @@ -9,10 +9,6 @@ type Linter interface { // Name returns the unique name of this linter Name() string - // Category returns the category this linter belongs to - // (e.g., "naming", "performance", "security", "schema") - Category() string - // Description returns a human-readable description of what this linter checks Description() string diff --git a/pkg/lint/registry.go b/pkg/lint/registry.go index 92853873..d6a5d1c5 100644 --- a/pkg/lint/registry.go +++ b/pkg/lint/registry.go @@ -8,9 +8,8 @@ import ( // linter represents a registered linter with metadata type linter struct { - impl Linter - category string - enabled bool + impl Linter + enabled bool } var ( @@ -30,9 +29,8 @@ func Register(l Linter) { } linters[l.Name()] = linter{ - impl: l, - category: l.Category(), - enabled: true, + impl: l, + enabled: true, } } @@ -89,24 +87,6 @@ func List() []string { return names } -// ListByCategory returns the names of all linters in a specific category in sorted order. -func ListByCategory(category string) []string { - lock.RLock() - defer lock.RUnlock() - - var names []string - - for name, linter := range linters { - if linter.category == category { - names = append(names, name) - } - } - - sort.Strings(names) - - return names -} - // Get returns a linter by name. // Returns an error if the linter is not found. func Get(name string) (Linter, error) { From b40746ad88aa54bc50052f7cd9161893d1cda3d0 Mon Sep 17 00:00:00 2001 From: Kolbe Kegel Date: Mon, 20 Oct 2025 12:36:08 -0600 Subject: [PATCH 04/12] add native linting to experimental section of USAGE.md --- USAGE.md | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/USAGE.md b/USAGE.md index d580aad7..56f46057 100644 --- a/USAGE.md +++ b/USAGE.md @@ -29,6 +29,8 @@ - [Experimental Features](#experimental-features) - [enable-experimental-multi-table-support](#enable-experimental-multi-table-support) - [enable-experimental-buffered-copy](#enable-experimental-buffered-copy) + - [`move` command](#move-command) + - [native linting support](#native-linting-support) ## Getting Started @@ -365,4 +367,15 @@ This feature provides a new top level binary `move`, which can copy whole schema This command depends strongly on the experimental buffered copy and multi-table support, both which are currently experimental. There is not too much which is special to move on top of these two features, so once they become stable, so too can `move`. -It is anticipated that `move` will need to provide some pluggable method of cutover so external metadata systems can be updated. There is no current design for this. \ No newline at end of file +It is anticipated that `move` will need to provide some pluggable method of cutover so external metadata systems can be updated. There is no current design for this. + + +### native linting support + +**Feature Description** + +This feature adds native linting support to Spirit, allowing for various rules to be applied to schema changes before they are executed. + +**Current Status** + +This feature is partially complete. It relies on new support for parsing CREATE TABLE statements (see `pkg/statetement/parse_create_table.go`). There are so far only a few linters implemented. This functionality is not currently exposed via command line flags. \ No newline at end of file From 7129cb6b5809b952811d1b3001abeb25c2c31d82 Mon Sep 17 00:00:00 2001 From: Kolbe Kegel Date: Mon, 20 Oct 2025 18:30:46 -0700 Subject: [PATCH 05/12] make linter String methods call a ling.Stringer function for consistency --- pkg/lint/example/example.go | 2 +- pkg/lint/lint_invisible_index.go | 2 +- pkg/lint/lint_multiple_alter.go | 2 +- pkg/lint/lint_primary_key_type.go | 2 +- pkg/lint/linter.go | 6 ++++++ 5 files changed, 10 insertions(+), 4 deletions(-) diff --git a/pkg/lint/example/example.go b/pkg/lint/example/example.go index b6a5d85b..fec9909a 100644 --- a/pkg/lint/example/example.go +++ b/pkg/lint/example/example.go @@ -29,7 +29,7 @@ func NewTableNameLengthLinter() *TableNameLengthLinter { } func (l *TableNameLengthLinter) String() string { - return l.Name() + return lint.Stringer(l) } func (l *TableNameLengthLinter) Name() string { diff --git a/pkg/lint/lint_invisible_index.go b/pkg/lint/lint_invisible_index.go index 06fa1eb0..9c324aa9 100644 --- a/pkg/lint/lint_invisible_index.go +++ b/pkg/lint/lint_invisible_index.go @@ -16,7 +16,7 @@ func init() { type InvisibleIndexBeforeDropLinter struct{} func (l *InvisibleIndexBeforeDropLinter) String() string { - return l.Name() + return Stringer(l) } func (l *InvisibleIndexBeforeDropLinter) Name() string { diff --git a/pkg/lint/lint_multiple_alter.go b/pkg/lint/lint_multiple_alter.go index 49fe05a7..861fb4d2 100644 --- a/pkg/lint/lint_multiple_alter.go +++ b/pkg/lint/lint_multiple_alter.go @@ -17,7 +17,7 @@ func init() { type MultipleAlterTableLinter struct{} func (l *MultipleAlterTableLinter) String() string { - return l.Name() + return Stringer(l) } func (l *MultipleAlterTableLinter) Name() string { diff --git a/pkg/lint/lint_primary_key_type.go b/pkg/lint/lint_primary_key_type.go index b37d81be..b0e81d52 100644 --- a/pkg/lint/lint_primary_key_type.go +++ b/pkg/lint/lint_primary_key_type.go @@ -18,7 +18,7 @@ func init() { type PrimaryKeyTypeLinter struct{} func (l *PrimaryKeyTypeLinter) String() string { - return l.Name() + return Stringer(l) } func (l *PrimaryKeyTypeLinter) Name() string { diff --git a/pkg/lint/linter.go b/pkg/lint/linter.go index 09564f0c..e780928c 100644 --- a/pkg/lint/linter.go +++ b/pkg/lint/linter.go @@ -30,3 +30,9 @@ type ConfigurableLinter interface { // DefaultConfig returns the default configuration for this linter DefaultConfig() any } + +// Stringer returns a string representation of the linter +// This is a helper function used by linters' String() methods. +func Stringer(l Linter) string { + return l.Name() + " - " + l.Description() +} From 3063f106f617ab33cfc44410271ab7f2512d7807 Mon Sep 17 00:00:00 2001 From: Kolbe Kegel Date: Mon, 20 Oct 2025 20:30:54 -0700 Subject: [PATCH 06/12] Make invisible_index_before_drop configurable and modify RunLinters to return an error --- pkg/lint/lint.go | 11 ++++++-- pkg/lint/lint_invisible_index.go | 43 ++++++++++++++++++++++++++++++-- 2 files changed, 50 insertions(+), 4 deletions(-) diff --git a/pkg/lint/lint.go b/pkg/lint/lint.go index 62de1ed3..ed4a96a0 100644 --- a/pkg/lint/lint.go +++ b/pkg/lint/lint.go @@ -48,6 +48,10 @@ package lint import ( + "errors" + "fmt" + "os" + "github.com/block/spirit/pkg/statement" ) @@ -74,7 +78,8 @@ type Config struct { // // If a linter implements ConfigurableLinter and has settings in config.Settings, // those settings are applied before running the linter. -func RunLinters(createTables []*statement.CreateTable, alterStatements []*statement.AbstractStatement, config Config) []Violation { +func RunLinters(createTables []*statement.CreateTable, alterStatements []*statement.AbstractStatement, config Config) ([]Violation, error) { + var errs []error lock.RLock() defer lock.RUnlock() @@ -104,6 +109,8 @@ func RunLinters(createTables []*statement.CreateTable, alterStatements []*statem if err != nil { // Configuration error - skip this linter // In a production system, we might want to log this + fmt.Fprintf(os.Stderr, "Error configuring %s: %s\n", name, err) + errs = append(errs, err) continue } } @@ -114,7 +121,7 @@ func RunLinters(createTables []*statement.CreateTable, alterStatements []*statem violations = append(violations, lintViolations...) } - return violations + return violations, errors.Join(errs...) } // HasErrors returns true if any violations have ERROR severity. diff --git a/pkg/lint/lint_invisible_index.go b/pkg/lint/lint_invisible_index.go index 9c324aa9..14b6eace 100644 --- a/pkg/lint/lint_invisible_index.go +++ b/pkg/lint/lint_invisible_index.go @@ -1,7 +1,9 @@ package lint import ( + "errors" "fmt" + "strings" "github.com/block/spirit/pkg/statement" "github.com/pingcap/tidb/pkg/parser/ast" @@ -13,7 +15,9 @@ func init() { // InvisibleIndexBeforeDropLinter checks that indexes are made invisible before dropping. // This is a safety practice to ensure the index is not needed before permanently removing it. -type InvisibleIndexBeforeDropLinter struct{} +type InvisibleIndexBeforeDropLinter struct { + raiseError bool +} func (l *InvisibleIndexBeforeDropLinter) String() string { return Stringer(l) @@ -27,7 +31,42 @@ func (l *InvisibleIndexBeforeDropLinter) Description() string { return "Requires indexes to be made invisible before dropping them as a safety measure" } +func (l *InvisibleIndexBeforeDropLinter) Configure(a any) error { + c, ok := a.(map[string]string) + if !ok { + return errors.New(l.Name() + " config must be a map[string]string") + } + for k, v := range c { + switch k { + case "raiseError": + if strings.EqualFold(v, "true") { + l.raiseError = true + break + } + if strings.EqualFold(v, "false") { + l.raiseError = false + break + } + return fmt.Errorf("invalid value for %s: %s", k, v) + default: + return fmt.Errorf("unknown config key for %s: %s", l.Name(), k) + } + } + return nil +} +func (l *InvisibleIndexBeforeDropLinter) DefaultConfig() any { + return map[string]string{ + "raiseError": "true", + } +} + +var _ ConfigurableLinter = &InvisibleIndexBeforeDropLinter{} + func (l *InvisibleIndexBeforeDropLinter) Lint(createTables []*statement.CreateTable, statements []*statement.AbstractStatement) []Violation { + severity := SeverityWarning + if l.raiseError { + severity = SeverityError + } var violations []Violation for _, stmt := range statements { @@ -70,7 +109,7 @@ func (l *InvisibleIndexBeforeDropLinter) Lint(createTables []*statement.CreateTa suggestion := fmt.Sprintf("First make the index invisible: ALTER TABLE %s ALTER INDEX %s INVISIBLE", tableName, indexName) violations = append(violations, Violation{ Linter: l, - Severity: SeverityWarning, + Severity: severity, Message: fmt.Sprintf("Index '%s' should be made invisible before dropping to ensure it's not needed", indexName), Location: &Location{ Table: tableName, From 2444df42daa6c07024c6914ec3095896457b3d07 Mon Sep 17 00:00:00 2001 From: Kolbe Kegel Date: Mon, 20 Oct 2025 20:52:34 -0700 Subject: [PATCH 07/12] adapt linters and tests to new RunLinters signature that returns errors --- pkg/lint/README.md | 12 +++++++++--- pkg/lint/example/example_test.go | 12 ++++++++---- pkg/lint/lint.go | 2 ++ pkg/lint/lint_invisible_index.go | 5 +++++ pkg/lint/lint_invisible_index_test.go | 6 ++++-- pkg/lint/lint_multiple_alter_test.go | 6 ++++-- pkg/lint/lint_primary_key_type_test.go | 6 ++++-- pkg/lint/lint_test.go | 21 ++++++++++++++------- 8 files changed, 50 insertions(+), 20 deletions(-) diff --git a/pkg/lint/README.md b/pkg/lint/README.md index be5de00e..9a668aac 100644 --- a/pkg/lint/README.md +++ b/pkg/lint/README.md @@ -19,7 +19,10 @@ import ( ) // All built-in linters are automatically registered! -violations := lint.RunLinters(tables, stmts, lint.Config{}) +violations, err := lint.RunLinters(tables, stmts, lint.Config{}) +if err != nil { + // Handle configuration errors +} // Check for errors if lint.HasErrors(violations) { @@ -84,12 +87,15 @@ func (l *MyCustomLinter) Lint(createTables []*statement.CreateTable, alterStatem ```go // Disable specific linters -violations := lint.RunLinters(tables, stmts, lint.Config{ +violations, err := lint.RunLinters(tables, stmts, lint.Config{ Enabled: map[string]bool{ "table_name_length": false, "duplicate_column": true, }, }) +if err != nil { + // Handle configuration errors +} ``` ## Core Types @@ -137,7 +143,7 @@ type Location struct { ### Execution -- `RunLinters(createTables, alterStatements, config)` - Run all enabled linters +- `RunLinters(createTables, alterStatements, config) ([]Violation, error)` - Run all enabled linters, returns violations and any configuration errors - `HasErrors(violations)` - Check if any violations are errors - `HasWarnings(violations)` - Check if any violations are warnings - `FilterBySeverity(violations, severity)` - Filter by severity level diff --git a/pkg/lint/example/example_test.go b/pkg/lint/example/example_test.go index 622b7a07..24b4a9f3 100644 --- a/pkg/lint/example/example_test.go +++ b/pkg/lint/example/example_test.go @@ -177,7 +177,8 @@ func TestExampleLinters_Integration(t *testing.T) { require.NoError(t, err) // Run all linters - violations := lint.RunLinters([]*statement.CreateTable{ct}, nil, lint.Config{}) + violations, err := lint.RunLinters([]*statement.CreateTable{ct}, nil, lint.Config{}) + require.NoError(t, err) // Should have violations from both linters require.Len(t, violations, 2) @@ -204,12 +205,13 @@ func TestExampleLinters_WithConfig(t *testing.T) { require.NoError(t, err) // Disable the table name length linter - violations := lint.RunLinters([]*statement.CreateTable{ct}, nil, lint.Config{ + violations, err := lint.RunLinters([]*statement.CreateTable{ct}, nil, lint.Config{ Enabled: map[string]bool{ "table_name_length": false, "duplicate_column": true, }, }) + require.NoError(t, err) // Should only have violation from duplicate_column linter require.Len(t, violations, 1) @@ -230,15 +232,17 @@ func TestTableNameLengthLinter_WithConfigSettings(t *testing.T) { require.NoError(t, err) // With default config (64), should pass - violations := lint.RunLinters([]*statement.CreateTable{ct}, nil, lint.Config{}) + violations, err := lint.RunLinters([]*statement.CreateTable{ct}, nil, lint.Config{}) + require.NoError(t, err) assert.Empty(t, violations) // Configure max length to 40 via Config.Settings - violations = lint.RunLinters([]*statement.CreateTable{ct}, nil, lint.Config{ + violations, err = lint.RunLinters([]*statement.CreateTable{ct}, nil, lint.Config{ Settings: map[string]any{ "table_name_length": TableNameLengthConfig{MaxLength: 40}, }, }) + require.NoError(t, err) // Should now have a violation require.Len(t, violations, 1) diff --git a/pkg/lint/lint.go b/pkg/lint/lint.go index ed4a96a0..37b5081c 100644 --- a/pkg/lint/lint.go +++ b/pkg/lint/lint.go @@ -80,6 +80,7 @@ type Config struct { // those settings are applied before running the linter. func RunLinters(createTables []*statement.CreateTable, alterStatements []*statement.AbstractStatement, config Config) ([]Violation, error) { var errs []error + lock.RLock() defer lock.RUnlock() @@ -111,6 +112,7 @@ func RunLinters(createTables []*statement.CreateTable, alterStatements []*statem // In a production system, we might want to log this fmt.Fprintf(os.Stderr, "Error configuring %s: %s\n", name, err) errs = append(errs, err) + continue } } diff --git a/pkg/lint/lint_invisible_index.go b/pkg/lint/lint_invisible_index.go index 14b6eace..fb2e8904 100644 --- a/pkg/lint/lint_invisible_index.go +++ b/pkg/lint/lint_invisible_index.go @@ -36,6 +36,7 @@ func (l *InvisibleIndexBeforeDropLinter) Configure(a any) error { if !ok { return errors.New(l.Name() + " config must be a map[string]string") } + for k, v := range c { switch k { case "raiseError": @@ -43,15 +44,18 @@ func (l *InvisibleIndexBeforeDropLinter) Configure(a any) error { l.raiseError = true break } + if strings.EqualFold(v, "false") { l.raiseError = false break } + return fmt.Errorf("invalid value for %s: %s", k, v) default: return fmt.Errorf("unknown config key for %s: %s", l.Name(), k) } } + return nil } func (l *InvisibleIndexBeforeDropLinter) DefaultConfig() any { @@ -67,6 +71,7 @@ func (l *InvisibleIndexBeforeDropLinter) Lint(createTables []*statement.CreateTa if l.raiseError { severity = SeverityError } + var violations []Violation for _, stmt := range statements { diff --git a/pkg/lint/lint_invisible_index_test.go b/pkg/lint/lint_invisible_index_test.go index 291d9757..9f0ca058 100644 --- a/pkg/lint/lint_invisible_index_test.go +++ b/pkg/lint/lint_invisible_index_test.go @@ -152,7 +152,8 @@ func TestInvisibleIndexBeforeDropLinter_Integration(t *testing.T) { stmts, err := statement.New(sql) require.NoError(t, err) - violations := RunLinters(nil, stmts, Config{}) + violations, err := RunLinters(nil, stmts, Config{}) + require.NoError(t, err) require.Len(t, violations, 1) assert.Equal(t, "invisible_index_before_drop", violations[0].Linter.Name()) @@ -168,11 +169,12 @@ func TestInvisibleIndexBeforeDropLinter_IntegrationDisabled(t *testing.T) { require.NoError(t, err) // Disable the linter - violations := RunLinters(nil, stmts, Config{ + violations, err := RunLinters(nil, stmts, Config{ Enabled: map[string]bool{ "invisible_index_before_drop": false, }, }) + require.NoError(t, err) // Should not have violations when disabled assert.Empty(t, violations) diff --git a/pkg/lint/lint_multiple_alter_test.go b/pkg/lint/lint_multiple_alter_test.go index c027240c..4fd46c94 100644 --- a/pkg/lint/lint_multiple_alter_test.go +++ b/pkg/lint/lint_multiple_alter_test.go @@ -202,7 +202,8 @@ func TestMultipleAlterTableLinter_Integration(t *testing.T) { stmts, err := statement.New(sql) require.NoError(t, err) - violations := RunLinters(nil, stmts, Config{}) + violations, err := RunLinters(nil, stmts, Config{}) + require.NoError(t, err) require.Len(t, violations, 1) assert.Equal(t, "multiple_alter_table", violations[0].Linter.Name()) @@ -217,11 +218,12 @@ func TestMultipleAlterTableLinter_IntegrationDisabled(t *testing.T) { stmts, err := statement.New(sql) require.NoError(t, err) - violations := RunLinters(nil, stmts, Config{ + violations, err := RunLinters(nil, stmts, Config{ Enabled: map[string]bool{ "multiple_alter_table": false, }, }) + require.NoError(t, err) assert.Empty(t, violations) } diff --git a/pkg/lint/lint_primary_key_type_test.go b/pkg/lint/lint_primary_key_type_test.go index 1887f745..1ee93b36 100644 --- a/pkg/lint/lint_primary_key_type_test.go +++ b/pkg/lint/lint_primary_key_type_test.go @@ -280,7 +280,8 @@ func TestPrimaryKeyTypeLinter_Integration(t *testing.T) { ct, err := statement.ParseCreateTable(sql) require.NoError(t, err) - violations := RunLinters([]*statement.CreateTable{ct}, nil, Config{}) + violations, err := RunLinters([]*statement.CreateTable{ct}, nil, Config{}) + require.NoError(t, err) require.Len(t, violations, 1) assert.Equal(t, "primary_key_type", violations[0].Linter.Name()) @@ -297,11 +298,12 @@ func TestPrimaryKeyTypeLinter_IntegrationDisabled(t *testing.T) { ct, err := statement.ParseCreateTable(sql) require.NoError(t, err) - violations := RunLinters([]*statement.CreateTable{ct}, nil, Config{ + violations, err := RunLinters([]*statement.CreateTable{ct}, nil, Config{ Enabled: map[string]bool{ "primary_key_type": false, }, }) + require.NoError(t, err) assert.Empty(t, violations) } diff --git a/pkg/lint/lint_test.go b/pkg/lint/lint_test.go index 388804c0..60801eed 100644 --- a/pkg/lint/lint_test.go +++ b/pkg/lint/lint_test.go @@ -147,7 +147,8 @@ func TestGetNonexistent(t *testing.T) { func TestRunLinters_Empty(t *testing.T) { Reset() - violations := RunLinters(nil, nil, Config{}) + violations, err := RunLinters(nil, nil, Config{}) + require.NoError(t, err) assert.Empty(t, violations) } @@ -169,7 +170,8 @@ func TestRunLinters_SingleLinter(t *testing.T) { Register(linter) - violations := RunLinters(nil, nil, Config{}) + violations, err := RunLinters(nil, nil, Config{}) + require.NoError(t, err) assert.Len(t, violations, 1) assert.Equal(t, "test_linter", violations[0].Linter.Name()) assert.Equal(t, SeverityError, violations[0].Severity) @@ -197,7 +199,8 @@ func TestRunLinters_MultipleLinters(t *testing.T) { Register(linter1) Register(linter2) - violations := RunLinters(nil, nil, Config{}) + violations, err := RunLinters(nil, nil, Config{}) + require.NoError(t, err) assert.Len(t, violations, 3) } @@ -213,11 +216,12 @@ func TestRunLinters_WithConfig_Disabled(t *testing.T) { Register(linter) // Disable the linter via config - violations := RunLinters(nil, nil, Config{ + violations, err := RunLinters(nil, nil, Config{ Enabled: map[string]bool{ "test_linter": false, }, }) + require.NoError(t, err) assert.Empty(t, violations) } @@ -237,11 +241,12 @@ func TestRunLinters_WithConfig_Enabled(t *testing.T) { require.NoError(t, Disable("test_linter")) // But explicitly enable via config - violations := RunLinters(nil, nil, Config{ + violations, err := RunLinters(nil, nil, Config{ Enabled: map[string]bool{ "test_linter": true, }, }) + require.NoError(t, err) assert.Len(t, violations, 1) assert.Equal(t, "Should see this", violations[0].Message) @@ -258,11 +263,12 @@ func TestRunLinters_ConfigurableLinter(t *testing.T) { Register(linter) config := map[string]string{"key": "value"} - violations := RunLinters(nil, nil, Config{ + violations, err := RunLinters(nil, nil, Config{ Settings: map[string]any{ "configurable_linter": config, }, }) + require.NoError(t, err) assert.Len(t, violations, 1) assert.True(t, linter.configCalled) @@ -279,7 +285,8 @@ func TestRunLinters_ConfigurableLinter_NoConfig(t *testing.T) { } Register(linter) - violations := RunLinters(nil, nil, Config{}) + violations, err := RunLinters(nil, nil, Config{}) + require.NoError(t, err) assert.Len(t, violations, 1) assert.False(t, linter.configCalled) From 5b6e812d1a15161a4b0d7477b1e4df5264f4b0d9 Mon Sep 17 00:00:00 2001 From: Kolbe Kegel Date: Mon, 20 Oct 2025 22:33:50 -0700 Subject: [PATCH 08/12] migrate linter configuration to map[string]string. wrangle handling ot linter default values. --- pkg/lint/README.md | 46 ++++- pkg/lint/example/example.go | 40 ++-- pkg/lint/example/example_test.go | 32 ++-- pkg/lint/lint.go | 52 +++-- pkg/lint/lint_invisible_index.go | 29 +-- pkg/lint/lint_invisible_index_test.go | 264 ++++++++++++++++++++++++++ pkg/lint/lint_primary_key_type.go | 4 - pkg/lint/lint_test.go | 117 +++++++++++- pkg/lint/linter.go | 23 ++- pkg/lint/registry.go | 6 +- 10 files changed, 526 insertions(+), 87 deletions(-) diff --git a/pkg/lint/README.md b/pkg/lint/README.md index 9a668aac..1173ce9e 100644 --- a/pkg/lint/README.md +++ b/pkg/lint/README.md @@ -85,12 +85,14 @@ func (l *MyCustomLinter) Lint(createTables []*statement.CreateTable, alterStatem ### Configuring Linters +#### Enabling/Disabling Linters + ```go // Disable specific linters violations, err := lint.RunLinters(tables, stmts, lint.Config{ Enabled: map[string]bool{ - "table_name_length": false, - "duplicate_column": true, + "invisible_index_before_drop": false, + "primary_key_type": true, }, }) if err != nil { @@ -98,6 +100,25 @@ if err != nil { } ``` +#### Configurable Linters + +Some linters support additional configuration options via the `Settings` field. Linters that implement the `ConfigurableLinter` interface accept settings as `map[string]string`: + +```go +violations, err := lint.RunLinters(tables, stmts, lint.Config{ + Settings: map[string]map[string]string{ + "invisible_index_before_drop": { + "raiseError": "true", // Make violations errors instead of warnings + }, + }, +}) +if err != nil { + // Handle configuration errors (e.g., invalid settings) +} +``` + +Each configurable linter defines its own settings keys and values. See the individual linter documentation below for available options. + ## Core Types ### Severity Levels @@ -156,10 +177,17 @@ The `lint` package includes several linters: ### invisible_index_before_drop **Category**: schema -**Severity**: Warning +**Severity**: Warning (default), Error (configurable) +**Configurable**: Yes Requires indexes to be made invisible before dropping them as a safety measure. This ensures the index isn't needed before permanently removing it. +**Configuration Options:** + +- `raiseError` (string): Set to `"true"` to make violations errors instead of warnings. Default: `false`. + +**Example Usage:** + ```go // ❌ Violation ALTER TABLE users DROP INDEX idx_email; @@ -170,6 +198,18 @@ ALTER TABLE users ALTER INDEX idx_email INVISIBLE; ALTER TABLE users DROP INDEX idx_email; ``` +**Configuration Example:** + +```go +violations, err := lint.RunLinters(tables, stmts, lint.Config{ + Settings: map[string]map[string]string{ + "invisible_index_before_drop": { + "raiseError": "true", // Violations will be errors + }, + }, +}) +``` + ### multiple_alter_table **Category**: schema diff --git a/pkg/lint/example/example.go b/pkg/lint/example/example.go index fec9909a..c9055340 100644 --- a/pkg/lint/example/example.go +++ b/pkg/lint/example/example.go @@ -3,8 +3,8 @@ package example import ( - "errors" "fmt" + "strconv" "github.com/block/spirit/pkg/lint" "github.com/block/spirit/pkg/statement" @@ -16,11 +16,6 @@ type TableNameLengthLinter struct { maxLength int } -// TableNameLengthConfig holds configuration for the table name length linter. -type TableNameLengthConfig struct { - MaxLength int `json:"max_length"` -} - // NewTableNameLengthLinter creates a new table name length linter with default configuration. func NewTableNameLengthLinter() *TableNameLengthLinter { return &TableNameLengthLinter{ @@ -37,26 +32,33 @@ func (l *TableNameLengthLinter) Name() string { } func (l *TableNameLengthLinter) Description() string { - return "Checks that table names do not exceed the configured maximum length (default: 64 characters)" + return "Checks that table names do not exceed the configured maximum length (default: 58 characters)" } -func (l *TableNameLengthLinter) DefaultConfig() any { - return TableNameLengthConfig{ - MaxLength: 64, +func (l *TableNameLengthLinter) DefaultConfig() map[string]string { + return map[string]string{ + "maxLength": "58", } } -func (l *TableNameLengthLinter) Configure(config any) error { - cfg, ok := config.(TableNameLengthConfig) - if !ok { - return errors.New("invalid config type for table_name_length linter: expected TableNameLengthConfig") - } +func (l *TableNameLengthLinter) Configure(config map[string]string) error { + for k, v := range config { + switch k { + case "maxLength": + maxLen, err := strconv.Atoi(v) + if err != nil { + return fmt.Errorf("maxLength must be a valid integer, got %q: %w", v, err) + } - if cfg.MaxLength <= 0 { - return fmt.Errorf("max_length must be positive, got %d", cfg.MaxLength) - } + if maxLen <= 0 { + return fmt.Errorf("maxLength must be positive, got %d", maxLen) + } - l.maxLength = cfg.MaxLength + l.maxLength = maxLen + default: + return fmt.Errorf("unknown config key for %s: %s", l.Name(), k) + } + } return nil } diff --git a/pkg/lint/example/example_test.go b/pkg/lint/example/example_test.go index 24b4a9f3..c14c59f6 100644 --- a/pkg/lint/example/example_test.go +++ b/pkg/lint/example/example_test.go @@ -65,12 +65,12 @@ func TestTableNameLengthLinter_Configure(t *testing.T) { linter := NewTableNameLengthLinter() - // With default config (64), should pass + // With default config (58), should pass violations := linter.Lint([]*statement.CreateTable{ct}, nil) assert.Empty(t, violations) // Configure to max length of 40 - err = linter.Configure(TableNameLengthConfig{MaxLength: 40}) + err = linter.Configure(map[string]string{"maxLength": "40"}) require.NoError(t, err) // Now should fail @@ -82,20 +82,25 @@ func TestTableNameLengthLinter_Configure(t *testing.T) { func TestTableNameLengthLinter_Configure_InvalidConfig(t *testing.T) { linter := NewTableNameLengthLinter() - // Wrong type - err := linter.Configure("invalid") + // Invalid integer + err := linter.Configure(map[string]string{"maxLength": "invalid"}) assert.Error(t, err) - assert.Contains(t, err.Error(), "invalid config type") + assert.Contains(t, err.Error(), "must be a valid integer") // Zero length - err = linter.Configure(TableNameLengthConfig{MaxLength: 0}) + err = linter.Configure(map[string]string{"maxLength": "0"}) assert.Error(t, err) assert.Contains(t, err.Error(), "must be positive") // Negative length - err = linter.Configure(TableNameLengthConfig{MaxLength: -1}) + err = linter.Configure(map[string]string{"maxLength": "-1"}) assert.Error(t, err) assert.Contains(t, err.Error(), "must be positive") + + // Unknown key + err = linter.Configure(map[string]string{"unknownKey": "value"}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "unknown config key") } func TestTableNameLengthLinter_DefaultConfig(t *testing.T) { @@ -103,10 +108,7 @@ func TestTableNameLengthLinter_DefaultConfig(t *testing.T) { config := linter.DefaultConfig() require.NotNil(t, config) - - cfg, ok := config.(TableNameLengthConfig) - require.True(t, ok) - assert.Equal(t, 64, cfg.MaxLength) + assert.Equal(t, "58", config["maxLength"]) } func TestDuplicateColumnLinter_NoDuplicates(t *testing.T) { @@ -231,15 +233,17 @@ func TestTableNameLengthLinter_WithConfigSettings(t *testing.T) { ct, err := statement.ParseCreateTable(sql) require.NoError(t, err) - // With default config (64), should pass + // With default config (58), should pass violations, err := lint.RunLinters([]*statement.CreateTable{ct}, nil, lint.Config{}) require.NoError(t, err) assert.Empty(t, violations) // Configure max length to 40 via Config.Settings violations, err = lint.RunLinters([]*statement.CreateTable{ct}, nil, lint.Config{ - Settings: map[string]any{ - "table_name_length": TableNameLengthConfig{MaxLength: 40}, + Settings: map[string]map[string]string{ + "table_name_length": { + "maxLength": "40", + }, }, }) require.NoError(t, err) diff --git a/pkg/lint/lint.go b/pkg/lint/lint.go index 37b5081c..ea8a3a48 100644 --- a/pkg/lint/lint.go +++ b/pkg/lint/lint.go @@ -12,7 +12,7 @@ // package naming // // func init() { -// lint.Register(&TableNameLinter{}) +// lint.Register(TableNameLinter{}) // } // // // Later, run all linters: @@ -44,7 +44,16 @@ // } // // Configurable linters can implement the ConfigurableLinter interface to accept -// custom settings via Config.Settings. +// custom settings via Config.Settings. Settings must be provided as map[string]string: +// +// config := lint.Config{ +// Settings: map[string]map[string]string{ +// "my_linter": { +// "option1": "value1", +// "option2": "value2", +// }, +// }, +// } package lint import ( @@ -61,9 +70,9 @@ type Config struct { // If a linter is not in this map, it uses its default enabled state Enabled map[string]bool - // Settings maps linter names to their configuration - // The configuration type is linter-specific - Settings map[string]any + // Settings maps linter names to their configuration as map[string]string + // Each linter's settings are provided as key-value string pairs + Settings map[string]map[string]string } // RunLinters runs all enabled linters and returns any violations found. @@ -104,22 +113,35 @@ func RunLinters(createTables []*statement.CreateTable, alterStatements []*statem } // Apply configuration if available - if configurableLinter, ok := linter.impl.(ConfigurableLinter); ok { + if configurableLinter, ok := linter.l.(ConfigurableLinter); ok { + // Start with default config + defaultConfig := configurableLinter.DefaultConfig() + + // Merge user settings with defaults (user settings override defaults) + finalConfig := make(map[string]string) + for k, v := range defaultConfig { + finalConfig[k] = v + } + if settings, ok := config.Settings[name]; ok { - err := configurableLinter.Configure(settings) - if err != nil { - // Configuration error - skip this linter - // In a production system, we might want to log this - fmt.Fprintf(os.Stderr, "Error configuring %s: %s\n", name, err) - errs = append(errs, err) - - continue + for k, v := range settings { + finalConfig[k] = v } } + + // Apply the merged configuration + err := configurableLinter.Configure(finalConfig) + if err != nil { + // Configuration error - skip this linter + fmt.Fprintf(os.Stderr, "Error configuring %s: %s\n", name, err) + errs = append(errs, err) + + continue + } } // Run the linter - lintViolations := linter.impl.Lint(createTables, alterStatements) + lintViolations := linter.l.Lint(createTables, alterStatements) violations = append(violations, lintViolations...) } diff --git a/pkg/lint/lint_invisible_index.go b/pkg/lint/lint_invisible_index.go index fb2e8904..30f24074 100644 --- a/pkg/lint/lint_invisible_index.go +++ b/pkg/lint/lint_invisible_index.go @@ -1,9 +1,7 @@ package lint import ( - "errors" "fmt" - "strings" "github.com/block/spirit/pkg/statement" "github.com/pingcap/tidb/pkg/parser/ast" @@ -31,26 +29,16 @@ func (l *InvisibleIndexBeforeDropLinter) Description() string { return "Requires indexes to be made invisible before dropping them as a safety measure" } -func (l *InvisibleIndexBeforeDropLinter) Configure(a any) error { - c, ok := a.(map[string]string) - if !ok { - return errors.New(l.Name() + " config must be a map[string]string") - } - - for k, v := range c { +func (l *InvisibleIndexBeforeDropLinter) Configure(config map[string]string) error { + for k, v := range config { switch k { case "raiseError": - if strings.EqualFold(v, "true") { - l.raiseError = true - break + boolVal, err := ConfigBool(v, k) + if err != nil { + return err } - if strings.EqualFold(v, "false") { - l.raiseError = false - break - } - - return fmt.Errorf("invalid value for %s: %s", k, v) + l.raiseError = boolVal default: return fmt.Errorf("unknown config key for %s: %s", l.Name(), k) } @@ -58,9 +46,10 @@ func (l *InvisibleIndexBeforeDropLinter) Configure(a any) error { return nil } -func (l *InvisibleIndexBeforeDropLinter) DefaultConfig() any { + +func (l *InvisibleIndexBeforeDropLinter) DefaultConfig() map[string]string { return map[string]string{ - "raiseError": "true", + "raiseError": "false", } } diff --git a/pkg/lint/lint_invisible_index_test.go b/pkg/lint/lint_invisible_index_test.go index 9f0ca058..1606d0f3 100644 --- a/pkg/lint/lint_invisible_index_test.go +++ b/pkg/lint/lint_invisible_index_test.go @@ -186,3 +186,267 @@ func TestInvisibleIndexBeforeDropLinter_Metadata(t *testing.T) { assert.Equal(t, "invisible_index_before_drop", linter.Name()) assert.NotEmpty(t, linter.Description()) } + +// Configuration Tests + +func TestInvisibleIndexBeforeDropLinter_Configure_ValidRaiseErrorTrue(t *testing.T) { + linter := &InvisibleIndexBeforeDropLinter{} + + config := map[string]string{ + "raiseError": "true", + } + + err := linter.Configure(config) + require.NoError(t, err) + assert.True(t, linter.raiseError) +} + +func TestInvisibleIndexBeforeDropLinter_Configure_ValidRaiseErrorFalse(t *testing.T) { + linter := &InvisibleIndexBeforeDropLinter{} + + config := map[string]string{ + "raiseError": "false", + } + + err := linter.Configure(config) + require.NoError(t, err) + assert.False(t, linter.raiseError) +} + +func TestInvisibleIndexBeforeDropLinter_Configure_CaseInsensitive(t *testing.T) { + tests := []struct { + name string + value string + expected bool + }{ + {"lowercase true", "true", true}, + {"uppercase TRUE", "TRUE", true}, + {"mixed True", "True", true}, + {"lowercase false", "false", false}, + {"uppercase FALSE", "FALSE", false}, + {"mixed False", "False", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + linter := &InvisibleIndexBeforeDropLinter{} + + config := map[string]string{ + "raiseError": tt.value, + } + + err := linter.Configure(config) + require.NoError(t, err) + assert.Equal(t, tt.expected, linter.raiseError) + }) + } +} + +func TestInvisibleIndexBeforeDropLinter_Configure_InvalidRaiseErrorValue(t *testing.T) { + linter := &InvisibleIndexBeforeDropLinter{} + + config := map[string]string{ + "raiseError": "invalid", + } + + err := linter.Configure(config) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid value for raiseError") +} + +func TestInvisibleIndexBeforeDropLinter_Configure_UnknownKey(t *testing.T) { + linter := &InvisibleIndexBeforeDropLinter{} + + config := map[string]string{ + "unknownKey": "value", + } + + err := linter.Configure(config) + require.Error(t, err) + assert.Contains(t, err.Error(), "unknown config key") + assert.Contains(t, err.Error(), "unknownKey") +} + +func TestInvisibleIndexBeforeDropLinter_Configure_MultipleKeys(t *testing.T) { + linter := &InvisibleIndexBeforeDropLinter{} + + config := map[string]string{ + "raiseError": "false", + "unknownKey": "value", + } + + err := linter.Configure(config) + require.Error(t, err) + assert.Contains(t, err.Error(), "unknown config key") +} + +func TestInvisibleIndexBeforeDropLinter_DefaultConfig(t *testing.T) { + linter := &InvisibleIndexBeforeDropLinter{} + + defaultConfig := linter.DefaultConfig() + require.NotNil(t, defaultConfig) + assert.Equal(t, "false", defaultConfig["raiseError"]) +} + +// Functional Tests with Configuration + +func TestInvisibleIndexBeforeDropLinter_RaiseErrorTrue_ProducesError(t *testing.T) { + sql := "ALTER TABLE users DROP INDEX idx_email" + stmts, err := statement.New(sql) + require.NoError(t, err) + require.Len(t, stmts, 1) + + linter := &InvisibleIndexBeforeDropLinter{} + err = linter.Configure(map[string]string{"raiseError": "true"}) + require.NoError(t, err) + + violations := linter.Lint(nil, stmts) + + require.Len(t, violations, 1) + assert.Equal(t, SeverityError, violations[0].Severity) +} + +func TestInvisibleIndexBeforeDropLinter_RaiseErrorFalse_ProducesWarning(t *testing.T) { + sql := "ALTER TABLE users DROP INDEX idx_email" + stmts, err := statement.New(sql) + require.NoError(t, err) + require.Len(t, stmts, 1) + + linter := &InvisibleIndexBeforeDropLinter{} + err = linter.Configure(map[string]string{"raiseError": "false"}) + require.NoError(t, err) + + violations := linter.Lint(nil, stmts) + + require.Len(t, violations, 1) + assert.Equal(t, SeverityWarning, violations[0].Severity) +} + +func TestInvisibleIndexBeforeDropLinter_DefaultBehavior(t *testing.T) { + // Without configuration, default should be raiseError=true (warning) + sql := "ALTER TABLE users DROP INDEX idx_email" + stmts, err := statement.New(sql) + require.NoError(t, err) + require.Len(t, stmts, 1) + + linter := &InvisibleIndexBeforeDropLinter{} + violations := linter.Lint(nil, stmts) + + require.Len(t, violations, 1) + // Default behavior is warning (raiseError defaults to false in struct) + assert.Equal(t, SeverityWarning, violations[0].Severity) +} + +// Integration Tests with RunLinters + +func TestInvisibleIndexBeforeDropLinter_IntegrationWithConfig_RaiseErrorTrue(t *testing.T) { + Reset() + Register(&InvisibleIndexBeforeDropLinter{}) + + sql := "ALTER TABLE users DROP INDEX idx_email" + stmts, err := statement.New(sql) + require.NoError(t, err) + + violations, err := RunLinters(nil, stmts, Config{ + Settings: map[string]map[string]string{ + "invisible_index_before_drop": { + "raiseError": "true", + }, + }, + }) + require.NoError(t, err) + + require.Len(t, violations, 1) + assert.Equal(t, SeverityError, violations[0].Severity) +} + +func TestInvisibleIndexBeforeDropLinter_IntegrationWithConfig_RaiseErrorFalse(t *testing.T) { + Reset() + Register(&InvisibleIndexBeforeDropLinter{}) + + sql := "ALTER TABLE users DROP INDEX idx_email" + stmts, err := statement.New(sql) + require.NoError(t, err) + + violations, err := RunLinters(nil, stmts, Config{ + Settings: map[string]map[string]string{ + "invisible_index_before_drop": { + "raiseError": "false", + }, + }, + }) + require.NoError(t, err) + + require.Len(t, violations, 1) + assert.Equal(t, SeverityWarning, violations[0].Severity) +} + +func TestInvisibleIndexBeforeDropLinter_IntegrationWithConfig_InvalidConfig(t *testing.T) { + Reset() + Register(&InvisibleIndexBeforeDropLinter{}) + + sql := "ALTER TABLE users DROP INDEX idx_email" + stmts, err := statement.New(sql) + require.NoError(t, err) + + // Invalid configuration should result in error + violations, err := RunLinters(nil, stmts, Config{ + Settings: map[string]map[string]string{ + "invisible_index_before_drop": { + "raiseError": "invalid_value", + }, + }, + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid value for raiseError") + // Linter should be skipped due to configuration error + assert.Empty(t, violations) +} + +func TestInvisibleIndexBeforeDropLinter_IntegrationWithConfig_DisabledLinter(t *testing.T) { + Reset() + Register(&InvisibleIndexBeforeDropLinter{}) + + sql := "ALTER TABLE users DROP INDEX idx_email" + stmts, err := statement.New(sql) + require.NoError(t, err) + + // Even with configuration, disabled linter should not run + violations, err := RunLinters(nil, stmts, Config{ + Enabled: map[string]bool{ + "invisible_index_before_drop": false, + }, + Settings: map[string]map[string]string{ + "invisible_index_before_drop": { + "raiseError": "true", + }, + }, + }) + require.NoError(t, err) + assert.Empty(t, violations) +} + +func TestInvisibleIndexBeforeDropLinter_IntegrationWithConfig_EnabledWithConfig(t *testing.T) { + Reset() + Register(&InvisibleIndexBeforeDropLinter{}) + + sql := "ALTER TABLE users DROP INDEX idx_email" + stmts, err := statement.New(sql) + require.NoError(t, err) + + // Explicitly enabled with custom configuration + violations, err := RunLinters(nil, stmts, Config{ + Enabled: map[string]bool{ + "invisible_index_before_drop": true, + }, + Settings: map[string]map[string]string{ + "invisible_index_before_drop": { + "raiseError": "false", + }, + }, + }) + require.NoError(t, err) + + require.Len(t, violations, 1) + assert.Equal(t, SeverityWarning, violations[0].Severity) +} diff --git a/pkg/lint/lint_primary_key_type.go b/pkg/lint/lint_primary_key_type.go index b0e81d52..8c38300c 100644 --- a/pkg/lint/lint_primary_key_type.go +++ b/pkg/lint/lint_primary_key_type.go @@ -143,9 +143,5 @@ func (l *PrimaryKeyTypeLinter) isBinaryType(column *statement.Column) bool { // VARBINARY is mysql.TypeVarchar with binary flag rawType := column.Raw.Tp.GetType() - // Check if it's a string type with binary flag (BINARY/VARBINARY) - - fmt.Printf("Debug: type=%s rawType=%d, flags=%d, options=%#v\n", column.Type, rawType, column.Raw.Tp.GetFlag(), column.Options) - return (rawType == mysql.TypeString || rawType == mysql.TypeVarchar) && mysql.HasBinaryFlag(column.Raw.Tp.GetFlag()) } diff --git a/pkg/lint/lint_test.go b/pkg/lint/lint_test.go index 60801eed..04c5f86f 100644 --- a/pkg/lint/lint_test.go +++ b/pkg/lint/lint_test.go @@ -31,7 +31,7 @@ type mockConfigurableLinter struct { mockLinter configCalled bool - configValue any + configValue map[string]string } func (m *mockConfigurableLinter) String() string { @@ -39,15 +39,17 @@ func (m *mockConfigurableLinter) String() string { panic("implement me") } -func (m *mockConfigurableLinter) Configure(config any) error { +func (m *mockConfigurableLinter) Configure(config map[string]string) error { m.configCalled = true m.configValue = config return nil } -func (m *mockConfigurableLinter) DefaultConfig() any { - return "default" +func (m *mockConfigurableLinter) DefaultConfig() map[string]string { + return map[string]string{ + "default": "value", + } } func TestRegister(t *testing.T) { @@ -264,7 +266,7 @@ func TestRunLinters_ConfigurableLinter(t *testing.T) { config := map[string]string{"key": "value"} violations, err := RunLinters(nil, nil, Config{ - Settings: map[string]any{ + Settings: map[string]map[string]string{ "configurable_linter": config, }, }) @@ -272,7 +274,12 @@ func TestRunLinters_ConfigurableLinter(t *testing.T) { assert.Len(t, violations, 1) assert.True(t, linter.configCalled) - assert.Equal(t, config, linter.configValue) + // User config should be merged with defaults + expected := map[string]string{ + "default": "value", // from DefaultConfig + "key": "value", // from user config + } + assert.Equal(t, expected, linter.configValue) } func TestRunLinters_ConfigurableLinter_NoConfig(t *testing.T) { @@ -289,7 +296,10 @@ func TestRunLinters_ConfigurableLinter_NoConfig(t *testing.T) { require.NoError(t, err) assert.Len(t, violations, 1) - assert.False(t, linter.configCalled) + // Now Configure is always called (with defaults) + assert.True(t, linter.configCalled) + // Should have received the default config + assert.Equal(t, map[string]string{"default": "value"}, linter.configValue) } func TestHasErrors(t *testing.T) { @@ -440,3 +450,96 @@ func TestViolationWithContext(t *testing.T) { assert.Equal(t, "value1", violation.Context["key1"]) assert.Equal(t, 42, violation.Context["key2"]) } + +// ConfigBool tests + +func TestConfigBool_ValidTrue(t *testing.T) { + tests := []string{"true", "TRUE", "True", "TrUe"} + for _, value := range tests { + t.Run(value, func(t *testing.T) { + result, err := ConfigBool(value, "testKey") + require.NoError(t, err) + assert.True(t, result) + }) + } +} + +func TestConfigBool_ValidFalse(t *testing.T) { + tests := []string{"false", "FALSE", "False", "FaLsE"} + for _, value := range tests { + t.Run(value, func(t *testing.T) { + result, err := ConfigBool(value, "testKey") + require.NoError(t, err) + assert.False(t, result) + }) + } +} + +func TestConfigBool_Invalid(t *testing.T) { + tests := []struct { + value string + key string + }{ + {"yes", "testKey"}, + {"no", "testKey"}, + {"1", "testKey"}, + {"0", "testKey"}, + {"True ", "testKey"}, // trailing space + {" true", "testKey"}, // leading space + {"", "testKey"}, + {"invalid", "myOption"}, + } + + for _, tt := range tests { + t.Run(tt.value, func(t *testing.T) { + result, err := ConfigBool(tt.value, tt.key) + require.Error(t, err) + assert.False(t, result) + assert.Contains(t, err.Error(), "invalid value for "+tt.key) + assert.Contains(t, err.Error(), tt.value) + assert.Contains(t, err.Error(), "expected 'true' or 'false'") + }) + } +} + +// DefaultConfig tests + +func TestRunLinters_AppliesDefaultConfig(t *testing.T) { + Reset() + Register(&InvisibleIndexBeforeDropLinter{}) + + sql := "ALTER TABLE users DROP INDEX idx_email" + stmts, err := statement.New(sql) + require.NoError(t, err) + + // Run without any config - should apply default (raiseError=false) + violations, err := RunLinters(nil, stmts, Config{}) + require.NoError(t, err) + + require.Len(t, violations, 1) + // Default raiseError is "false", so severity should be Warning + assert.Equal(t, SeverityWarning, violations[0].Severity) +} + +func TestRunLinters_UserConfigOverridesDefault(t *testing.T) { + Reset() + Register(&InvisibleIndexBeforeDropLinter{}) + + sql := "ALTER TABLE users DROP INDEX idx_email" + stmts, err := statement.New(sql) + require.NoError(t, err) + + // Override default raiseError=false with true + violations, err := RunLinters(nil, stmts, Config{ + Settings: map[string]map[string]string{ + "invisible_index_before_drop": { + "raiseError": "true", + }, + }, + }) + require.NoError(t, err) + + require.Len(t, violations, 1) + // User set raiseError=true, so severity should be Error + assert.Equal(t, SeverityError, violations[0].Severity) +} diff --git a/pkg/lint/linter.go b/pkg/lint/linter.go index e780928c..f24360a0 100644 --- a/pkg/lint/linter.go +++ b/pkg/lint/linter.go @@ -1,6 +1,9 @@ package lint import ( + "fmt" + "strings" + "github.com/block/spirit/pkg/statement" ) @@ -25,10 +28,11 @@ type ConfigurableLinter interface { Linter // Configure applies configuration to the linter - Configure(config any) error + // Configuration is provided as a map of string keys to string values + Configure(config map[string]string) error // DefaultConfig returns the default configuration for this linter - DefaultConfig() any + DefaultConfig() map[string]string } // Stringer returns a string representation of the linter @@ -36,3 +40,18 @@ type ConfigurableLinter interface { func Stringer(l Linter) string { return l.Name() + " - " + l.Description() } + +// ConfigBool parses a boolean configuration value from a string. +// It accepts "true" or "false" (case-insensitive) and returns an error for invalid values. +// The key parameter is used in error messages to provide context. +func ConfigBool(value string, key string) (bool, error) { + if strings.EqualFold(value, "true") { + return true, nil + } + + if strings.EqualFold(value, "false") { + return false, nil + } + + return false, fmt.Errorf("invalid value for %s: %s (expected 'true' or 'false')", key, value) +} diff --git a/pkg/lint/registry.go b/pkg/lint/registry.go index d6a5d1c5..dffe8632 100644 --- a/pkg/lint/registry.go +++ b/pkg/lint/registry.go @@ -8,7 +8,7 @@ import ( // linter represents a registered linter with metadata type linter struct { - impl Linter + l Linter enabled bool } @@ -29,7 +29,7 @@ func Register(l Linter) { } linters[l.Name()] = linter{ - impl: l, + l: l, enabled: true, } } @@ -98,7 +98,7 @@ func Get(name string) (Linter, error) { return nil, fmt.Errorf("linter %q not found", name) } - return l.impl, nil + return l.l, nil } // Reset clears all registered linters. From c59d6356fc3c108cc6c4c3146f36cc6cba1876f0 Mon Sep 17 00:00:00 2001 From: Kolbe Kegel Date: Mon, 20 Oct 2025 10:28:54 -0600 Subject: [PATCH 09/12] add first iteration of command-line lint tool # Conflicts: # pkg/lint/lint.go --- cmd/lint/lint.go | 15 +++++++++++++++ pkg/statement/parse_create_table.go | 2 +- 2 files changed, 16 insertions(+), 1 deletion(-) create mode 100644 cmd/lint/lint.go diff --git a/cmd/lint/lint.go b/cmd/lint/lint.go new file mode 100644 index 00000000..5484d02e --- /dev/null +++ b/cmd/lint/lint.go @@ -0,0 +1,15 @@ +package main + +import ( + "github.com/alecthomas/kong" + "github.com/block/spirit/pkg/lint" +) + +var cli struct { + lint.Lint `cmd:"" help:"Lint CREATE TABLE and ALTER TABLE statements."` +} + +func main() { + ctx := kong.Parse(&cli) + ctx.FatalIfErrorf(ctx.Run()) +} diff --git a/pkg/statement/parse_create_table.go b/pkg/statement/parse_create_table.go index 83e6bcaf..2029cf46 100644 --- a/pkg/statement/parse_create_table.go +++ b/pkg/statement/parse_create_table.go @@ -171,7 +171,7 @@ func ParseCreateTable(sql string) (*CreateTable, error) { stmts, _, err := p.Parse(sql, "", "") if err != nil { - return nil, fmt.Errorf("failed to parse SQL: %w", err) + return nil, fmt.Errorf("failed to parse SQL %q: %w", sql, err) } if len(stmts) != 1 { From a1ff40dc58cdb1260a28af7233701cc4149d4506 Mon Sep 17 00:00:00 2001 From: Kolbe Kegel Date: Mon, 20 Oct 2025 12:47:23 -0600 Subject: [PATCH 10/12] iterating lint command --- pkg/lint/cmd.go | 256 +++++++++++++++++++++++++++ pkg/lint/cmd_test.go | 404 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 660 insertions(+) create mode 100644 pkg/lint/cmd.go create mode 100644 pkg/lint/cmd_test.go diff --git a/pkg/lint/cmd.go b/pkg/lint/cmd.go new file mode 100644 index 00000000..bad578c9 --- /dev/null +++ b/pkg/lint/cmd.go @@ -0,0 +1,256 @@ +package lint + +import ( + "errors" + "fmt" + "io" + "os" + "path/filepath" + "strings" + + "github.com/block/spirit/pkg/statement" +) + +// StatementSource represents a single source of SQL statements. +type StatementSource struct { + // Origin describes where this SQL came from + // For files: "file:" + file path (e.g., "file:migrations/001.sql") + // For stdin: "stdin" + // For command-line: "cmdline" + Origin string + + // SQL contains the actual SQL content + SQL string +} + +// resolveStatement takes a single --statement argument and returns one or more StatementSources. +// - Inline SQL → 1 StatementSource with Origin="cmdline" +// - "-" (stdin) → 1 StatementSource with Origin="stdin" +// - "file:path.sql" → 1 StatementSource with Origin="file:path.sql" +// - "file:dir/" → N StatementSources (one per .sql file in directory, recursively) +// - "file:*.sql" → N StatementSources (one per matching file) +func resolveStatement(arg string) ([]StatementSource, error) { + // Check for stdin + if arg == "-" { + content, err := io.ReadAll(os.Stdin) + if err != nil { + return nil, fmt.Errorf("failed to read from stdin: %w", err) + } + return []StatementSource{{ + Origin: "stdin", + SQL: string(content), + }}, nil + } + + // Check for file: prefix + if strings.HasPrefix(arg, "file:") { + path := strings.TrimPrefix(arg, "file:") + + // Check if it's a glob pattern (contains wildcard characters) + if strings.ContainsAny(path, "*?[]") { + return resolveGlob(path) + } + + // Try to stat the path to determine if it's a file or directory + info, err := os.Stat(path) + if err != nil { + return nil, fmt.Errorf("failed to access %s: %w", path, err) + } + + if info.IsDir() { + return resolveDirectory(path) + } + + return resolveFile(path) + } + + // Default to command-line SQL + return []StatementSource{{ + Origin: "cmdline", + SQL: arg, + }}, nil +} + +// resolveFile reads a single SQL file and returns a StatementSource +func resolveFile(path string) ([]StatementSource, error) { + content, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("failed to read file %s: %w", path, err) + } + + return []StatementSource{{ + Origin: "file:" + path, + SQL: string(content), + }}, nil +} + +// resolveDirectory recursively finds all .sql files in a directory and returns StatementSources +func resolveDirectory(dir string) ([]StatementSource, error) { + var sources []StatementSource + + err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + // Skip directories and non-.sql files + if info.IsDir() || !strings.HasSuffix(strings.ToLower(path), ".sql") { + return nil + } + + content, err := os.ReadFile(path) + if err != nil { + return fmt.Errorf("failed to read file %s: %w", path, err) + } + + sources = append(sources, StatementSource{ + Origin: "file:" + path, + SQL: string(content), + }) + + return nil + }) + + if err != nil { + return nil, err + } + + if len(sources) == 0 { + return nil, fmt.Errorf("no .sql files found in directory: %s", dir) + } + + return sources, nil +} + +// resolveGlob expands a glob pattern and returns StatementSources for all matching files +func resolveGlob(pattern string) ([]StatementSource, error) { + matches, err := filepath.Glob(pattern) + if err != nil { + return nil, fmt.Errorf("invalid glob pattern %s: %w", pattern, err) + } + + if len(matches) == 0 { + return nil, fmt.Errorf("no files matched glob pattern: %s", pattern) + } + + var sources []StatementSource + + for _, path := range matches { + // Skip directories + info, err := os.Stat(path) + if err != nil { + return nil, fmt.Errorf("failed to stat file %s: %w", path, err) + } + if info.IsDir() { + continue + } + + content, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("failed to read file %s: %w", path, err) + } + + sources = append(sources, StatementSource{ + Origin: "file:" + path, + SQL: string(content), + }) + } + + if len(sources) == 0 { + return nil, fmt.Errorf("glob pattern matched only directories: %s", pattern) + } + + return sources, nil +} + +// parseStatementSource parses a single StatementSource and extracts CREATE TABLE and ALTER TABLE statements. +// Returns the parsed statements and any error encountered. +// Note: Due to limitations in statement.New(), a single source cannot contain both CREATE TABLE and ALTER TABLE statements. +func parseStatementSource(source StatementSource) ([]*statement.CreateTable, []*statement.AbstractStatement, error) { + sql := strings.TrimSpace(source.SQL) + if sql == "" { + return nil, nil, nil // Empty source is OK + } + + var createTables []*statement.CreateTable + var alterStatements []*statement.AbstractStatement + + // Parse all statements + stmts, err := statement.New(sql) + if err != nil { + return nil, nil, fmt.Errorf("failed to parse %s: %w", source.Origin, err) + } + + // Categorize statements + for _, stmt := range stmts { + if stmt.IsAlterTable() { + alterStatements = append(alterStatements, stmt) + } else { + // It's a CREATE TABLE, parse into structured format + ct, err := statement.ParseCreateTable(stmt.Statement) + if err != nil { + return nil, nil, fmt.Errorf("failed to parse CREATE TABLE from %s: %w", source.Origin, err) + } + createTables = append(createTables, ct) + } + } + + return createTables, alterStatements, nil +} + +// Lint is the struct for the lint command +type Lint struct { + Statement []string `help:"CREATE TABLE and ALTER TABLE statements to lint" sep:"none"` + Linters []string `help:"Specific linters to run (default: all)" default:"all"` + Config []string `help:"Individual linter configuration properties"` +} + +func (l *Lint) Run() error { + var lintConfig Config + foundViolations := false + + if len(l.Statement) == 0 { + return errors.New("must specify at least one statement to lint") + } + + // Resolve all statement arguments into sources + var sources []StatementSource + for _, arg := range l.Statement { + s, err := resolveStatement(arg) + if err != nil { + return err + } + sources = append(sources, s...) + } + + // Parse each source + for _, source := range sources { + createTables, alterStatements, err := parseStatementSource(source) + if err != nil { + return err + } + if len(createTables) == 0 && len(alterStatements) == 0 { + fmt.Fprintf(os.Stderr, "Warning: no valid statements found in %s, skipping\n", source.Origin) + continue // No valid statements in this source + } + + // Run linters + violations := RunLinters(createTables, alterStatements, lintConfig) + + if len(violations) == 0 { + fmt.Printf("No lint violations found in %q\n", source.Origin) + continue + } + fmt.Printf("Found lint violations found in %q:\n", source.Origin) + foundViolations = true + + for _, v := range violations { + fmt.Println(v.String()) + } + } + + if foundViolations { + return errors.New("lint violations found") + } + return nil +} diff --git a/pkg/lint/cmd_test.go b/pkg/lint/cmd_test.go new file mode 100644 index 00000000..9fa441e9 --- /dev/null +++ b/pkg/lint/cmd_test.go @@ -0,0 +1,404 @@ +package lint + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestResolveStatement_Cmdline(t *testing.T) { + tests := []struct { + name string + arg string + }{ + { + name: "CREATE TABLE", + arg: "CREATE TABLE users (id BIGINT PRIMARY KEY)", + }, + { + name: "ALTER TABLE", + arg: "ALTER TABLE users ADD COLUMN email VARCHAR(255)", + }, + { + name: "multiline SQL", + arg: "CREATE TABLE users (\n id BIGINT PRIMARY KEY\n)", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sources, err := resolveStatement(tt.arg) + require.NoError(t, err) + require.Len(t, sources, 1) + assert.Equal(t, "cmdline", sources[0].Origin) + assert.Equal(t, tt.arg, sources[0].SQL) + }) + } +} + +func TestResolveStatement_Stdin(t *testing.T) { + // Mock stdin + oldStdin := os.Stdin + defer func() { os.Stdin = oldStdin }() + + r, w, _ := os.Pipe() + os.Stdin = r + + sql := "CREATE TABLE users (id BIGINT PRIMARY KEY)" + go func() { + w.Write([]byte(sql)) + w.Close() + }() + + sources, err := resolveStatement("-") + require.NoError(t, err) + require.Len(t, sources, 1) + assert.Equal(t, "stdin", sources[0].Origin) + assert.Equal(t, sql, sources[0].SQL) +} + +func TestResolveStatement_File(t *testing.T) { + // Create a temporary file + tmpfile, err := os.CreateTemp("", "test_*.sql") + require.NoError(t, err) + defer os.Remove(tmpfile.Name()) + + sql := "CREATE TABLE users (id BIGINT PRIMARY KEY)" + _, err = tmpfile.Write([]byte(sql)) + require.NoError(t, err) + tmpfile.Close() + + sources, err := resolveStatement("file:" + tmpfile.Name()) + require.NoError(t, err) + require.Len(t, sources, 1) + assert.Equal(t, "file:"+tmpfile.Name(), sources[0].Origin) + assert.Equal(t, sql, sources[0].SQL) +} + +func TestResolveStatement_FileNotExists(t *testing.T) { + sources, err := resolveStatement("file:/nonexistent/path/to/file.sql") + assert.Error(t, err) + assert.Nil(t, sources) + assert.Contains(t, err.Error(), "failed to access") +} + +func TestResolveStatement_Directory(t *testing.T) { + // Create a temporary directory with SQL files + tmpdir, err := os.MkdirTemp("", "test_migrations_") + require.NoError(t, err) + defer os.RemoveAll(tmpdir) + + // Create some files + sql1 := "CREATE TABLE users (id INT)" + sql2 := "CREATE TABLE orders (id INT)" + os.WriteFile(filepath.Join(tmpdir, "001_users.sql"), []byte(sql1), 0644) + os.WriteFile(filepath.Join(tmpdir, "002_orders.sql"), []byte(sql2), 0644) + os.WriteFile(filepath.Join(tmpdir, "README.md"), []byte("# Migrations"), 0644) + + // Create a subdirectory with a file + subdir := filepath.Join(tmpdir, "archived") + os.Mkdir(subdir, 0755) + sql3 := "CREATE TABLE old (id INT)" + os.WriteFile(filepath.Join(subdir, "old.sql"), []byte(sql3), 0644) + + sources, err := resolveStatement("file:" + tmpdir) + require.NoError(t, err) + require.Len(t, sources, 3) // Should find all 3 .sql files recursively + + // Verify origins have file: prefix + for _, source := range sources { + assert.True(t, strings.HasPrefix(source.Origin, "file:")) + assert.True(t, strings.HasSuffix(source.Origin, ".sql")) + } + + // Verify we got the right content + sqlContents := []string{sources[0].SQL, sources[1].SQL, sources[2].SQL} + assert.Contains(t, sqlContents, sql1) + assert.Contains(t, sqlContents, sql2) + assert.Contains(t, sqlContents, sql3) +} + +func TestResolveStatement_DirectoryEmpty(t *testing.T) { + tmpdir, err := os.MkdirTemp("", "test_empty_") + require.NoError(t, err) + defer os.RemoveAll(tmpdir) + + sources, err := resolveStatement("file:" + tmpdir) + assert.Error(t, err) + assert.Nil(t, sources) + assert.Contains(t, err.Error(), "no .sql files found") +} + +func TestResolveStatement_Glob(t *testing.T) { + // Create a temporary directory with SQL files + tmpdir, err := os.MkdirTemp("", "test_migrations_") + require.NoError(t, err) + defer os.RemoveAll(tmpdir) + + // Create some files + sql1 := "CREATE TABLE users (id INT)" + sql2 := "CREATE TABLE orders (id INT)" + sql3 := "CREATE TABLE products (id INT)" + os.WriteFile(filepath.Join(tmpdir, "001_users.sql"), []byte(sql1), 0644) + os.WriteFile(filepath.Join(tmpdir, "002_orders.sql"), []byte(sql2), 0644) + os.WriteFile(filepath.Join(tmpdir, "003_products.sql"), []byte(sql3), 0644) + os.WriteFile(filepath.Join(tmpdir, "README.md"), []byte("# Migrations"), 0644) + + // Test glob pattern + pattern := "file:" + filepath.Join(tmpdir, "*.sql") + sources, err := resolveStatement(pattern) + require.NoError(t, err) + require.Len(t, sources, 3) // Should find all 3 .sql files + + // Verify origins + for _, source := range sources { + assert.True(t, strings.HasPrefix(source.Origin, "file:")) + assert.True(t, strings.HasSuffix(source.Origin, ".sql")) + } +} + +func TestResolveStatement_GlobNoMatches(t *testing.T) { + tmpdir, err := os.MkdirTemp("", "test_migrations_") + require.NoError(t, err) + defer os.RemoveAll(tmpdir) + + pattern := "file:" + filepath.Join(tmpdir, "*.sql") + sources, err := resolveStatement(pattern) + assert.Error(t, err) + assert.Nil(t, sources) + assert.Contains(t, err.Error(), "no files matched glob pattern") +} + +func TestResolveStatement_GlobWithPattern(t *testing.T) { + tmpdir, err := os.MkdirTemp("", "test_migrations_") + require.NoError(t, err) + defer os.RemoveAll(tmpdir) + + // Create files with different patterns + os.WriteFile(filepath.Join(tmpdir, "001_users.sql"), []byte("CREATE TABLE users (id INT)"), 0644) + os.WriteFile(filepath.Join(tmpdir, "002_orders.sql"), []byte("CREATE TABLE orders (id INT)"), 0644) + os.WriteFile(filepath.Join(tmpdir, "999_old.sql"), []byte("CREATE TABLE old (id INT)"), 0644) + + // Test pattern that matches only 001 and 002 + pattern := "file:" + filepath.Join(tmpdir, "00[12]*.sql") + sources, err := resolveStatement(pattern) + require.NoError(t, err) + require.Len(t, sources, 2) +} + +func TestResolveStatement_GlobSkipsDirectories(t *testing.T) { + tmpdir, err := os.MkdirTemp("", "test_migrations_") + require.NoError(t, err) + defer os.RemoveAll(tmpdir) + + // Create a subdirectory that would match the glob + subdir := filepath.Join(tmpdir, "migrations") + os.Mkdir(subdir, 0755) + + // Create a file + os.WriteFile(filepath.Join(tmpdir, "001.sql"), []byte("CREATE TABLE users (id INT)"), 0644) + + // Glob should skip the directory + pattern := "file:" + filepath.Join(tmpdir, "*") + sources, err := resolveStatement(pattern) + require.NoError(t, err) + require.Len(t, sources, 1) // Only the file, not the directory +} + +func TestResolveStatement_Integration(t *testing.T) { + // Create a realistic test directory structure + tmpdir, err := os.MkdirTemp("", "test_migrations_") + require.NoError(t, err) + defer os.RemoveAll(tmpdir) + + // Create some files + os.WriteFile(filepath.Join(tmpdir, "001_users.sql"), []byte("CREATE TABLE users (id INT)"), 0644) + os.WriteFile(filepath.Join(tmpdir, "002_orders.sql"), []byte("CREATE TABLE orders (id INT)"), 0644) + + tests := []struct { + name string + arg string + expectedCount int + }{ + { + name: "specific file", + arg: "file:" + filepath.Join(tmpdir, "001_users.sql"), + expectedCount: 1, + }, + { + name: "directory", + arg: "file:" + tmpdir, + expectedCount: 2, + }, + { + name: "glob all sql", + arg: "file:" + filepath.Join(tmpdir, "*.sql"), + expectedCount: 2, + }, + { + name: "glob with pattern", + arg: "file:" + filepath.Join(tmpdir, "001*.sql"), + expectedCount: 1, + }, + { + name: "cmdline", + arg: "CREATE TABLE test (id INT)", + expectedCount: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sources, err := resolveStatement(tt.arg) + require.NoError(t, err) + assert.Len(t, sources, tt.expectedCount) + }) + } +} + +func TestParseStatementSource_Empty(t *testing.T) { + source := StatementSource{ + Origin: "cmdline", + SQL: "", + } + + createTables, alterStatements, err := parseStatementSource(source) + require.NoError(t, err) + assert.Nil(t, createTables) + assert.Nil(t, alterStatements) +} + +func TestParseStatementSource_WhitespaceOnly(t *testing.T) { + source := StatementSource{ + Origin: "cmdline", + SQL: " \n\t ", + } + + createTables, alterStatements, err := parseStatementSource(source) + require.NoError(t, err) + assert.Nil(t, createTables) + assert.Nil(t, alterStatements) +} + +func TestParseStatementSource_SingleCreateTable(t *testing.T) { + source := StatementSource{ + Origin: "cmdline", + SQL: "CREATE TABLE users (id BIGINT UNSIGNED PRIMARY KEY, email VARCHAR(255))", + } + + createTables, alterStatements, err := parseStatementSource(source) + require.NoError(t, err) + require.Len(t, createTables, 1) + assert.Len(t, alterStatements, 0) + assert.Equal(t, "users", createTables[0].GetTableName()) +} + +func TestParseStatementSource_SingleAlterTable(t *testing.T) { + source := StatementSource{ + Origin: "cmdline", + SQL: "ALTER TABLE users ADD COLUMN email VARCHAR(255)", + } + + createTables, alterStatements, err := parseStatementSource(source) + require.NoError(t, err) + assert.Len(t, createTables, 0) + require.Len(t, alterStatements, 1) + assert.Equal(t, "users", alterStatements[0].Table) +} + +func TestParseStatementSource_MultipleAlterStatements(t *testing.T) { + source := StatementSource{ + Origin: "file:migrations/001.sql", + SQL: ` + ALTER TABLE users ADD COLUMN email VARCHAR(255); + ALTER TABLE users ADD INDEX idx_email (email); + `, + } + + createTables, alterStatements, err := parseStatementSource(source) + require.NoError(t, err) + assert.Len(t, createTables, 0) + require.Len(t, alterStatements, 2) + assert.Equal(t, "users", alterStatements[0].Table) + assert.Equal(t, "users", alterStatements[1].Table) +} + +func TestParseStatementSource_MixedStatements(t *testing.T) { + source := StatementSource{ + Origin: "file:schema.sql", + SQL: ` + CREATE TABLE users (id BIGINT UNSIGNED PRIMARY KEY); + ALTER TABLE users ADD INDEX idx_email (email); + `, + } + + // Mixed statements should fail due to statement.New() limitation + createTables, alterStatements, err := parseStatementSource(source) + assert.Error(t, err) + assert.Nil(t, createTables) + assert.Nil(t, alterStatements) + assert.Contains(t, err.Error(), "failed to parse file:schema.sql") +} + +func TestParseStatementSource_MultipleCreateStatements(t *testing.T) { + source := StatementSource{ + Origin: "file:schema.sql", + SQL: ` + CREATE TABLE users (id BIGINT UNSIGNED PRIMARY KEY); + CREATE TABLE orders (id BIGINT UNSIGNED PRIMARY KEY); + `, + } + + // Multiple CREATE statements should fail due to statement.New() limitation + createTables, alterStatements, err := parseStatementSource(source) + assert.Error(t, err) + assert.Nil(t, createTables) + assert.Nil(t, alterStatements) + assert.Contains(t, err.Error(), "failed to parse file:schema.sql") +} + +func TestParseStatementSource_InvalidSQL(t *testing.T) { + source := StatementSource{ + Origin: "cmdline", + SQL: "INVALID SQL STATEMENT", + } + + createTables, alterStatements, err := parseStatementSource(source) + assert.Error(t, err) + assert.Nil(t, createTables) + assert.Nil(t, alterStatements) + assert.Contains(t, err.Error(), "failed to parse cmdline") +} + +func TestParseStatementSource_ErrorContext(t *testing.T) { + source := StatementSource{ + Origin: "file:migrations/bad.sql", + SQL: "CREATE TABLE", + } + + _, _, err := parseStatementSource(source) + assert.Error(t, err) + assert.Contains(t, err.Error(), "file:migrations/bad.sql") +} + +func TestParseStatementSource_WithComments(t *testing.T) { + source := StatementSource{ + Origin: "file:schema.sql", + SQL: ` + -- This is a comment + ALTER TABLE users ADD COLUMN email VARCHAR(255); + /* Multi-line + comment */ + ALTER TABLE users ADD INDEX idx_email (email); + `, + } + + createTables, alterStatements, err := parseStatementSource(source) + require.NoError(t, err) + assert.Len(t, createTables, 0) + require.Len(t, alterStatements, 2) +} From f56a310e6218a1588fdfc1e0508d96ca60ac1ff3 Mon Sep 17 00:00:00 2001 From: Kolbe Kegel Date: Mon, 20 Oct 2025 18:52:52 -0700 Subject: [PATCH 11/12] analyze all provided create & alter statements together. add README.md for cli. --- cmd/lint/README.md | 20 ++++++++++++++++++++ pkg/lint/cmd.go | 29 +++++++++++++---------------- 2 files changed, 33 insertions(+), 16 deletions(-) create mode 100644 cmd/lint/README.md diff --git a/cmd/lint/README.md b/cmd/lint/README.md new file mode 100644 index 00000000..3b56aebf --- /dev/null +++ b/cmd/lint/README.md @@ -0,0 +1,20 @@ +The `lint` command is a simplistic, **experimental** interface to spirit's `lint` package. + +It is not intended that this command be incorporated into your workflow at this time. The interface and +output of this command is likely to change drastically, and it's possible it will be removed entirely. + +You can provide the statements via the `--statements` option in one of three ways: +1. In plaintext on the command line (`--statements="CREATE TABLE ..." --statements="ALTER TABLE ..."`) +2. Via a file or directory or glob pattern containing the statements (`--statements=file:/path/to/file.sql`) +3. Via standard input (`--statements=-`) + +You can combine these however you like. + +All statements provided are considered as a single group. + +Because of implementation details of the `lint` package and in order to reduce the complexity of this CLI, +the `CREATE TABLE` and `ALTER TABLE` statements have to be provided in a pretty specific way: + +* `CREATE TABLE` statements must be provided one-by-one. One per `--statements` argument, one per file, etc. +* `ALTER TABLE` statements can be provided one-by-one, or multiple `ALTER TABLE` statements can be provided in a single +`--statements` argument or file, separated by semicolons. \ No newline at end of file diff --git a/pkg/lint/cmd.go b/pkg/lint/cmd.go index bad578c9..86550bca 100644 --- a/pkg/lint/cmd.go +++ b/pkg/lint/cmd.go @@ -206,8 +206,9 @@ type Lint struct { } func (l *Lint) Run() error { + var allCreateTables []*statement.CreateTable + var allAlterStatements []*statement.AbstractStatement var lintConfig Config - foundViolations := false if len(l.Statement) == 0 { return errors.New("must specify at least one statement to lint") @@ -233,24 +234,20 @@ func (l *Lint) Run() error { fmt.Fprintf(os.Stderr, "Warning: no valid statements found in %s, skipping\n", source.Origin) continue // No valid statements in this source } + allCreateTables = append(allCreateTables, createTables...) + allAlterStatements = append(allAlterStatements, alterStatements...) + } - // Run linters - violations := RunLinters(createTables, alterStatements, lintConfig) - - if len(violations) == 0 { - fmt.Printf("No lint violations found in %q\n", source.Origin) - continue - } - fmt.Printf("Found lint violations found in %q:\n", source.Origin) - foundViolations = true + // Run linters + violations := RunLinters(allCreateTables, allAlterStatements, lintConfig) - for _, v := range violations { - fmt.Println(v.String()) - } + if len(violations) == 0 { + fmt.Println("No lint violations found") + return nil } - if foundViolations { - return errors.New("lint violations found") + for _, v := range violations { + fmt.Println(v.String()) } - return nil + return errors.New("lint violations found") } From 6be0aed198b74e7e50cffe8d1d4ec7282cc49646 Mon Sep 17 00:00:00 2001 From: Kolbe Kegel Date: Tue, 21 Oct 2025 08:42:14 -0700 Subject: [PATCH 12/12] rebase on new push to kolbek-lint-framework --- pkg/lint/cmd.go | 28 ++++++++++++---- pkg/lint/cmd_test.go | 80 ++++++++++++++++++++------------------------ 2 files changed, 57 insertions(+), 51 deletions(-) diff --git a/pkg/lint/cmd.go b/pkg/lint/cmd.go index 86550bca..f45af415 100644 --- a/pkg/lint/cmd.go +++ b/pkg/lint/cmd.go @@ -36,6 +36,7 @@ func resolveStatement(arg string) ([]StatementSource, error) { if err != nil { return nil, fmt.Errorf("failed to read from stdin: %w", err) } + return []StatementSource{{ Origin: "stdin", SQL: string(content), @@ -110,7 +111,6 @@ func resolveDirectory(dir string) ([]StatementSource, error) { return nil }) - if err != nil { return nil, err } @@ -141,6 +141,7 @@ func resolveGlob(pattern string) ([]StatementSource, error) { if err != nil { return nil, fmt.Errorf("failed to stat file %s: %w", path, err) } + if info.IsDir() { continue } @@ -172,8 +173,10 @@ func parseStatementSource(source StatementSource) ([]*statement.CreateTable, []* return nil, nil, nil // Empty source is OK } - var createTables []*statement.CreateTable - var alterStatements []*statement.AbstractStatement + var ( + createTables []*statement.CreateTable + alterStatements []*statement.AbstractStatement + ) // Parse all statements stmts, err := statement.New(sql) @@ -191,6 +194,7 @@ func parseStatementSource(source StatementSource) ([]*statement.CreateTable, []* if err != nil { return nil, nil, fmt.Errorf("failed to parse CREATE TABLE from %s: %w", source.Origin, err) } + createTables = append(createTables, ct) } } @@ -206,9 +210,11 @@ type Lint struct { } func (l *Lint) Run() error { - var allCreateTables []*statement.CreateTable - var allAlterStatements []*statement.AbstractStatement - var lintConfig Config + var ( + allCreateTables []*statement.CreateTable + allAlterStatements []*statement.AbstractStatement + lintConfig Config + ) if len(l.Statement) == 0 { return errors.New("must specify at least one statement to lint") @@ -216,11 +222,13 @@ func (l *Lint) Run() error { // Resolve all statement arguments into sources var sources []StatementSource + for _, arg := range l.Statement { s, err := resolveStatement(arg) if err != nil { return err } + sources = append(sources, s...) } @@ -230,16 +238,21 @@ func (l *Lint) Run() error { if err != nil { return err } + if len(createTables) == 0 && len(alterStatements) == 0 { fmt.Fprintf(os.Stderr, "Warning: no valid statements found in %s, skipping\n", source.Origin) continue // No valid statements in this source } + allCreateTables = append(allCreateTables, createTables...) allAlterStatements = append(allAlterStatements, alterStatements...) } // Run linters - violations := RunLinters(allCreateTables, allAlterStatements, lintConfig) + violations, err := RunLinters(allCreateTables, allAlterStatements, lintConfig) + if err != nil { + return fmt.Errorf("failed to run linters: %w", err) + } if len(violations) == 0 { fmt.Println("No lint violations found") @@ -249,5 +262,6 @@ func (l *Lint) Run() error { for _, v := range violations { fmt.Println(v.String()) } + return errors.New("lint violations found") } diff --git a/pkg/lint/cmd_test.go b/pkg/lint/cmd_test.go index 9fa441e9..cabd20d1 100644 --- a/pkg/lint/cmd_test.go +++ b/pkg/lint/cmd_test.go @@ -43,14 +43,16 @@ func TestResolveStatement_Cmdline(t *testing.T) { func TestResolveStatement_Stdin(t *testing.T) { // Mock stdin oldStdin := os.Stdin + defer func() { os.Stdin = oldStdin }() r, w, _ := os.Pipe() os.Stdin = r sql := "CREATE TABLE users (id BIGINT PRIMARY KEY)" + go func() { - w.Write([]byte(sql)) + _, _ = w.WriteString(sql) w.Close() }() @@ -63,12 +65,13 @@ func TestResolveStatement_Stdin(t *testing.T) { func TestResolveStatement_File(t *testing.T) { // Create a temporary file - tmpfile, err := os.CreateTemp("", "test_*.sql") + tmpfile, err := os.CreateTemp(t.TempDir(), "test_*.sql") require.NoError(t, err) + defer os.Remove(tmpfile.Name()) sql := "CREATE TABLE users (id BIGINT PRIMARY KEY)" - _, err = tmpfile.Write([]byte(sql)) + _, err = tmpfile.WriteString(sql) require.NoError(t, err) tmpfile.Close() @@ -88,22 +91,22 @@ func TestResolveStatement_FileNotExists(t *testing.T) { func TestResolveStatement_Directory(t *testing.T) { // Create a temporary directory with SQL files - tmpdir, err := os.MkdirTemp("", "test_migrations_") - require.NoError(t, err) - defer os.RemoveAll(tmpdir) + tmpdir := t.TempDir() // Create some files sql1 := "CREATE TABLE users (id INT)" sql2 := "CREATE TABLE orders (id INT)" - os.WriteFile(filepath.Join(tmpdir, "001_users.sql"), []byte(sql1), 0644) - os.WriteFile(filepath.Join(tmpdir, "002_orders.sql"), []byte(sql2), 0644) - os.WriteFile(filepath.Join(tmpdir, "README.md"), []byte("# Migrations"), 0644) + + require.NoError(t, os.WriteFile(filepath.Join(tmpdir, "001_users.sql"), []byte(sql1), 0644)) + require.NoError(t, os.WriteFile(filepath.Join(tmpdir, "002_orders.sql"), []byte(sql2), 0644)) + require.NoError(t, os.WriteFile(filepath.Join(tmpdir, "README.md"), []byte("# Migrations"), 0644)) // Create a subdirectory with a file subdir := filepath.Join(tmpdir, "archived") - os.Mkdir(subdir, 0755) + require.NoError(t, os.Mkdir(subdir, 0755)) + sql3 := "CREATE TABLE old (id INT)" - os.WriteFile(filepath.Join(subdir, "old.sql"), []byte(sql3), 0644) + require.NoError(t, os.WriteFile(filepath.Join(subdir, "old.sql"), []byte(sql3), 0644)) sources, err := resolveStatement("file:" + tmpdir) require.NoError(t, err) @@ -123,9 +126,7 @@ func TestResolveStatement_Directory(t *testing.T) { } func TestResolveStatement_DirectoryEmpty(t *testing.T) { - tmpdir, err := os.MkdirTemp("", "test_empty_") - require.NoError(t, err) - defer os.RemoveAll(tmpdir) + tmpdir := t.TempDir() sources, err := resolveStatement("file:" + tmpdir) assert.Error(t, err) @@ -135,18 +136,17 @@ func TestResolveStatement_DirectoryEmpty(t *testing.T) { func TestResolveStatement_Glob(t *testing.T) { // Create a temporary directory with SQL files - tmpdir, err := os.MkdirTemp("", "test_migrations_") - require.NoError(t, err) - defer os.RemoveAll(tmpdir) + tmpdir := t.TempDir() // Create some files sql1 := "CREATE TABLE users (id INT)" sql2 := "CREATE TABLE orders (id INT)" sql3 := "CREATE TABLE products (id INT)" - os.WriteFile(filepath.Join(tmpdir, "001_users.sql"), []byte(sql1), 0644) - os.WriteFile(filepath.Join(tmpdir, "002_orders.sql"), []byte(sql2), 0644) - os.WriteFile(filepath.Join(tmpdir, "003_products.sql"), []byte(sql3), 0644) - os.WriteFile(filepath.Join(tmpdir, "README.md"), []byte("# Migrations"), 0644) + + require.NoError(t, os.WriteFile(filepath.Join(tmpdir, "001_users.sql"), []byte(sql1), 0644)) + require.NoError(t, os.WriteFile(filepath.Join(tmpdir, "002_orders.sql"), []byte(sql2), 0644)) + require.NoError(t, os.WriteFile(filepath.Join(tmpdir, "003_products.sql"), []byte(sql3), 0644)) + require.NoError(t, os.WriteFile(filepath.Join(tmpdir, "README.md"), []byte("# Migrations"), 0644)) // Test glob pattern pattern := "file:" + filepath.Join(tmpdir, "*.sql") @@ -162,9 +162,7 @@ func TestResolveStatement_Glob(t *testing.T) { } func TestResolveStatement_GlobNoMatches(t *testing.T) { - tmpdir, err := os.MkdirTemp("", "test_migrations_") - require.NoError(t, err) - defer os.RemoveAll(tmpdir) + tmpdir := t.TempDir() pattern := "file:" + filepath.Join(tmpdir, "*.sql") sources, err := resolveStatement(pattern) @@ -174,14 +172,12 @@ func TestResolveStatement_GlobNoMatches(t *testing.T) { } func TestResolveStatement_GlobWithPattern(t *testing.T) { - tmpdir, err := os.MkdirTemp("", "test_migrations_") - require.NoError(t, err) - defer os.RemoveAll(tmpdir) + tmpdir := t.TempDir() // Create files with different patterns - os.WriteFile(filepath.Join(tmpdir, "001_users.sql"), []byte("CREATE TABLE users (id INT)"), 0644) - os.WriteFile(filepath.Join(tmpdir, "002_orders.sql"), []byte("CREATE TABLE orders (id INT)"), 0644) - os.WriteFile(filepath.Join(tmpdir, "999_old.sql"), []byte("CREATE TABLE old (id INT)"), 0644) + require.NoError(t, os.WriteFile(filepath.Join(tmpdir, "001_users.sql"), []byte("CREATE TABLE users (id INT)"), 0644)) + require.NoError(t, os.WriteFile(filepath.Join(tmpdir, "002_orders.sql"), []byte("CREATE TABLE orders (id INT)"), 0644)) + require.NoError(t, os.WriteFile(filepath.Join(tmpdir, "999_old.sql"), []byte("CREATE TABLE old (id INT)"), 0644)) // Test pattern that matches only 001 and 002 pattern := "file:" + filepath.Join(tmpdir, "00[12]*.sql") @@ -191,16 +187,14 @@ func TestResolveStatement_GlobWithPattern(t *testing.T) { } func TestResolveStatement_GlobSkipsDirectories(t *testing.T) { - tmpdir, err := os.MkdirTemp("", "test_migrations_") - require.NoError(t, err) - defer os.RemoveAll(tmpdir) + tmpdir := t.TempDir() // Create a subdirectory that would match the glob subdir := filepath.Join(tmpdir, "migrations") - os.Mkdir(subdir, 0755) + require.NoError(t, os.Mkdir(subdir, 0755)) // Create a file - os.WriteFile(filepath.Join(tmpdir, "001.sql"), []byte("CREATE TABLE users (id INT)"), 0644) + require.NoError(t, os.WriteFile(filepath.Join(tmpdir, "001.sql"), []byte("CREATE TABLE users (id INT)"), 0644)) // Glob should skip the directory pattern := "file:" + filepath.Join(tmpdir, "*") @@ -211,13 +205,11 @@ func TestResolveStatement_GlobSkipsDirectories(t *testing.T) { func TestResolveStatement_Integration(t *testing.T) { // Create a realistic test directory structure - tmpdir, err := os.MkdirTemp("", "test_migrations_") - require.NoError(t, err) - defer os.RemoveAll(tmpdir) + tmpdir := t.TempDir() // Create some files - os.WriteFile(filepath.Join(tmpdir, "001_users.sql"), []byte("CREATE TABLE users (id INT)"), 0644) - os.WriteFile(filepath.Join(tmpdir, "002_orders.sql"), []byte("CREATE TABLE orders (id INT)"), 0644) + require.NoError(t, os.WriteFile(filepath.Join(tmpdir, "001_users.sql"), []byte("CREATE TABLE users (id INT)"), 0644)) + require.NoError(t, os.WriteFile(filepath.Join(tmpdir, "002_orders.sql"), []byte("CREATE TABLE orders (id INT)"), 0644)) tests := []struct { name string @@ -293,7 +285,7 @@ func TestParseStatementSource_SingleCreateTable(t *testing.T) { createTables, alterStatements, err := parseStatementSource(source) require.NoError(t, err) require.Len(t, createTables, 1) - assert.Len(t, alterStatements, 0) + assert.Empty(t, alterStatements) assert.Equal(t, "users", createTables[0].GetTableName()) } @@ -305,7 +297,7 @@ func TestParseStatementSource_SingleAlterTable(t *testing.T) { createTables, alterStatements, err := parseStatementSource(source) require.NoError(t, err) - assert.Len(t, createTables, 0) + assert.Empty(t, createTables) require.Len(t, alterStatements, 1) assert.Equal(t, "users", alterStatements[0].Table) } @@ -321,7 +313,7 @@ func TestParseStatementSource_MultipleAlterStatements(t *testing.T) { createTables, alterStatements, err := parseStatementSource(source) require.NoError(t, err) - assert.Len(t, createTables, 0) + assert.Empty(t, createTables) require.Len(t, alterStatements, 2) assert.Equal(t, "users", alterStatements[0].Table) assert.Equal(t, "users", alterStatements[1].Table) @@ -399,6 +391,6 @@ func TestParseStatementSource_WithComments(t *testing.T) { createTables, alterStatements, err := parseStatementSource(source) require.NoError(t, err) - assert.Len(t, createTables, 0) + assert.Empty(t, createTables) require.Len(t, alterStatements, 2) }