diff --git a/USAGE.md b/USAGE.md index d580aad7..56f46057 100644 --- a/USAGE.md +++ b/USAGE.md @@ -29,6 +29,8 @@ - [Experimental Features](#experimental-features) - [enable-experimental-multi-table-support](#enable-experimental-multi-table-support) - [enable-experimental-buffered-copy](#enable-experimental-buffered-copy) + - [`move` command](#move-command) + - [native linting support](#native-linting-support) ## Getting Started @@ -365,4 +367,15 @@ This feature provides a new top level binary `move`, which can copy whole schema This command depends strongly on the experimental buffered copy and multi-table support, both which are currently experimental. There is not too much which is special to move on top of these two features, so once they become stable, so too can `move`. -It is anticipated that `move` will need to provide some pluggable method of cutover so external metadata systems can be updated. There is no current design for this. \ No newline at end of file +It is anticipated that `move` will need to provide some pluggable method of cutover so external metadata systems can be updated. There is no current design for this. + + +### native linting support + +**Feature Description** + +This feature adds native linting support to Spirit, allowing for various rules to be applied to schema changes before they are executed. + +**Current Status** + +This feature is partially complete. It relies on new support for parsing CREATE TABLE statements (see `pkg/statetement/parse_create_table.go`). There are so far only a few linters implemented. This functionality is not currently exposed via command line flags. \ No newline at end of file diff --git a/cmd/lint/README.md b/cmd/lint/README.md new file mode 100644 index 00000000..3b56aebf --- /dev/null +++ b/cmd/lint/README.md @@ -0,0 +1,20 @@ +The `lint` command is a simplistic, **experimental** interface to spirit's `lint` package. + +It is not intended that this command be incorporated into your workflow at this time. The interface and +output of this command is likely to change drastically, and it's possible it will be removed entirely. + +You can provide the statements via the `--statements` option in one of three ways: +1. In plaintext on the command line (`--statements="CREATE TABLE ..." --statements="ALTER TABLE ..."`) +2. Via a file or directory or glob pattern containing the statements (`--statements=file:/path/to/file.sql`) +3. Via standard input (`--statements=-`) + +You can combine these however you like. + +All statements provided are considered as a single group. + +Because of implementation details of the `lint` package and in order to reduce the complexity of this CLI, +the `CREATE TABLE` and `ALTER TABLE` statements have to be provided in a pretty specific way: + +* `CREATE TABLE` statements must be provided one-by-one. One per `--statements` argument, one per file, etc. +* `ALTER TABLE` statements can be provided one-by-one, or multiple `ALTER TABLE` statements can be provided in a single +`--statements` argument or file, separated by semicolons. \ No newline at end of file diff --git a/cmd/lint/lint.go b/cmd/lint/lint.go new file mode 100644 index 00000000..5484d02e --- /dev/null +++ b/cmd/lint/lint.go @@ -0,0 +1,15 @@ +package main + +import ( + "github.com/alecthomas/kong" + "github.com/block/spirit/pkg/lint" +) + +var cli struct { + lint.Lint `cmd:"" help:"Lint CREATE TABLE and ALTER TABLE statements."` +} + +func main() { + ctx := kong.Parse(&cli) + ctx.FatalIfErrorf(ctx.Run()) +} diff --git a/pkg/lint/README.md b/pkg/lint/README.md new file mode 100644 index 00000000..1173ce9e --- /dev/null +++ b/pkg/lint/README.md @@ -0,0 +1,303 @@ +# Linter Framework + +The `lint` package provides a framework for static analysis of MySQL schema definitions and DDL statements. It enables validation and best-practice enforcement beyond the runtime checks provided by the `check` package. + +## Architecture + +All built-in linters are automatically registered and enabled when the `lint` package is imported. The framework uses a flat package structure: + +- **Core framework files**: `lint.go`, `linter.go`, `registry.go`, `violation.go` +- **Linter implementations**: `lint_*.go` (e.g., `lint_invisible_index.go`) + +## Quick Start + +### Using the Linter Framework + +```go +import ( + "github.com/block/spirit/pkg/lint" +) + +// All built-in linters are automatically registered! +violations, err := lint.RunLinters(tables, stmts, lint.Config{}) +if err != nil { + // Handle configuration errors +} + +// Check for errors +if lint.HasErrors(violations) { + // Handle errors +} + +// Filter violations +errors := lint.FilterBySeverity(violations, lint.SeverityError) +warnings := lint.FilterBySeverity(violations, lint.SeverityWarning) +``` + +### Creating a Custom Linter + +Custom linters can be +1. added directly to the `lint` package (in new files with the `lint_` prefix, for consistency) +2. added to your own package and registered by blank import that relies on the `init()` function +3. added to your own code and registered explicitly using `lint.Register()` + +```go +// lint_my_custom.go +package lint + +import ( + "github.com/block/spirit/pkg/statement" +) + +// Register your linter in init() +func init() { + Register(&MyCustomLinter{}) +} + +// MyCustomLinter checks custom rules +type MyCustomLinter struct{} + +func (l *MyCustomLinter) Name() string { return "my_custom" } +func (l *MyCustomLinter) Category() string { return "naming" } +func (l *MyCustomLinter) Description() string { return "Checks naming conventions" } +func (l *MyCustomLinter) String() string { return l.Name() } + +func (l *MyCustomLinter) Lint(createTables []*statement.CreateTable, alterStatements []*statement.AbstractStatement) []Violation { + var violations []Violation + + for _, ct := range createTables { + // Check table properties + if /* condition */ { + violations = append(violations, Violation{ + Linter: l, + Severity: SeverityWarning, + Message: "Table name issue", + Location: &Location{ + Table: ct.GetTableName(), + }, + }) + } + } + + return violations +} +``` + +### Configuring Linters + +#### Enabling/Disabling Linters + +```go +// Disable specific linters +violations, err := lint.RunLinters(tables, stmts, lint.Config{ + Enabled: map[string]bool{ + "invisible_index_before_drop": false, + "primary_key_type": true, + }, +}) +if err != nil { + // Handle configuration errors +} +``` + +#### Configurable Linters + +Some linters support additional configuration options via the `Settings` field. Linters that implement the `ConfigurableLinter` interface accept settings as `map[string]string`: + +```go +violations, err := lint.RunLinters(tables, stmts, lint.Config{ + Settings: map[string]map[string]string{ + "invisible_index_before_drop": { + "raiseError": "true", // Make violations errors instead of warnings + }, + }, +}) +if err != nil { + // Handle configuration errors (e.g., invalid settings) +} +``` + +Each configurable linter defines its own settings keys and values. See the individual linter documentation below for available options. + +## Core Types + +### Severity Levels + +- **ERROR**: Will cause actual problems (data loss, inconsistency, MySQL limitations) +- **WARNING**: Best practice violations, potential issues +- **INFO**: Suggestions, style preferences + +### Violation + +```go +type Violation struct { + Linter Linter // The linter that produced this violation + Severity Severity // ERROR, WARNING, or INFO + Message string // Human-readable message + Location *Location // Where the violation occurred + Suggestion *string // Optional fix suggestion + Context map[string]any // Additional context +} +``` + +### Location + +```go +type Location struct { + Table string // Table name + Column *string // Column name (if applicable) + Index *string // Index name (if applicable) + Constraint *string // Constraint name (if applicable) +} +``` + +## API Functions + +### Registration + +- `Register(l Linter)` - Register a linter (call from init()) +- `Enable(name string)` - Enable a linter by name +- `Disable(name string)` - Disable a linter by name +- `List()` - Get all registered linter names +- `ListByCategory(category string)` - Get linters in a category +- `Get(name string)` - Get a linter by name + +### Execution + +- `RunLinters(createTables, alterStatements, config) ([]Violation, error)` - Run all enabled linters, returns violations and any configuration errors +- `HasErrors(violations)` - Check if any violations are errors +- `HasWarnings(violations)` - Check if any violations are warnings +- `FilterBySeverity(violations, severity)` - Filter by severity level +- `FilterByLinter(violations, name)` - Filter by linter name + +## Built-in Linters + +The `lint` package includes several linters: + +### invisible_index_before_drop + +**Category**: schema +**Severity**: Warning (default), Error (configurable) +**Configurable**: Yes + +Requires indexes to be made invisible before dropping them as a safety measure. This ensures the index isn't needed before permanently removing it. + +**Configuration Options:** + +- `raiseError` (string): Set to `"true"` to make violations errors instead of warnings. Default: `false`. + +**Example Usage:** + +```go +// ❌ Violation +ALTER TABLE users DROP INDEX idx_email; + +// ✅ Correct +ALTER TABLE users ALTER INDEX idx_email INVISIBLE; +-- Wait and monitor performance +ALTER TABLE users DROP INDEX idx_email; +``` + +**Configuration Example:** + +```go +violations, err := lint.RunLinters(tables, stmts, lint.Config{ + Settings: map[string]map[string]string{ + "invisible_index_before_drop": { + "raiseError": "true", // Violations will be errors + }, + }, +}) +``` + +### multiple_alter_table + +**Category**: schema +**Severity**: Info + +Detects multiple ALTER TABLE statements on the same table that could be combined into one for better performance and fewer table rebuilds. + +```go +// ❌ Violation +ALTER TABLE users ADD COLUMN age INT; +ALTER TABLE users ADD INDEX idx_age (age); + +// ✅ Better +ALTER TABLE users + ADD COLUMN age INT, + ADD INDEX idx_age (age); +``` + +### primary_key_type + +**Category**: schema +**Severity**: Error (invalid types), Warning (signed BIGINT) + +Ensures primary keys use BIGINT (preferably UNSIGNED) or BINARY/VARBINARY types. + +```go +// ❌ Error - invalid type +CREATE TABLE users ( + id INT PRIMARY KEY -- Should be BIGINT +); + +// ⚠️ Warning - should be unsigned +CREATE TABLE users ( + id BIGINT PRIMARY KEY -- Should be BIGINT UNSIGNED +); + +// ✅ Correct +CREATE TABLE users ( + id BIGINT UNSIGNED PRIMARY KEY +); +``` + +## Example Linters + +The `example` package provides demonstration linters for learning purposes: + +### TableNameLengthLinter + +Checks that table names don't exceed MySQL's 64 character limit. + +### DuplicateColumnLinter + +Detects duplicate column definitions in CREATE TABLE statements. + +See `pkg/lint/example/` for reference implementations. + +## Contributing + +When adding new linters to the `lint` package: + +1. **Create a new file** with the `lint_` prefix (e.g., `lint_my_rule.go`) +2. **Implement the `Linter` interface** with all required methods +3. **Register in `init()`** function to enable automatic registration +4. **Add comprehensive tests** in a corresponding `lint_my_rule_test.go` file +5. **Document the linter** with clear comments and examples +6. **Choose appropriate severity levels**: + - `SeverityError` for violations that will cause actual problems + - `SeverityWarning` for best practice violations + - `SeverityInfo` for suggestions and style preferences +7. **Provide helpful messages** with actionable suggestions when possible +8. **Update this README** with documentation for the new linter + +### File Naming Convention + +- Linter implementation: `lint_.go` +- Linter tests: `lint__test.go` + +### Example Structure + +``` +pkg/lint/ +├── lint.go # Core API +├── linter.go # Interface definition +├── registry.go # Registration system +├── violation.go # Violation types +├── lint_invisible_index.go # Built-in linter +├── lint_multiple_alter.go # Built-in linter +├── lint_primary_key_type.go # Built-in linter +├── lint_my_new_rule.go # Your new linter +└── lint_my_new_rule_test.go # Your tests +``` diff --git a/pkg/lint/cmd.go b/pkg/lint/cmd.go new file mode 100644 index 00000000..f45af415 --- /dev/null +++ b/pkg/lint/cmd.go @@ -0,0 +1,267 @@ +package lint + +import ( + "errors" + "fmt" + "io" + "os" + "path/filepath" + "strings" + + "github.com/block/spirit/pkg/statement" +) + +// StatementSource represents a single source of SQL statements. +type StatementSource struct { + // Origin describes where this SQL came from + // For files: "file:" + file path (e.g., "file:migrations/001.sql") + // For stdin: "stdin" + // For command-line: "cmdline" + Origin string + + // SQL contains the actual SQL content + SQL string +} + +// resolveStatement takes a single --statement argument and returns one or more StatementSources. +// - Inline SQL → 1 StatementSource with Origin="cmdline" +// - "-" (stdin) → 1 StatementSource with Origin="stdin" +// - "file:path.sql" → 1 StatementSource with Origin="file:path.sql" +// - "file:dir/" → N StatementSources (one per .sql file in directory, recursively) +// - "file:*.sql" → N StatementSources (one per matching file) +func resolveStatement(arg string) ([]StatementSource, error) { + // Check for stdin + if arg == "-" { + content, err := io.ReadAll(os.Stdin) + if err != nil { + return nil, fmt.Errorf("failed to read from stdin: %w", err) + } + + return []StatementSource{{ + Origin: "stdin", + SQL: string(content), + }}, nil + } + + // Check for file: prefix + if strings.HasPrefix(arg, "file:") { + path := strings.TrimPrefix(arg, "file:") + + // Check if it's a glob pattern (contains wildcard characters) + if strings.ContainsAny(path, "*?[]") { + return resolveGlob(path) + } + + // Try to stat the path to determine if it's a file or directory + info, err := os.Stat(path) + if err != nil { + return nil, fmt.Errorf("failed to access %s: %w", path, err) + } + + if info.IsDir() { + return resolveDirectory(path) + } + + return resolveFile(path) + } + + // Default to command-line SQL + return []StatementSource{{ + Origin: "cmdline", + SQL: arg, + }}, nil +} + +// resolveFile reads a single SQL file and returns a StatementSource +func resolveFile(path string) ([]StatementSource, error) { + content, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("failed to read file %s: %w", path, err) + } + + return []StatementSource{{ + Origin: "file:" + path, + SQL: string(content), + }}, nil +} + +// resolveDirectory recursively finds all .sql files in a directory and returns StatementSources +func resolveDirectory(dir string) ([]StatementSource, error) { + var sources []StatementSource + + err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + // Skip directories and non-.sql files + if info.IsDir() || !strings.HasSuffix(strings.ToLower(path), ".sql") { + return nil + } + + content, err := os.ReadFile(path) + if err != nil { + return fmt.Errorf("failed to read file %s: %w", path, err) + } + + sources = append(sources, StatementSource{ + Origin: "file:" + path, + SQL: string(content), + }) + + return nil + }) + if err != nil { + return nil, err + } + + if len(sources) == 0 { + return nil, fmt.Errorf("no .sql files found in directory: %s", dir) + } + + return sources, nil +} + +// resolveGlob expands a glob pattern and returns StatementSources for all matching files +func resolveGlob(pattern string) ([]StatementSource, error) { + matches, err := filepath.Glob(pattern) + if err != nil { + return nil, fmt.Errorf("invalid glob pattern %s: %w", pattern, err) + } + + if len(matches) == 0 { + return nil, fmt.Errorf("no files matched glob pattern: %s", pattern) + } + + var sources []StatementSource + + for _, path := range matches { + // Skip directories + info, err := os.Stat(path) + if err != nil { + return nil, fmt.Errorf("failed to stat file %s: %w", path, err) + } + + if info.IsDir() { + continue + } + + content, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("failed to read file %s: %w", path, err) + } + + sources = append(sources, StatementSource{ + Origin: "file:" + path, + SQL: string(content), + }) + } + + if len(sources) == 0 { + return nil, fmt.Errorf("glob pattern matched only directories: %s", pattern) + } + + return sources, nil +} + +// parseStatementSource parses a single StatementSource and extracts CREATE TABLE and ALTER TABLE statements. +// Returns the parsed statements and any error encountered. +// Note: Due to limitations in statement.New(), a single source cannot contain both CREATE TABLE and ALTER TABLE statements. +func parseStatementSource(source StatementSource) ([]*statement.CreateTable, []*statement.AbstractStatement, error) { + sql := strings.TrimSpace(source.SQL) + if sql == "" { + return nil, nil, nil // Empty source is OK + } + + var ( + createTables []*statement.CreateTable + alterStatements []*statement.AbstractStatement + ) + + // Parse all statements + stmts, err := statement.New(sql) + if err != nil { + return nil, nil, fmt.Errorf("failed to parse %s: %w", source.Origin, err) + } + + // Categorize statements + for _, stmt := range stmts { + if stmt.IsAlterTable() { + alterStatements = append(alterStatements, stmt) + } else { + // It's a CREATE TABLE, parse into structured format + ct, err := statement.ParseCreateTable(stmt.Statement) + if err != nil { + return nil, nil, fmt.Errorf("failed to parse CREATE TABLE from %s: %w", source.Origin, err) + } + + createTables = append(createTables, ct) + } + } + + return createTables, alterStatements, nil +} + +// Lint is the struct for the lint command +type Lint struct { + Statement []string `help:"CREATE TABLE and ALTER TABLE statements to lint" sep:"none"` + Linters []string `help:"Specific linters to run (default: all)" default:"all"` + Config []string `help:"Individual linter configuration properties"` +} + +func (l *Lint) Run() error { + var ( + allCreateTables []*statement.CreateTable + allAlterStatements []*statement.AbstractStatement + lintConfig Config + ) + + if len(l.Statement) == 0 { + return errors.New("must specify at least one statement to lint") + } + + // Resolve all statement arguments into sources + var sources []StatementSource + + for _, arg := range l.Statement { + s, err := resolveStatement(arg) + if err != nil { + return err + } + + sources = append(sources, s...) + } + + // Parse each source + for _, source := range sources { + createTables, alterStatements, err := parseStatementSource(source) + if err != nil { + return err + } + + if len(createTables) == 0 && len(alterStatements) == 0 { + fmt.Fprintf(os.Stderr, "Warning: no valid statements found in %s, skipping\n", source.Origin) + continue // No valid statements in this source + } + + allCreateTables = append(allCreateTables, createTables...) + allAlterStatements = append(allAlterStatements, alterStatements...) + } + + // Run linters + violations, err := RunLinters(allCreateTables, allAlterStatements, lintConfig) + if err != nil { + return fmt.Errorf("failed to run linters: %w", err) + } + + if len(violations) == 0 { + fmt.Println("No lint violations found") + return nil + } + + for _, v := range violations { + fmt.Println(v.String()) + } + + return errors.New("lint violations found") +} diff --git a/pkg/lint/cmd_test.go b/pkg/lint/cmd_test.go new file mode 100644 index 00000000..cabd20d1 --- /dev/null +++ b/pkg/lint/cmd_test.go @@ -0,0 +1,396 @@ +package lint + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestResolveStatement_Cmdline(t *testing.T) { + tests := []struct { + name string + arg string + }{ + { + name: "CREATE TABLE", + arg: "CREATE TABLE users (id BIGINT PRIMARY KEY)", + }, + { + name: "ALTER TABLE", + arg: "ALTER TABLE users ADD COLUMN email VARCHAR(255)", + }, + { + name: "multiline SQL", + arg: "CREATE TABLE users (\n id BIGINT PRIMARY KEY\n)", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sources, err := resolveStatement(tt.arg) + require.NoError(t, err) + require.Len(t, sources, 1) + assert.Equal(t, "cmdline", sources[0].Origin) + assert.Equal(t, tt.arg, sources[0].SQL) + }) + } +} + +func TestResolveStatement_Stdin(t *testing.T) { + // Mock stdin + oldStdin := os.Stdin + + defer func() { os.Stdin = oldStdin }() + + r, w, _ := os.Pipe() + os.Stdin = r + + sql := "CREATE TABLE users (id BIGINT PRIMARY KEY)" + + go func() { + _, _ = w.WriteString(sql) + w.Close() + }() + + sources, err := resolveStatement("-") + require.NoError(t, err) + require.Len(t, sources, 1) + assert.Equal(t, "stdin", sources[0].Origin) + assert.Equal(t, sql, sources[0].SQL) +} + +func TestResolveStatement_File(t *testing.T) { + // Create a temporary file + tmpfile, err := os.CreateTemp(t.TempDir(), "test_*.sql") + require.NoError(t, err) + + defer os.Remove(tmpfile.Name()) + + sql := "CREATE TABLE users (id BIGINT PRIMARY KEY)" + _, err = tmpfile.WriteString(sql) + require.NoError(t, err) + tmpfile.Close() + + sources, err := resolveStatement("file:" + tmpfile.Name()) + require.NoError(t, err) + require.Len(t, sources, 1) + assert.Equal(t, "file:"+tmpfile.Name(), sources[0].Origin) + assert.Equal(t, sql, sources[0].SQL) +} + +func TestResolveStatement_FileNotExists(t *testing.T) { + sources, err := resolveStatement("file:/nonexistent/path/to/file.sql") + assert.Error(t, err) + assert.Nil(t, sources) + assert.Contains(t, err.Error(), "failed to access") +} + +func TestResolveStatement_Directory(t *testing.T) { + // Create a temporary directory with SQL files + tmpdir := t.TempDir() + + // Create some files + sql1 := "CREATE TABLE users (id INT)" + sql2 := "CREATE TABLE orders (id INT)" + + require.NoError(t, os.WriteFile(filepath.Join(tmpdir, "001_users.sql"), []byte(sql1), 0644)) + require.NoError(t, os.WriteFile(filepath.Join(tmpdir, "002_orders.sql"), []byte(sql2), 0644)) + require.NoError(t, os.WriteFile(filepath.Join(tmpdir, "README.md"), []byte("# Migrations"), 0644)) + + // Create a subdirectory with a file + subdir := filepath.Join(tmpdir, "archived") + require.NoError(t, os.Mkdir(subdir, 0755)) + + sql3 := "CREATE TABLE old (id INT)" + require.NoError(t, os.WriteFile(filepath.Join(subdir, "old.sql"), []byte(sql3), 0644)) + + sources, err := resolveStatement("file:" + tmpdir) + require.NoError(t, err) + require.Len(t, sources, 3) // Should find all 3 .sql files recursively + + // Verify origins have file: prefix + for _, source := range sources { + assert.True(t, strings.HasPrefix(source.Origin, "file:")) + assert.True(t, strings.HasSuffix(source.Origin, ".sql")) + } + + // Verify we got the right content + sqlContents := []string{sources[0].SQL, sources[1].SQL, sources[2].SQL} + assert.Contains(t, sqlContents, sql1) + assert.Contains(t, sqlContents, sql2) + assert.Contains(t, sqlContents, sql3) +} + +func TestResolveStatement_DirectoryEmpty(t *testing.T) { + tmpdir := t.TempDir() + + sources, err := resolveStatement("file:" + tmpdir) + assert.Error(t, err) + assert.Nil(t, sources) + assert.Contains(t, err.Error(), "no .sql files found") +} + +func TestResolveStatement_Glob(t *testing.T) { + // Create a temporary directory with SQL files + tmpdir := t.TempDir() + + // Create some files + sql1 := "CREATE TABLE users (id INT)" + sql2 := "CREATE TABLE orders (id INT)" + sql3 := "CREATE TABLE products (id INT)" + + require.NoError(t, os.WriteFile(filepath.Join(tmpdir, "001_users.sql"), []byte(sql1), 0644)) + require.NoError(t, os.WriteFile(filepath.Join(tmpdir, "002_orders.sql"), []byte(sql2), 0644)) + require.NoError(t, os.WriteFile(filepath.Join(tmpdir, "003_products.sql"), []byte(sql3), 0644)) + require.NoError(t, os.WriteFile(filepath.Join(tmpdir, "README.md"), []byte("# Migrations"), 0644)) + + // Test glob pattern + pattern := "file:" + filepath.Join(tmpdir, "*.sql") + sources, err := resolveStatement(pattern) + require.NoError(t, err) + require.Len(t, sources, 3) // Should find all 3 .sql files + + // Verify origins + for _, source := range sources { + assert.True(t, strings.HasPrefix(source.Origin, "file:")) + assert.True(t, strings.HasSuffix(source.Origin, ".sql")) + } +} + +func TestResolveStatement_GlobNoMatches(t *testing.T) { + tmpdir := t.TempDir() + + pattern := "file:" + filepath.Join(tmpdir, "*.sql") + sources, err := resolveStatement(pattern) + assert.Error(t, err) + assert.Nil(t, sources) + assert.Contains(t, err.Error(), "no files matched glob pattern") +} + +func TestResolveStatement_GlobWithPattern(t *testing.T) { + tmpdir := t.TempDir() + + // Create files with different patterns + require.NoError(t, os.WriteFile(filepath.Join(tmpdir, "001_users.sql"), []byte("CREATE TABLE users (id INT)"), 0644)) + require.NoError(t, os.WriteFile(filepath.Join(tmpdir, "002_orders.sql"), []byte("CREATE TABLE orders (id INT)"), 0644)) + require.NoError(t, os.WriteFile(filepath.Join(tmpdir, "999_old.sql"), []byte("CREATE TABLE old (id INT)"), 0644)) + + // Test pattern that matches only 001 and 002 + pattern := "file:" + filepath.Join(tmpdir, "00[12]*.sql") + sources, err := resolveStatement(pattern) + require.NoError(t, err) + require.Len(t, sources, 2) +} + +func TestResolveStatement_GlobSkipsDirectories(t *testing.T) { + tmpdir := t.TempDir() + + // Create a subdirectory that would match the glob + subdir := filepath.Join(tmpdir, "migrations") + require.NoError(t, os.Mkdir(subdir, 0755)) + + // Create a file + require.NoError(t, os.WriteFile(filepath.Join(tmpdir, "001.sql"), []byte("CREATE TABLE users (id INT)"), 0644)) + + // Glob should skip the directory + pattern := "file:" + filepath.Join(tmpdir, "*") + sources, err := resolveStatement(pattern) + require.NoError(t, err) + require.Len(t, sources, 1) // Only the file, not the directory +} + +func TestResolveStatement_Integration(t *testing.T) { + // Create a realistic test directory structure + tmpdir := t.TempDir() + + // Create some files + require.NoError(t, os.WriteFile(filepath.Join(tmpdir, "001_users.sql"), []byte("CREATE TABLE users (id INT)"), 0644)) + require.NoError(t, os.WriteFile(filepath.Join(tmpdir, "002_orders.sql"), []byte("CREATE TABLE orders (id INT)"), 0644)) + + tests := []struct { + name string + arg string + expectedCount int + }{ + { + name: "specific file", + arg: "file:" + filepath.Join(tmpdir, "001_users.sql"), + expectedCount: 1, + }, + { + name: "directory", + arg: "file:" + tmpdir, + expectedCount: 2, + }, + { + name: "glob all sql", + arg: "file:" + filepath.Join(tmpdir, "*.sql"), + expectedCount: 2, + }, + { + name: "glob with pattern", + arg: "file:" + filepath.Join(tmpdir, "001*.sql"), + expectedCount: 1, + }, + { + name: "cmdline", + arg: "CREATE TABLE test (id INT)", + expectedCount: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sources, err := resolveStatement(tt.arg) + require.NoError(t, err) + assert.Len(t, sources, tt.expectedCount) + }) + } +} + +func TestParseStatementSource_Empty(t *testing.T) { + source := StatementSource{ + Origin: "cmdline", + SQL: "", + } + + createTables, alterStatements, err := parseStatementSource(source) + require.NoError(t, err) + assert.Nil(t, createTables) + assert.Nil(t, alterStatements) +} + +func TestParseStatementSource_WhitespaceOnly(t *testing.T) { + source := StatementSource{ + Origin: "cmdline", + SQL: " \n\t ", + } + + createTables, alterStatements, err := parseStatementSource(source) + require.NoError(t, err) + assert.Nil(t, createTables) + assert.Nil(t, alterStatements) +} + +func TestParseStatementSource_SingleCreateTable(t *testing.T) { + source := StatementSource{ + Origin: "cmdline", + SQL: "CREATE TABLE users (id BIGINT UNSIGNED PRIMARY KEY, email VARCHAR(255))", + } + + createTables, alterStatements, err := parseStatementSource(source) + require.NoError(t, err) + require.Len(t, createTables, 1) + assert.Empty(t, alterStatements) + assert.Equal(t, "users", createTables[0].GetTableName()) +} + +func TestParseStatementSource_SingleAlterTable(t *testing.T) { + source := StatementSource{ + Origin: "cmdline", + SQL: "ALTER TABLE users ADD COLUMN email VARCHAR(255)", + } + + createTables, alterStatements, err := parseStatementSource(source) + require.NoError(t, err) + assert.Empty(t, createTables) + require.Len(t, alterStatements, 1) + assert.Equal(t, "users", alterStatements[0].Table) +} + +func TestParseStatementSource_MultipleAlterStatements(t *testing.T) { + source := StatementSource{ + Origin: "file:migrations/001.sql", + SQL: ` + ALTER TABLE users ADD COLUMN email VARCHAR(255); + ALTER TABLE users ADD INDEX idx_email (email); + `, + } + + createTables, alterStatements, err := parseStatementSource(source) + require.NoError(t, err) + assert.Empty(t, createTables) + require.Len(t, alterStatements, 2) + assert.Equal(t, "users", alterStatements[0].Table) + assert.Equal(t, "users", alterStatements[1].Table) +} + +func TestParseStatementSource_MixedStatements(t *testing.T) { + source := StatementSource{ + Origin: "file:schema.sql", + SQL: ` + CREATE TABLE users (id BIGINT UNSIGNED PRIMARY KEY); + ALTER TABLE users ADD INDEX idx_email (email); + `, + } + + // Mixed statements should fail due to statement.New() limitation + createTables, alterStatements, err := parseStatementSource(source) + assert.Error(t, err) + assert.Nil(t, createTables) + assert.Nil(t, alterStatements) + assert.Contains(t, err.Error(), "failed to parse file:schema.sql") +} + +func TestParseStatementSource_MultipleCreateStatements(t *testing.T) { + source := StatementSource{ + Origin: "file:schema.sql", + SQL: ` + CREATE TABLE users (id BIGINT UNSIGNED PRIMARY KEY); + CREATE TABLE orders (id BIGINT UNSIGNED PRIMARY KEY); + `, + } + + // Multiple CREATE statements should fail due to statement.New() limitation + createTables, alterStatements, err := parseStatementSource(source) + assert.Error(t, err) + assert.Nil(t, createTables) + assert.Nil(t, alterStatements) + assert.Contains(t, err.Error(), "failed to parse file:schema.sql") +} + +func TestParseStatementSource_InvalidSQL(t *testing.T) { + source := StatementSource{ + Origin: "cmdline", + SQL: "INVALID SQL STATEMENT", + } + + createTables, alterStatements, err := parseStatementSource(source) + assert.Error(t, err) + assert.Nil(t, createTables) + assert.Nil(t, alterStatements) + assert.Contains(t, err.Error(), "failed to parse cmdline") +} + +func TestParseStatementSource_ErrorContext(t *testing.T) { + source := StatementSource{ + Origin: "file:migrations/bad.sql", + SQL: "CREATE TABLE", + } + + _, _, err := parseStatementSource(source) + assert.Error(t, err) + assert.Contains(t, err.Error(), "file:migrations/bad.sql") +} + +func TestParseStatementSource_WithComments(t *testing.T) { + source := StatementSource{ + Origin: "file:schema.sql", + SQL: ` + -- This is a comment + ALTER TABLE users ADD COLUMN email VARCHAR(255); + /* Multi-line + comment */ + ALTER TABLE users ADD INDEX idx_email (email); + `, + } + + createTables, alterStatements, err := parseStatementSource(source) + require.NoError(t, err) + assert.Empty(t, createTables) + require.Len(t, alterStatements, 2) +} diff --git a/pkg/lint/example/example.go b/pkg/lint/example/example.go new file mode 100644 index 00000000..c9055340 --- /dev/null +++ b/pkg/lint/example/example.go @@ -0,0 +1,126 @@ +// Package example provides example linters to demonstrate the linter framework. +// These linters are for demonstration purposes and are not registered by default. +package example + +import ( + "fmt" + "strconv" + + "github.com/block/spirit/pkg/lint" + "github.com/block/spirit/pkg/statement" +) + +// TableNameLengthLinter checks that table names are not too long. +// MySQL has a limit of 64 characters for table names. +type TableNameLengthLinter struct { + maxLength int +} + +// NewTableNameLengthLinter creates a new table name length linter with default configuration. +func NewTableNameLengthLinter() *TableNameLengthLinter { + return &TableNameLengthLinter{ + maxLength: 58, // MySQL's limit is 64 but we use 58 to allow for prefixes/suffixes + } +} + +func (l *TableNameLengthLinter) String() string { + return lint.Stringer(l) +} + +func (l *TableNameLengthLinter) Name() string { + return "table_name_length" +} + +func (l *TableNameLengthLinter) Description() string { + return "Checks that table names do not exceed the configured maximum length (default: 58 characters)" +} + +func (l *TableNameLengthLinter) DefaultConfig() map[string]string { + return map[string]string{ + "maxLength": "58", + } +} + +func (l *TableNameLengthLinter) Configure(config map[string]string) error { + for k, v := range config { + switch k { + case "maxLength": + maxLen, err := strconv.Atoi(v) + if err != nil { + return fmt.Errorf("maxLength must be a valid integer, got %q: %w", v, err) + } + + if maxLen <= 0 { + return fmt.Errorf("maxLength must be positive, got %d", maxLen) + } + + l.maxLength = maxLen + default: + return fmt.Errorf("unknown config key for %s: %s", l.Name(), k) + } + } + + return nil +} + +func (l *TableNameLengthLinter) Lint(createTables []*statement.CreateTable, statements []*statement.AbstractStatement) []lint.Violation { + var violations []lint.Violation + + for _, ct := range createTables { + tableName := ct.GetTableName() + if len(tableName) > l.maxLength { + violations = append(violations, lint.Violation{ + Linter: l, + Severity: lint.SeverityError, + Message: fmt.Sprintf("Table name '%s' exceeds maximum length of %d characters (actual: %d)", tableName, l.maxLength, len(tableName)), + Location: &lint.Location{ + Table: tableName, + }, + }) + } + } + + return violations +} + +// DuplicateColumnLinter checks for duplicate column names in CREATE TABLE statements. +type DuplicateColumnLinter struct{} + +func (l *DuplicateColumnLinter) String() string { + return l.Name() +} + +func (l *DuplicateColumnLinter) Name() string { + return "duplicate_column" +} + +func (l *DuplicateColumnLinter) Description() string { + return "Detects duplicate column definitions in CREATE TABLE statements" +} + +func (l *DuplicateColumnLinter) Lint(createTables []*statement.CreateTable, statements []*statement.AbstractStatement) []lint.Violation { + var violations []lint.Violation + + for _, ct := range createTables { + tableName := ct.GetTableName() + seen := make(map[string]bool) + + for _, col := range ct.GetColumns() { + if seen[col.Name] { + violations = append(violations, lint.Violation{ + Linter: l, + Severity: lint.SeverityError, + Message: fmt.Sprintf("Duplicate column definition: '%s'", col.Name), + Location: &lint.Location{ + Table: tableName, + Column: &col.Name, + }, + }) + } + + seen[col.Name] = true + } + } + + return violations +} diff --git a/pkg/lint/example/example_test.go b/pkg/lint/example/example_test.go new file mode 100644 index 00000000..c14c59f6 --- /dev/null +++ b/pkg/lint/example/example_test.go @@ -0,0 +1,255 @@ +package example + +import ( + "testing" + + "github.com/block/spirit/pkg/lint" + "github.com/block/spirit/pkg/statement" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTableNameLengthLinter_Valid(t *testing.T) { + sql := "CREATE TABLE users (id INT PRIMARY KEY)" + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := NewTableNameLengthLinter() + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + + assert.Empty(t, violations) +} + +func TestTableNameLengthLinter_TooLong(t *testing.T) { + // Create a table name that's 65 characters (exceeds MySQL's 64 char limit) + longName := "this_is_a_very_long_table_name_that_exceeds_the_mysql_limit_abcde" + require.Len(t, longName, 65, "Test setup: name should be 65 chars") + + sql := "CREATE TABLE " + longName + " (id INT PRIMARY KEY)" + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := NewTableNameLengthLinter() + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + + require.Len(t, violations, 1) + assert.Equal(t, "table_name_length", violations[0].Linter.Name()) + assert.Equal(t, lint.SeverityError, violations[0].Severity) + assert.Contains(t, violations[0].Message, "exceeds maximum length") + assert.Equal(t, longName, violations[0].Location.Table) +} + +func TestTableNameLengthLinter_ExactlyAtLimit(t *testing.T) { + // Create a table name that's exactly 58 characters + exactName := "abcdefghij_abcdefghij_abcdefghij_abcdefghij_abcdefghij_abc" + require.Len(t, exactName, 58, "Test setup: name should be exactly 58 chars") + + sql := "CREATE TABLE " + exactName + " (id INT PRIMARY KEY)" + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := NewTableNameLengthLinter() + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + + assert.Empty(t, violations, "58 character name should be allowed") +} + +func TestTableNameLengthLinter_Configure(t *testing.T) { + // Create a table name that's 50 characters + name := "abcdefghij_abcdefghij_abcdefghij_abcdefghij_abcdef" + require.Len(t, name, 50, "Test setup: name should be 50 chars") + + sql := "CREATE TABLE " + name + " (id INT PRIMARY KEY)" + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := NewTableNameLengthLinter() + + // With default config (58), should pass + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + assert.Empty(t, violations) + + // Configure to max length of 40 + err = linter.Configure(map[string]string{"maxLength": "40"}) + require.NoError(t, err) + + // Now should fail + violations = linter.Lint([]*statement.CreateTable{ct}, nil) + require.Len(t, violations, 1) + assert.Contains(t, violations[0].Message, "exceeds maximum length of 40") +} + +func TestTableNameLengthLinter_Configure_InvalidConfig(t *testing.T) { + linter := NewTableNameLengthLinter() + + // Invalid integer + err := linter.Configure(map[string]string{"maxLength": "invalid"}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "must be a valid integer") + + // Zero length + err = linter.Configure(map[string]string{"maxLength": "0"}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "must be positive") + + // Negative length + err = linter.Configure(map[string]string{"maxLength": "-1"}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "must be positive") + + // Unknown key + err = linter.Configure(map[string]string{"unknownKey": "value"}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "unknown config key") +} + +func TestTableNameLengthLinter_DefaultConfig(t *testing.T) { + linter := NewTableNameLengthLinter() + + config := linter.DefaultConfig() + require.NotNil(t, config) + assert.Equal(t, "58", config["maxLength"]) +} + +func TestDuplicateColumnLinter_NoDuplicates(t *testing.T) { + sql := "CREATE TABLE users (id INT, name VARCHAR(100), email VARCHAR(255))" + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := &DuplicateColumnLinter{} + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + + assert.Empty(t, violations) +} + +func TestDuplicateColumnLinter_WithDuplicates(t *testing.T) { + sql := "CREATE TABLE users (id INT, name VARCHAR(100), id INT)" + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := &DuplicateColumnLinter{} + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + + require.Len(t, violations, 1) + assert.Equal(t, "duplicate_column", violations[0].Linter.Name()) + assert.Equal(t, lint.SeverityError, violations[0].Severity) + assert.Contains(t, violations[0].Message, "Duplicate column definition") + assert.Equal(t, "users", violations[0].Location.Table) + assert.NotNil(t, violations[0].Location.Column) + assert.Equal(t, "id", *violations[0].Location.Column) +} + +func TestDuplicateColumnLinter_MultipleDuplicates(t *testing.T) { + sql := "CREATE TABLE test (id INT, name VARCHAR(100), id INT, name VARCHAR(100))" + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := &DuplicateColumnLinter{} + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + + require.Len(t, violations, 2, "Should detect both duplicate columns") + + // Check that both duplicates are reported + duplicates := make(map[string]bool) + + for _, v := range violations { + if v.Location.Column != nil { + duplicates[*v.Location.Column] = true + } + } + + assert.True(t, duplicates["id"]) + assert.True(t, duplicates["name"]) +} + +func TestExampleLinters_Integration(t *testing.T) { + // Reset the global registry + lint.Reset() + + // Register our example linters + lint.Register(NewTableNameLengthLinter()) + lint.Register(&DuplicateColumnLinter{}) + + // Create a table with both issues + longName := "this_is_a_very_long_table_name_that_exceeds_the_mysql_limit_abcde" + require.Len(t, longName, 65, "Test setup: name should be 65 chars") + + sql := "CREATE TABLE " + longName + " (id INT, name VARCHAR(100), id INT)" + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + // Run all linters + violations, err := lint.RunLinters([]*statement.CreateTable{ct}, nil, lint.Config{}) + require.NoError(t, err) + + // Should have violations from both linters + require.Len(t, violations, 2) + + // Check that we have one violation from each linter + linterCounts := make(map[string]int) + for _, v := range violations { + linterCounts[v.Linter.Name()]++ + } + + assert.Equal(t, 1, linterCounts["table_name_length"]) + assert.Equal(t, 1, linterCounts["duplicate_column"]) +} + +func TestExampleLinters_WithConfig(t *testing.T) { + lint.Reset() + + lint.Register(NewTableNameLengthLinter()) + lint.Register(&DuplicateColumnLinter{}) + + longName := "this_is_a_very_long_table_name_that_exceeds_the_mysql_limit_abc" + sql := "CREATE TABLE " + longName + " (id INT, name VARCHAR(100), id INT)" + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + // Disable the table name length linter + violations, err := lint.RunLinters([]*statement.CreateTable{ct}, nil, lint.Config{ + Enabled: map[string]bool{ + "table_name_length": false, + "duplicate_column": true, + }, + }) + require.NoError(t, err) + + // Should only have violation from duplicate_column linter + require.Len(t, violations, 1) + assert.Equal(t, "duplicate_column", violations[0].Linter.Name()) +} + +func TestTableNameLengthLinter_WithConfigSettings(t *testing.T) { + lint.Reset() + + lint.Register(NewTableNameLengthLinter()) + + // Create a table name that's 50 characters + name := "abcdefghij_abcdefghij_abcdefghij_abcdefghij_abcdef" + require.Len(t, name, 50, "Test setup: name should be 50 chars") + + sql := "CREATE TABLE " + name + " (id INT PRIMARY KEY)" + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + // With default config (58), should pass + violations, err := lint.RunLinters([]*statement.CreateTable{ct}, nil, lint.Config{}) + require.NoError(t, err) + assert.Empty(t, violations) + + // Configure max length to 40 via Config.Settings + violations, err = lint.RunLinters([]*statement.CreateTable{ct}, nil, lint.Config{ + Settings: map[string]map[string]string{ + "table_name_length": { + "maxLength": "40", + }, + }, + }) + require.NoError(t, err) + + // Should now have a violation + require.Len(t, violations, 1) + assert.Equal(t, "table_name_length", violations[0].Linter.Name()) + assert.Contains(t, violations[0].Message, "exceeds maximum length of 40") +} diff --git a/pkg/lint/lint.go b/pkg/lint/lint.go new file mode 100644 index 00000000..ea8a3a48 --- /dev/null +++ b/pkg/lint/lint.go @@ -0,0 +1,197 @@ +// Package lint provides a framework for static analysis of MySQL schema definitions +// and DDL statements. It enables validation and best-practice enforcement beyond +// the runtime checks provided by the check package. +// +// The linter framework operates on parsed CREATE TABLE statements rather than live +// database connections. +// +// # Basic Usage +// +// Linters are registered via init() functions and executed via RunLinters(): +// +// package naming +// +// func init() { +// lint.Register(TableNameLinter{}) +// } +// +// // Later, run all linters: +// violations := lint.RunLinters(tables, stmts, config) +// +// # Creating a Linter +// +// To create a custom linter, implement the Linter interface: +// +// type MyLinter struct{} +// +// func (l *MyLinter) Name() string { return "my_linter" } +// func (l *MyLinter) Category() string { return "custom" } +// func (l *MyLinter) Description() string { return "My custom linter" } +// func (l *MyLinter) Lint(createTables []*statement.CreateTable, statements []*statement.AbstractStatement) []lint.Violation { +// // Perform linting logic +// return violations +// } +// +// # Configuration +// +// Linters can be enabled/disabled via the Config.Enabled map: +// +// config := lint.Config{ +// Enabled: map[string]bool{ +// "table_name": true, +// "column_name": false, +// }, +// } +// +// Configurable linters can implement the ConfigurableLinter interface to accept +// custom settings via Config.Settings. Settings must be provided as map[string]string: +// +// config := lint.Config{ +// Settings: map[string]map[string]string{ +// "my_linter": { +// "option1": "value1", +// "option2": "value2", +// }, +// }, +// } +package lint + +import ( + "errors" + "fmt" + "os" + + "github.com/block/spirit/pkg/statement" +) + +// Config holds linter configuration +type Config struct { + // Enabled maps linter names to whether they are enabled + // If a linter is not in this map, it uses its default enabled state + Enabled map[string]bool + + // Settings maps linter names to their configuration as map[string]string + // Each linter's settings are provided as key-value string pairs + Settings map[string]map[string]string +} + +// RunLinters runs all enabled linters and returns any violations found. +// Linters are executed in an undefined order. +// +// A linter is executed if: +// - It is enabled by default (set during Register), AND +// - It is not explicitly disabled in config.Enabled +// +// OR: +// - It is explicitly enabled in config.Enabled +// +// If a linter implements ConfigurableLinter and has settings in config.Settings, +// those settings are applied before running the linter. +func RunLinters(createTables []*statement.CreateTable, alterStatements []*statement.AbstractStatement, config Config) ([]Violation, error) { + var errs []error + + lock.RLock() + defer lock.RUnlock() + + var violations []Violation + + for name, linter := range linters { + // Check if linter is explicitly disabled in config + if enabled, ok := config.Enabled[name]; ok && !enabled { + continue + } + + // Check if linter is explicitly enabled in config + explicitlyEnabled := false + if enabled, ok := config.Enabled[name]; ok && enabled { + explicitlyEnabled = true + } + + // Skip if not enabled by default and not explicitly enabled + if !linter.enabled && !explicitlyEnabled { + continue + } + + // Apply configuration if available + if configurableLinter, ok := linter.l.(ConfigurableLinter); ok { + // Start with default config + defaultConfig := configurableLinter.DefaultConfig() + + // Merge user settings with defaults (user settings override defaults) + finalConfig := make(map[string]string) + for k, v := range defaultConfig { + finalConfig[k] = v + } + + if settings, ok := config.Settings[name]; ok { + for k, v := range settings { + finalConfig[k] = v + } + } + + // Apply the merged configuration + err := configurableLinter.Configure(finalConfig) + if err != nil { + // Configuration error - skip this linter + fmt.Fprintf(os.Stderr, "Error configuring %s: %s\n", name, err) + errs = append(errs, err) + + continue + } + } + + // Run the linter + lintViolations := linter.l.Lint(createTables, alterStatements) + violations = append(violations, lintViolations...) + } + + return violations, errors.Join(errs...) +} + +// HasErrors returns true if any violations have ERROR severity. +func HasErrors(violations []Violation) bool { + for _, v := range violations { + if v.Severity == SeverityError { + return true + } + } + + return false +} + +// HasWarnings returns true if any violations have WARNING severity. +func HasWarnings(violations []Violation) bool { + for _, v := range violations { + if v.Severity == SeverityWarning { + return true + } + } + + return false +} + +// FilterBySeverity returns only violations with the specified severity. +func FilterBySeverity(violations []Violation, severity Severity) []Violation { + var filtered []Violation + + for _, v := range violations { + if v.Severity == severity { + filtered = append(filtered, v) + } + } + + return filtered +} + +// FilterByLinter returns only violations from the specified linter. +func FilterByLinter(violations []Violation, linterName string) []Violation { + var filtered []Violation + + for _, v := range violations { + if v.Linter.Name() == linterName { + filtered = append(filtered, v) + } + } + + return filtered +} diff --git a/pkg/lint/lint_invisible_index.go b/pkg/lint/lint_invisible_index.go new file mode 100644 index 00000000..30f24074 --- /dev/null +++ b/pkg/lint/lint_invisible_index.go @@ -0,0 +1,119 @@ +package lint + +import ( + "fmt" + + "github.com/block/spirit/pkg/statement" + "github.com/pingcap/tidb/pkg/parser/ast" +) + +func init() { + Register(&InvisibleIndexBeforeDropLinter{}) +} + +// InvisibleIndexBeforeDropLinter checks that indexes are made invisible before dropping. +// This is a safety practice to ensure the index is not needed before permanently removing it. +type InvisibleIndexBeforeDropLinter struct { + raiseError bool +} + +func (l *InvisibleIndexBeforeDropLinter) String() string { + return Stringer(l) +} + +func (l *InvisibleIndexBeforeDropLinter) Name() string { + return "invisible_index_before_drop" +} + +func (l *InvisibleIndexBeforeDropLinter) Description() string { + return "Requires indexes to be made invisible before dropping them as a safety measure" +} + +func (l *InvisibleIndexBeforeDropLinter) Configure(config map[string]string) error { + for k, v := range config { + switch k { + case "raiseError": + boolVal, err := ConfigBool(v, k) + if err != nil { + return err + } + + l.raiseError = boolVal + default: + return fmt.Errorf("unknown config key for %s: %s", l.Name(), k) + } + } + + return nil +} + +func (l *InvisibleIndexBeforeDropLinter) DefaultConfig() map[string]string { + return map[string]string{ + "raiseError": "false", + } +} + +var _ ConfigurableLinter = &InvisibleIndexBeforeDropLinter{} + +func (l *InvisibleIndexBeforeDropLinter) Lint(createTables []*statement.CreateTable, statements []*statement.AbstractStatement) []Violation { + severity := SeverityWarning + if l.raiseError { + severity = SeverityError + } + + var violations []Violation + + for _, stmt := range statements { + // Only check ALTER TABLE statements + if !stmt.IsAlterTable() { + continue + } + + alterStmt, ok := (*stmt.StmtNode).(*ast.AlterTableStmt) + if !ok { + continue + } + + tableName := stmt.Table + + // Check each ALTER specification + for _, spec := range alterStmt.Specs { + if spec.Tp != ast.AlterTableDropIndex { + continue + } + + indexName := spec.Name + + madeInvisible := false + // If not made invisible in this ALTER, check if it's invisible in the CREATE TABLE + if len(createTables) > 0 { + for _, ct := range createTables { + if ct.GetTableName() == tableName { + for _, idx := range ct.GetIndexes() { + if idx.Name == indexName && idx.Invisible != nil && *idx.Invisible { + madeInvisible = true + break + } + } + } + } + } + + if !madeInvisible { + suggestion := fmt.Sprintf("First make the index invisible: ALTER TABLE %s ALTER INDEX %s INVISIBLE", tableName, indexName) + violations = append(violations, Violation{ + Linter: l, + Severity: severity, + Message: fmt.Sprintf("Index '%s' should be made invisible before dropping to ensure it's not needed", indexName), + Location: &Location{ + Table: tableName, + Index: &indexName, + }, + Suggestion: &suggestion, + }) + } + } + } + + return violations +} diff --git a/pkg/lint/lint_invisible_index_test.go b/pkg/lint/lint_invisible_index_test.go new file mode 100644 index 00000000..1606d0f3 --- /dev/null +++ b/pkg/lint/lint_invisible_index_test.go @@ -0,0 +1,452 @@ +package lint + +import ( + "testing" + + "github.com/block/spirit/pkg/statement" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestInvisibleIndexBeforeDropLinter_DropWithoutInvisible(t *testing.T) { + sql := "ALTER TABLE users DROP INDEX idx_email" + stmts, err := statement.New(sql) + require.NoError(t, err) + require.Len(t, stmts, 1) + + linter := &InvisibleIndexBeforeDropLinter{} + violations := linter.Lint(nil, stmts) + + require.Len(t, violations, 1) + assert.Equal(t, "invisible_index_before_drop", violations[0].Linter.Name()) + assert.Equal(t, SeverityWarning, violations[0].Severity) + assert.Contains(t, violations[0].Message, "should be made invisible before dropping") + assert.Equal(t, "users", violations[0].Location.Table) + assert.NotNil(t, violations[0].Location.Index) + assert.Equal(t, "idx_email", *violations[0].Location.Index) + assert.NotNil(t, violations[0].Suggestion) + assert.Contains(t, *violations[0].Suggestion, "ALTER INDEX idx_email INVISIBLE") +} + +func TestInvisibleIndexBeforeDropLinter_DropAfterInvisibleInSameAlter(t *testing.T) { + sql := "ALTER TABLE users ALTER INDEX idx_email INVISIBLE, DROP INDEX idx_email" + stmts, err := statement.New(sql) + require.NoError(t, err) + require.Len(t, stmts, 1) + + linter := &InvisibleIndexBeforeDropLinter{} + violations := linter.Lint(nil, stmts) + + // Making an index invisible in the same ALTER statement where you drop it is obviously not good enough + assert.Len(t, violations, 1) + assert.IsType(t, &InvisibleIndexBeforeDropLinter{}, violations[0].Linter) + assert.Equal(t, "invisible_index_before_drop", violations[0].Linter.Name()) +} + +func TestInvisibleIndexBeforeDropLinter_DropAlreadyInvisibleIndex(t *testing.T) { + // Create a table with an invisible index + createSQL := `CREATE TABLE users ( + id INT PRIMARY KEY, + email VARCHAR(255), + INDEX idx_email (email) INVISIBLE + )` + ct, err := statement.ParseCreateTable(createSQL) + require.NoError(t, err) + + // Drop the invisible index + alterSQL := "ALTER TABLE users DROP INDEX idx_email" + stmts, err := statement.New(alterSQL) + require.NoError(t, err) + require.Len(t, stmts, 1) + + linter := &InvisibleIndexBeforeDropLinter{} + violations := linter.Lint([]*statement.CreateTable{ct}, stmts) + + // Should not have violations since index is already invisible + assert.Empty(t, violations) +} + +func TestInvisibleIndexBeforeDropLinter_DropVisibleIndex(t *testing.T) { + // Create a table with a visible index + createSQL := `CREATE TABLE users ( + id INT PRIMARY KEY, + email VARCHAR(255), + INDEX idx_email (email) + )` + ct, err := statement.ParseCreateTable(createSQL) + require.NoError(t, err) + + // Drop the visible index + alterSQL := "ALTER TABLE users DROP INDEX idx_email" + stmts, err := statement.New(alterSQL) + require.NoError(t, err) + require.Len(t, stmts, 1) + + linter := &InvisibleIndexBeforeDropLinter{} + violations := linter.Lint([]*statement.CreateTable{ct}, stmts) + + // Should have violation since index is visible + require.Len(t, violations, 1) + assert.Equal(t, "invisible_index_before_drop", violations[0].Linter.Name()) +} + +func TestInvisibleIndexBeforeDropLinter_MultipleDrops(t *testing.T) { + sql := "ALTER TABLE users DROP INDEX idx_email, DROP INDEX idx_name" + stmts, err := statement.New(sql) + require.NoError(t, err) + require.Len(t, stmts, 1) + + linter := &InvisibleIndexBeforeDropLinter{} + violations := linter.Lint(nil, stmts) + + // Should have violations for both indexes + require.Len(t, violations, 2) + + indexNames := make(map[string]bool) + + for _, v := range violations { + assert.Equal(t, "invisible_index_before_drop", v.Linter.Name()) + assert.Equal(t, SeverityWarning, v.Severity) + + if v.Location.Index != nil { + indexNames[*v.Location.Index] = true + } + } + + assert.True(t, indexNames["idx_email"]) + assert.True(t, indexNames["idx_name"]) +} + +func TestInvisibleIndexBeforeDropLinter_NonAlterStatement(t *testing.T) { + sql := "CREATE TABLE users (id INT PRIMARY KEY)" + stmts, err := statement.New(sql) + require.NoError(t, err) + require.Len(t, stmts, 1) + + linter := &InvisibleIndexBeforeDropLinter{} + violations := linter.Lint(nil, stmts) + + // Should not have violations for non-ALTER statements + assert.Empty(t, violations) +} + +func TestInvisibleIndexBeforeDropLinter_AlterWithoutDrop(t *testing.T) { + sql := "ALTER TABLE users ADD COLUMN age INT" + stmts, err := statement.New(sql) + require.NoError(t, err) + require.Len(t, stmts, 1) + + linter := &InvisibleIndexBeforeDropLinter{} + violations := linter.Lint(nil, stmts) + + // Should not have violations for ALTER without DROP INDEX + assert.Empty(t, violations) +} + +func TestInvisibleIndexBeforeDropLinter_Integration(t *testing.T) { + // Reset registry and register linter + Reset() + Register(&InvisibleIndexBeforeDropLinter{}) + + sql := "ALTER TABLE users DROP INDEX idx_email" + stmts, err := statement.New(sql) + require.NoError(t, err) + + violations, err := RunLinters(nil, stmts, Config{}) + require.NoError(t, err) + + require.Len(t, violations, 1) + assert.Equal(t, "invisible_index_before_drop", violations[0].Linter.Name()) +} + +func TestInvisibleIndexBeforeDropLinter_IntegrationDisabled(t *testing.T) { + // Reset registry and register linter + Reset() + Register(&InvisibleIndexBeforeDropLinter{}) + + sql := "ALTER TABLE users DROP INDEX idx_email" + stmts, err := statement.New(sql) + require.NoError(t, err) + + // Disable the linter + violations, err := RunLinters(nil, stmts, Config{ + Enabled: map[string]bool{ + "invisible_index_before_drop": false, + }, + }) + require.NoError(t, err) + + // Should not have violations when disabled + assert.Empty(t, violations) +} + +func TestInvisibleIndexBeforeDropLinter_Metadata(t *testing.T) { + linter := &InvisibleIndexBeforeDropLinter{} + + assert.Equal(t, "invisible_index_before_drop", linter.Name()) + assert.NotEmpty(t, linter.Description()) +} + +// Configuration Tests + +func TestInvisibleIndexBeforeDropLinter_Configure_ValidRaiseErrorTrue(t *testing.T) { + linter := &InvisibleIndexBeforeDropLinter{} + + config := map[string]string{ + "raiseError": "true", + } + + err := linter.Configure(config) + require.NoError(t, err) + assert.True(t, linter.raiseError) +} + +func TestInvisibleIndexBeforeDropLinter_Configure_ValidRaiseErrorFalse(t *testing.T) { + linter := &InvisibleIndexBeforeDropLinter{} + + config := map[string]string{ + "raiseError": "false", + } + + err := linter.Configure(config) + require.NoError(t, err) + assert.False(t, linter.raiseError) +} + +func TestInvisibleIndexBeforeDropLinter_Configure_CaseInsensitive(t *testing.T) { + tests := []struct { + name string + value string + expected bool + }{ + {"lowercase true", "true", true}, + {"uppercase TRUE", "TRUE", true}, + {"mixed True", "True", true}, + {"lowercase false", "false", false}, + {"uppercase FALSE", "FALSE", false}, + {"mixed False", "False", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + linter := &InvisibleIndexBeforeDropLinter{} + + config := map[string]string{ + "raiseError": tt.value, + } + + err := linter.Configure(config) + require.NoError(t, err) + assert.Equal(t, tt.expected, linter.raiseError) + }) + } +} + +func TestInvisibleIndexBeforeDropLinter_Configure_InvalidRaiseErrorValue(t *testing.T) { + linter := &InvisibleIndexBeforeDropLinter{} + + config := map[string]string{ + "raiseError": "invalid", + } + + err := linter.Configure(config) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid value for raiseError") +} + +func TestInvisibleIndexBeforeDropLinter_Configure_UnknownKey(t *testing.T) { + linter := &InvisibleIndexBeforeDropLinter{} + + config := map[string]string{ + "unknownKey": "value", + } + + err := linter.Configure(config) + require.Error(t, err) + assert.Contains(t, err.Error(), "unknown config key") + assert.Contains(t, err.Error(), "unknownKey") +} + +func TestInvisibleIndexBeforeDropLinter_Configure_MultipleKeys(t *testing.T) { + linter := &InvisibleIndexBeforeDropLinter{} + + config := map[string]string{ + "raiseError": "false", + "unknownKey": "value", + } + + err := linter.Configure(config) + require.Error(t, err) + assert.Contains(t, err.Error(), "unknown config key") +} + +func TestInvisibleIndexBeforeDropLinter_DefaultConfig(t *testing.T) { + linter := &InvisibleIndexBeforeDropLinter{} + + defaultConfig := linter.DefaultConfig() + require.NotNil(t, defaultConfig) + assert.Equal(t, "false", defaultConfig["raiseError"]) +} + +// Functional Tests with Configuration + +func TestInvisibleIndexBeforeDropLinter_RaiseErrorTrue_ProducesError(t *testing.T) { + sql := "ALTER TABLE users DROP INDEX idx_email" + stmts, err := statement.New(sql) + require.NoError(t, err) + require.Len(t, stmts, 1) + + linter := &InvisibleIndexBeforeDropLinter{} + err = linter.Configure(map[string]string{"raiseError": "true"}) + require.NoError(t, err) + + violations := linter.Lint(nil, stmts) + + require.Len(t, violations, 1) + assert.Equal(t, SeverityError, violations[0].Severity) +} + +func TestInvisibleIndexBeforeDropLinter_RaiseErrorFalse_ProducesWarning(t *testing.T) { + sql := "ALTER TABLE users DROP INDEX idx_email" + stmts, err := statement.New(sql) + require.NoError(t, err) + require.Len(t, stmts, 1) + + linter := &InvisibleIndexBeforeDropLinter{} + err = linter.Configure(map[string]string{"raiseError": "false"}) + require.NoError(t, err) + + violations := linter.Lint(nil, stmts) + + require.Len(t, violations, 1) + assert.Equal(t, SeverityWarning, violations[0].Severity) +} + +func TestInvisibleIndexBeforeDropLinter_DefaultBehavior(t *testing.T) { + // Without configuration, default should be raiseError=true (warning) + sql := "ALTER TABLE users DROP INDEX idx_email" + stmts, err := statement.New(sql) + require.NoError(t, err) + require.Len(t, stmts, 1) + + linter := &InvisibleIndexBeforeDropLinter{} + violations := linter.Lint(nil, stmts) + + require.Len(t, violations, 1) + // Default behavior is warning (raiseError defaults to false in struct) + assert.Equal(t, SeverityWarning, violations[0].Severity) +} + +// Integration Tests with RunLinters + +func TestInvisibleIndexBeforeDropLinter_IntegrationWithConfig_RaiseErrorTrue(t *testing.T) { + Reset() + Register(&InvisibleIndexBeforeDropLinter{}) + + sql := "ALTER TABLE users DROP INDEX idx_email" + stmts, err := statement.New(sql) + require.NoError(t, err) + + violations, err := RunLinters(nil, stmts, Config{ + Settings: map[string]map[string]string{ + "invisible_index_before_drop": { + "raiseError": "true", + }, + }, + }) + require.NoError(t, err) + + require.Len(t, violations, 1) + assert.Equal(t, SeverityError, violations[0].Severity) +} + +func TestInvisibleIndexBeforeDropLinter_IntegrationWithConfig_RaiseErrorFalse(t *testing.T) { + Reset() + Register(&InvisibleIndexBeforeDropLinter{}) + + sql := "ALTER TABLE users DROP INDEX idx_email" + stmts, err := statement.New(sql) + require.NoError(t, err) + + violations, err := RunLinters(nil, stmts, Config{ + Settings: map[string]map[string]string{ + "invisible_index_before_drop": { + "raiseError": "false", + }, + }, + }) + require.NoError(t, err) + + require.Len(t, violations, 1) + assert.Equal(t, SeverityWarning, violations[0].Severity) +} + +func TestInvisibleIndexBeforeDropLinter_IntegrationWithConfig_InvalidConfig(t *testing.T) { + Reset() + Register(&InvisibleIndexBeforeDropLinter{}) + + sql := "ALTER TABLE users DROP INDEX idx_email" + stmts, err := statement.New(sql) + require.NoError(t, err) + + // Invalid configuration should result in error + violations, err := RunLinters(nil, stmts, Config{ + Settings: map[string]map[string]string{ + "invisible_index_before_drop": { + "raiseError": "invalid_value", + }, + }, + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid value for raiseError") + // Linter should be skipped due to configuration error + assert.Empty(t, violations) +} + +func TestInvisibleIndexBeforeDropLinter_IntegrationWithConfig_DisabledLinter(t *testing.T) { + Reset() + Register(&InvisibleIndexBeforeDropLinter{}) + + sql := "ALTER TABLE users DROP INDEX idx_email" + stmts, err := statement.New(sql) + require.NoError(t, err) + + // Even with configuration, disabled linter should not run + violations, err := RunLinters(nil, stmts, Config{ + Enabled: map[string]bool{ + "invisible_index_before_drop": false, + }, + Settings: map[string]map[string]string{ + "invisible_index_before_drop": { + "raiseError": "true", + }, + }, + }) + require.NoError(t, err) + assert.Empty(t, violations) +} + +func TestInvisibleIndexBeforeDropLinter_IntegrationWithConfig_EnabledWithConfig(t *testing.T) { + Reset() + Register(&InvisibleIndexBeforeDropLinter{}) + + sql := "ALTER TABLE users DROP INDEX idx_email" + stmts, err := statement.New(sql) + require.NoError(t, err) + + // Explicitly enabled with custom configuration + violations, err := RunLinters(nil, stmts, Config{ + Enabled: map[string]bool{ + "invisible_index_before_drop": true, + }, + Settings: map[string]map[string]string{ + "invisible_index_before_drop": { + "raiseError": "false", + }, + }, + }) + require.NoError(t, err) + + require.Len(t, violations, 1) + assert.Equal(t, SeverityWarning, violations[0].Severity) +} diff --git a/pkg/lint/lint_multiple_alter.go b/pkg/lint/lint_multiple_alter.go new file mode 100644 index 00000000..861fb4d2 --- /dev/null +++ b/pkg/lint/lint_multiple_alter.go @@ -0,0 +1,97 @@ +package lint + +import ( + "fmt" + "strings" + + "github.com/block/spirit/pkg/statement" +) + +func init() { + Register(&MultipleAlterTableLinter{}) +} + +// MultipleAlterTableLinter checks for multiple ALTER TABLE statements affecting the same table. +// Multiple ALTER TABLE statements on the same table should be combined into a single statement +// for better performance, fewer table rebuilds, and decreased danger of bad intermediate state. +type MultipleAlterTableLinter struct{} + +func (l *MultipleAlterTableLinter) String() string { + return Stringer(l) +} + +func (l *MultipleAlterTableLinter) Name() string { + return "multiple_alter_table" +} + +func (l *MultipleAlterTableLinter) Description() string { + return "Detects multiple ALTER TABLE statements on the same table that could be combined" +} + +func (l *MultipleAlterTableLinter) Lint(_ []*statement.CreateTable, statements []*statement.AbstractStatement) []Violation { + var violations []Violation + + // Count ALTER TABLE statements per table + tableAlterCounts := make(map[string][]int) // table name -> statement indices + + for i, stmt := range statements { + if !stmt.IsAlterTable() { + continue + } + + tableName := stmt.Table + if tableName == "" { + continue + } + + tableAlterCounts[tableName] = append(tableAlterCounts[tableName], i) + } + + // Report violations for tables with multiple ALTER statements + for tableName, indices := range tableAlterCounts { + if len(indices) < 2 { + continue + } + + // Build a list of the ALTER operations for the suggestion + var operations []string + + for _, idx := range indices { + if statements[idx].Alter != "" { + operations = append(operations, statements[idx].Alter) + } + } + + suggestion := "" + if len(operations) > 0 { + suggestion = fmt.Sprintf("Combine into: ALTER TABLE %s %s", + tableName, + strings.Join(operations, ", ")) + } + + message := fmt.Sprintf("Table '%s' has %d separate ALTER TABLE statements that could be combined into one for better performance", + tableName, + len(indices)) + + violation := Violation{ + Linter: l, + Severity: SeverityInfo, + Message: message, + Location: &Location{ + Table: tableName, + }, + Context: map[string]any{ + "alter_count": len(indices), + "statement_indices": indices, + }, + } + + if suggestion != "" { + violation.Suggestion = &suggestion + } + + violations = append(violations, violation) + } + + return violations +} diff --git a/pkg/lint/lint_multiple_alter_test.go b/pkg/lint/lint_multiple_alter_test.go new file mode 100644 index 00000000..4fd46c94 --- /dev/null +++ b/pkg/lint/lint_multiple_alter_test.go @@ -0,0 +1,271 @@ +package lint + +import ( + "testing" + + "github.com/block/spirit/pkg/statement" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMultipleAlterTableLinter_SingleAlter(t *testing.T) { + sql := "ALTER TABLE users ADD COLUMN age INT" + stmts, err := statement.New(sql) + require.NoError(t, err) + + linter := &MultipleAlterTableLinter{} + violations := linter.Lint(nil, stmts) + + // No violation for single ALTER + assert.Empty(t, violations) +} + +func TestMultipleAlterTableLinter_TwoAltersOnSameTable(t *testing.T) { + sql := `ALTER TABLE users ADD COLUMN age INT; + ALTER TABLE users ADD INDEX idx_age (age)` + stmts, err := statement.New(sql) + require.NoError(t, err) + require.Len(t, stmts, 2) + + linter := &MultipleAlterTableLinter{} + violations := linter.Lint(nil, stmts) + + require.Len(t, violations, 1) + assert.Equal(t, "multiple_alter_table", violations[0].Linter.Name()) + assert.Equal(t, SeverityInfo, violations[0].Severity) + assert.Contains(t, violations[0].Message, "2 separate ALTER TABLE statements") + assert.Equal(t, "users", violations[0].Location.Table) + assert.NotNil(t, violations[0].Suggestion) + assert.Contains(t, *violations[0].Suggestion, "Combine into") + assert.Contains(t, *violations[0].Suggestion, "ADD COLUMN `age` INT") + assert.Contains(t, *violations[0].Suggestion, "ADD INDEX `idx_age`(`age`)") +} + +func TestMultipleAlterTableLinter_ThreeAltersOnSameTable(t *testing.T) { + sql := `ALTER TABLE users ADD COLUMN age INT; + ALTER TABLE users ADD COLUMN email VARCHAR(255); + ALTER TABLE users ADD INDEX idx_email (email)` + stmts, err := statement.New(sql) + require.NoError(t, err) + require.Len(t, stmts, 3) + + linter := &MultipleAlterTableLinter{} + violations := linter.Lint(nil, stmts) + + require.Len(t, violations, 1) + assert.Equal(t, "multiple_alter_table", violations[0].Linter.Name()) + assert.Contains(t, violations[0].Message, "3 separate ALTER TABLE statements") + + // Check context + assert.NotNil(t, violations[0].Context) + assert.Equal(t, 3, violations[0].Context["alter_count"]) + indices, ok := violations[0].Context["statement_indices"].([]int) + require.True(t, ok) + assert.Len(t, indices, 3) +} + +func TestMultipleAlterTableLinter_AltersOnDifferentTables(t *testing.T) { + sql := `ALTER TABLE users ADD COLUMN age INT; + ALTER TABLE orders ADD COLUMN total DECIMAL(10,2)` + stmts, err := statement.New(sql) + require.NoError(t, err) + require.Len(t, stmts, 2) + + linter := &MultipleAlterTableLinter{} + violations := linter.Lint(nil, stmts) + + // No violation when altering different tables + assert.Empty(t, violations) +} + +func TestMultipleAlterTableLinter_MixedStatements(t *testing.T) { + // Parse statements individually since statement.New() doesn't support mixed types + var stmts []*statement.AbstractStatement + + sql1 := "CREATE TABLE products (id INT PRIMARY KEY)" + s1, err := statement.New(sql1) + require.NoError(t, err) + + stmts = append(stmts, s1...) + + sql2 := "ALTER TABLE users ADD COLUMN age INT" + s2, err := statement.New(sql2) + require.NoError(t, err) + + stmts = append(stmts, s2...) + + sql3 := "ALTER TABLE users ADD INDEX idx_age (age)" + s3, err := statement.New(sql3) + require.NoError(t, err) + + stmts = append(stmts, s3...) + + sql4 := "DROP TABLE old_table" + s4, err := statement.New(sql4) + require.NoError(t, err) + + stmts = append(stmts, s4...) + + require.Len(t, stmts, 4) + + linter := &MultipleAlterTableLinter{} + violations := linter.Lint(nil, stmts) + + // Should only detect the two ALTER TABLE users statements + require.Len(t, violations, 1) + assert.Equal(t, "users", violations[0].Location.Table) + assert.Contains(t, violations[0].Message, "2 separate ALTER TABLE statements") +} + +func TestMultipleAlterTableLinter_MultipleTables(t *testing.T) { + sql := `ALTER TABLE users ADD COLUMN age INT; + ALTER TABLE users ADD INDEX idx_age (age); + ALTER TABLE orders ADD COLUMN status VARCHAR(50); + ALTER TABLE orders ADD INDEX idx_status (status)` + stmts, err := statement.New(sql) + require.NoError(t, err) + require.Len(t, stmts, 4) + + linter := &MultipleAlterTableLinter{} + violations := linter.Lint(nil, stmts) + + // Should detect violations for both tables + require.Len(t, violations, 2) + + tableNames := make(map[string]bool) + for _, v := range violations { + tableNames[v.Location.Table] = true + assert.Equal(t, "multiple_alter_table", v.Linter.Name()) + assert.Contains(t, v.Message, "2 separate ALTER TABLE statements") + } + + assert.True(t, tableNames["users"]) + assert.True(t, tableNames["orders"]) +} + +func TestMultipleAlterTableLinter_ComplexAlters(t *testing.T) { + sql := `ALTER TABLE users ADD COLUMN age INT, ADD COLUMN email VARCHAR(255); + ALTER TABLE users DROP COLUMN old_field` + stmts, err := statement.New(sql) + require.NoError(t, err) + require.Len(t, stmts, 2) + + linter := &MultipleAlterTableLinter{} + violations := linter.Lint(nil, stmts) + + require.Len(t, violations, 1) + assert.Equal(t, "users", violations[0].Location.Table) + + // Suggestion should include both ALTER operations (with backticks as generated by parser) + assert.NotNil(t, violations[0].Suggestion) + suggestion := *violations[0].Suggestion + assert.Contains(t, suggestion, "ADD COLUMN `age` INT, ADD COLUMN `email` VARCHAR(255)") + assert.Contains(t, suggestion, "DROP COLUMN `old_field`") +} + +func TestMultipleAlterTableLinter_NonAlterStatements(t *testing.T) { + // Parse statements individually since statement.New() doesn't support mixed types + var stmts []*statement.AbstractStatement + + sql1 := "CREATE TABLE users (id INT PRIMARY KEY)" + s1, err := statement.New(sql1) + require.NoError(t, err) + + stmts = append(stmts, s1...) + + sql2 := "CREATE TABLE orders (id INT PRIMARY KEY)" + s2, err := statement.New(sql2) + require.NoError(t, err) + + stmts = append(stmts, s2...) + + linter := &MultipleAlterTableLinter{} + violations := linter.Lint(nil, stmts) + + // No violations for non-ALTER statements + assert.Empty(t, violations) +} + +func TestMultipleAlterTableLinter_EmptyStatements(t *testing.T) { + linter := &MultipleAlterTableLinter{} + violations := linter.Lint(nil, nil) + + assert.Empty(t, violations) +} + +func TestMultipleAlterTableLinter_Integration(t *testing.T) { + Reset() + Register(&MultipleAlterTableLinter{}) + + sql := `ALTER TABLE users ADD COLUMN age INT; + ALTER TABLE users ADD INDEX idx_age (age)` + stmts, err := statement.New(sql) + require.NoError(t, err) + + violations, err := RunLinters(nil, stmts, Config{}) + require.NoError(t, err) + + require.Len(t, violations, 1) + assert.Equal(t, "multiple_alter_table", violations[0].Linter.Name()) +} + +func TestMultipleAlterTableLinter_IntegrationDisabled(t *testing.T) { + Reset() + Register(&MultipleAlterTableLinter{}) + + sql := `ALTER TABLE users ADD COLUMN age INT; + ALTER TABLE users ADD INDEX idx_age (age)` + stmts, err := statement.New(sql) + require.NoError(t, err) + + violations, err := RunLinters(nil, stmts, Config{ + Enabled: map[string]bool{ + "multiple_alter_table": false, + }, + }) + require.NoError(t, err) + + assert.Empty(t, violations) +} + +func TestMultipleAlterTableLinter_Metadata(t *testing.T) { + linter := &MultipleAlterTableLinter{} + + assert.Equal(t, "multiple_alter_table", linter.Name()) + assert.NotEmpty(t, linter.Description()) +} + +func TestMultipleAlterTableLinter_ContextData(t *testing.T) { + sql := `ALTER TABLE users ADD COLUMN age INT; + ALTER TABLE users ADD COLUMN email VARCHAR(255); + ALTER TABLE users ADD INDEX idx_email (email)` + stmts, err := statement.New(sql) + require.NoError(t, err) + + linter := &MultipleAlterTableLinter{} + violations := linter.Lint(nil, stmts) + + require.Len(t, violations, 1) + + // Verify context contains useful debugging info + assert.NotNil(t, violations[0].Context) + assert.Equal(t, 3, violations[0].Context["alter_count"]) + + indices, ok := violations[0].Context["statement_indices"].([]int) + require.True(t, ok) + assert.Equal(t, []int{0, 1, 2}, indices) +} + +func TestMultipleAlterTableLinter_SeverityIsInfo(t *testing.T) { + sql := `ALTER TABLE users ADD COLUMN age INT; + ALTER TABLE users ADD INDEX idx_age (age)` + stmts, err := statement.New(sql) + require.NoError(t, err) + + linter := &MultipleAlterTableLinter{} + violations := linter.Lint(nil, stmts) + + require.Len(t, violations, 1) + // This is INFO level because it's an optimization suggestion, not an error + assert.Equal(t, SeverityInfo, violations[0].Severity) +} diff --git a/pkg/lint/lint_primary_key_type.go b/pkg/lint/lint_primary_key_type.go new file mode 100644 index 00000000..8c38300c --- /dev/null +++ b/pkg/lint/lint_primary_key_type.go @@ -0,0 +1,147 @@ +package lint + +import ( + "fmt" + "strings" + + "github.com/block/spirit/pkg/statement" + "github.com/pingcap/tidb/pkg/parser/mysql" +) + +func init() { + Register(&PrimaryKeyTypeLinter{}) +} + +// PrimaryKeyTypeLinter checks that primary keys use appropriate data types. +// Primary keys should be BIGINT (preferably UNSIGNED) or BINARY/VARBINARY. +// Other types are flagged as errors, and signed BIGINT is flagged as a warning. +type PrimaryKeyTypeLinter struct{} + +func (l *PrimaryKeyTypeLinter) String() string { + return Stringer(l) +} + +func (l *PrimaryKeyTypeLinter) Name() string { + return "primary_key_type" +} + +func (l *PrimaryKeyTypeLinter) Description() string { + return "Ensures primary keys use BIGINT (preferably UNSIGNED) or BINARY/VARBINARY types" +} + +func (l *PrimaryKeyTypeLinter) Lint(createTables []*statement.CreateTable, _ []*statement.AbstractStatement) []Violation { + var violations []Violation + + for _, ct := range createTables { + tableName := ct.GetTableName() + + // Get primary key columns from indexes (this includes both table-level and column-level PRIMARY KEY) + pkColumns := l.getPrimaryKeyColumnsFromIndexes(ct) + if len(pkColumns) == 0 { + continue + } + + // Check each primary key column's type + for _, pkCol := range pkColumns { + column := ct.GetColumns().ByName(pkCol) + if column == nil { + continue + } + + violation := l.checkColumnType(tableName, column) + if violation != nil { + violations = append(violations, *violation) + } + } + } + + return violations +} + +// getPrimaryKeyColumnsFromIndexes returns the names of columns that are part of the primary key +// by checking the indexes (which includes both table-level and column-level PRIMARY KEY definitions) +func (l *PrimaryKeyTypeLinter) getPrimaryKeyColumnsFromIndexes(ct *statement.CreateTable) []string { + var pkColumns []string + + // Check for PRIMARY KEY in indexes + for _, index := range ct.GetIndexes() { + if index.Type == "PRIMARY KEY" { + pkColumns = append(pkColumns, index.Columns...) + break // There can only be one PRIMARY KEY + } + } + + return pkColumns +} + +// checkColumnType checks if a primary key column has an appropriate type +func (l *PrimaryKeyTypeLinter) checkColumnType(tableName string, column *statement.Column) *Violation { + columnType := strings.ToUpper(column.Type) + + // Check for BIGINT + if strings.HasPrefix(columnType, "BIGINT") { + // Check if it's unsigned (either in type string or Unsigned field) + isUnsigned := column.Unsigned != nil && *column.Unsigned + if !isUnsigned { + isUnsigned = strings.Contains(columnType, "UNSIGNED") + } + + if isUnsigned { + // BIGINT UNSIGNED is ideal - no violation + return nil + } + + // BIGINT without UNSIGNED is a warning + suggestion := fmt.Sprintf("Consider using BIGINT UNSIGNED for column '%s' to avoid negative values and increase range", column.Name) + + return &Violation{ + Linter: l, + Severity: SeverityWarning, + Message: fmt.Sprintf("Primary key column '%s' uses signed BIGINT; UNSIGNED is preferred", column.Name), + Location: &Location{ + Table: tableName, + Column: &column.Name, + }, + Suggestion: &suggestion, + } + } + + // Check for BINARY/VARBINARY + // Note: The parser returns "char" for BINARY and "varchar" for VARBINARY, + // so we need to check the raw type and binary flag + if l.isBinaryType(column) { + // BINARY/VARBINARY is acceptable - no violation + return nil + } + + // Any other type is an error + suggestion := fmt.Sprintf("Change column '%s' to BIGINT UNSIGNED or BINARY/VARBINARY", column.Name) + + return &Violation{ + Linter: l, + Severity: SeverityError, + Message: fmt.Sprintf("Primary key column '%s' has type '%s'; must be BIGINT or BINARY/VARBINARY", column.Name, column.Type), + Location: &Location{ + Table: tableName, + Column: &column.Name, + }, + Suggestion: &suggestion, + Context: map[string]interface{}{ + "current_type": column.Type, + }, + } +} + +// isBinaryType checks if a column is BINARY or VARBINARY type +// The parser returns "char" for BINARY and "varchar" for VARBINARY, so we need to check the binary flag +func (l *PrimaryKeyTypeLinter) isBinaryType(column *statement.Column) bool { + if column.Raw == nil || column.Raw.Tp == nil { + return false + } + + // BINARY is mysql.TypeString with binary flag + // VARBINARY is mysql.TypeVarchar with binary flag + rawType := column.Raw.Tp.GetType() + + return (rawType == mysql.TypeString || rawType == mysql.TypeVarchar) && mysql.HasBinaryFlag(column.Raw.Tp.GetFlag()) +} diff --git a/pkg/lint/lint_primary_key_type_test.go b/pkg/lint/lint_primary_key_type_test.go new file mode 100644 index 00000000..1ee93b36 --- /dev/null +++ b/pkg/lint/lint_primary_key_type_test.go @@ -0,0 +1,396 @@ +package lint + +import ( + "testing" + + "github.com/block/spirit/pkg/statement" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPrimaryKeyTypeLinter_BigIntUnsigned(t *testing.T) { + sql := `CREATE TABLE users ( + id BIGINT UNSIGNED PRIMARY KEY, + name VARCHAR(255) + )` + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := &PrimaryKeyTypeLinter{} + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + + // BIGINT UNSIGNED is ideal - no violations + assert.Empty(t, violations) +} + +func TestPrimaryKeyTypeLinter_BigIntSigned(t *testing.T) { + sql := `CREATE TABLE users ( + id BIGINT PRIMARY KEY, + name VARCHAR(255) + )` + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := &PrimaryKeyTypeLinter{} + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + + // BIGINT without UNSIGNED should be a warning + require.Len(t, violations, 1) + assert.Equal(t, "primary_key_type", violations[0].Linter.Name()) + assert.Equal(t, SeverityWarning, violations[0].Severity) + assert.Contains(t, violations[0].Message, "signed BIGINT") + assert.Contains(t, violations[0].Message, "UNSIGNED is preferred") + assert.Equal(t, "users", violations[0].Location.Table) + assert.NotNil(t, violations[0].Location.Column) + assert.Equal(t, "id", *violations[0].Location.Column) + assert.NotNil(t, violations[0].Suggestion) + assert.Contains(t, *violations[0].Suggestion, "BIGINT UNSIGNED") +} + +func TestPrimaryKeyTypeLinter_Binary(t *testing.T) { + sql := `CREATE TABLE users ( + id BINARY(16) PRIMARY KEY, + name VARCHAR(255) + )` + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := &PrimaryKeyTypeLinter{} + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + + // BINARY is acceptable - no violations + assert.Empty(t, violations) +} + +func TestPrimaryKeyTypeLinter_VarBinary(t *testing.T) { + sql := `CREATE TABLE users ( + id VARBINARY(255) PRIMARY KEY, + name VARCHAR(255) + )` + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := &PrimaryKeyTypeLinter{} + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + + // VARBINARY is acceptable - no violations + assert.Empty(t, violations) +} + +func TestPrimaryKeyTypeLinter_IntError(t *testing.T) { + sql := `CREATE TABLE users ( + id INT PRIMARY KEY, + name VARCHAR(255) + )` + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := &PrimaryKeyTypeLinter{} + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + + // INT is not acceptable - should be an error + require.Len(t, violations, 1) + assert.Equal(t, "primary_key_type", violations[0].Linter.Name()) + assert.Equal(t, SeverityError, violations[0].Severity) + assert.Contains(t, violations[0].Message, "must be BIGINT or BINARY/VARBINARY") + assert.Equal(t, "users", violations[0].Location.Table) + assert.NotNil(t, violations[0].Location.Column) + assert.Equal(t, "id", *violations[0].Location.Column) + assert.NotNil(t, violations[0].Suggestion) + assert.Contains(t, *violations[0].Suggestion, "BIGINT UNSIGNED") +} + +func TestPrimaryKeyTypeLinter_VarcharError(t *testing.T) { + sql := `CREATE TABLE users ( + id VARCHAR(36) PRIMARY KEY, + name VARCHAR(255) + )` + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := &PrimaryKeyTypeLinter{} + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + + // VARCHAR is not acceptable - should be an error + require.Len(t, violations, 1) + assert.Equal(t, "primary_key_type", violations[0].Linter.Name()) + assert.Equal(t, SeverityError, violations[0].Severity) + assert.Contains(t, violations[0].Message, "must be BIGINT or BINARY/VARBINARY") + assert.NotNil(t, violations[0].Context) + // The parser returns lowercase "varchar" + assert.Equal(t, "varchar", violations[0].Context["current_type"]) +} + +func TestPrimaryKeyTypeLinter_CompositePrimaryKey(t *testing.T) { + sql := `CREATE TABLE user_roles ( + user_id BIGINT UNSIGNED, + role_id INT, + PRIMARY KEY (user_id, role_id) + )` + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := &PrimaryKeyTypeLinter{} + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + + // user_id is BIGINT UNSIGNED (good), role_id is INT (error) + require.Len(t, violations, 1) + assert.Equal(t, "primary_key_type", violations[0].Linter.Name()) + assert.Equal(t, SeverityError, violations[0].Severity) + assert.Equal(t, "role_id", *violations[0].Location.Column) +} + +func TestPrimaryKeyTypeLinter_CompositePrimaryKeyAllGood(t *testing.T) { + sql := `CREATE TABLE user_roles ( + user_id BIGINT UNSIGNED, + role_id BIGINT UNSIGNED, + PRIMARY KEY (user_id, role_id) + )` + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := &PrimaryKeyTypeLinter{} + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + + // Both columns are BIGINT UNSIGNED - no violations + assert.Empty(t, violations) +} + +func TestPrimaryKeyTypeLinter_CompositePrimaryKeyMixed(t *testing.T) { + sql := `CREATE TABLE user_roles ( + user_id BIGINT, + role_id BIGINT UNSIGNED, + PRIMARY KEY (user_id, role_id) + )` + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := &PrimaryKeyTypeLinter{} + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + + // user_id is signed BIGINT (warning) + require.Len(t, violations, 1) + assert.Equal(t, "primary_key_type", violations[0].Linter.Name()) + assert.Equal(t, SeverityWarning, violations[0].Severity) + assert.Equal(t, "user_id", *violations[0].Location.Column) +} + +func TestPrimaryKeyTypeLinter_NoPrimaryKey(t *testing.T) { + sql := `CREATE TABLE logs ( + id BIGINT UNSIGNED, + message TEXT + )` + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := &PrimaryKeyTypeLinter{} + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + + // No primary key - no violations from this linter + assert.Empty(t, violations) +} + +func TestPrimaryKeyTypeLinter_MultipleTables(t *testing.T) { + sql1 := `CREATE TABLE users ( + id BIGINT UNSIGNED PRIMARY KEY, + name VARCHAR(255) + )` + ct1, err := statement.ParseCreateTable(sql1) + require.NoError(t, err) + + sql2 := `CREATE TABLE orders ( + id INT PRIMARY KEY, + user_id BIGINT UNSIGNED + )` + ct2, err := statement.ParseCreateTable(sql2) + require.NoError(t, err) + + linter := &PrimaryKeyTypeLinter{} + violations := linter.Lint([]*statement.CreateTable{ct1, ct2}, nil) + + // Only orders table should have a violation + require.Len(t, violations, 1) + assert.Equal(t, "orders", violations[0].Location.Table) + assert.Equal(t, SeverityError, violations[0].Severity) +} + +func TestPrimaryKeyTypeLinter_SmallIntError(t *testing.T) { + sql := `CREATE TABLE users ( + id SMALLINT PRIMARY KEY, + name VARCHAR(255) + )` + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := &PrimaryKeyTypeLinter{} + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + + // SMALLINT is not acceptable - should be an error + require.Len(t, violations, 1) + assert.Equal(t, SeverityError, violations[0].Severity) +} + +func TestPrimaryKeyTypeLinter_MediumIntError(t *testing.T) { + sql := `CREATE TABLE users ( + id MEDIUMINT PRIMARY KEY, + name VARCHAR(255) + )` + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := &PrimaryKeyTypeLinter{} + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + + // MEDIUMINT is not acceptable - should be an error + require.Len(t, violations, 1) + assert.Equal(t, SeverityError, violations[0].Severity) +} + +func TestPrimaryKeyTypeLinter_CharError(t *testing.T) { + sql := `CREATE TABLE users ( + id CHAR(36) PRIMARY KEY, + name VARCHAR(255) + )` + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := &PrimaryKeyTypeLinter{} + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + + // CHAR is not acceptable - should be an error + require.Len(t, violations, 1) + assert.Equal(t, SeverityError, violations[0].Severity) +} + +func TestPrimaryKeyTypeLinter_EmptyInput(t *testing.T) { + linter := &PrimaryKeyTypeLinter{} + violations := linter.Lint(nil, nil) + + assert.Empty(t, violations) +} + +func TestPrimaryKeyTypeLinter_Integration(t *testing.T) { + Reset() + Register(&PrimaryKeyTypeLinter{}) + + sql := `CREATE TABLE users ( + id INT PRIMARY KEY, + name VARCHAR(255) + )` + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + violations, err := RunLinters([]*statement.CreateTable{ct}, nil, Config{}) + require.NoError(t, err) + + require.Len(t, violations, 1) + assert.Equal(t, "primary_key_type", violations[0].Linter.Name()) +} + +func TestPrimaryKeyTypeLinter_IntegrationDisabled(t *testing.T) { + Reset() + Register(&PrimaryKeyTypeLinter{}) + + sql := `CREATE TABLE users ( + id INT PRIMARY KEY, + name VARCHAR(255) + )` + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + violations, err := RunLinters([]*statement.CreateTable{ct}, nil, Config{ + Enabled: map[string]bool{ + "primary_key_type": false, + }, + }) + require.NoError(t, err) + + assert.Empty(t, violations) +} + +func TestPrimaryKeyTypeLinter_Metadata(t *testing.T) { + linter := &PrimaryKeyTypeLinter{} + + assert.Equal(t, "primary_key_type", linter.Name()) + assert.NotEmpty(t, linter.Description()) +} + +func TestPrimaryKeyTypeLinter_BigIntUnsignedExplicit(t *testing.T) { + sql := `CREATE TABLE users ( + id BIGINT(20) UNSIGNED PRIMARY KEY, + name VARCHAR(255) + )` + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := &PrimaryKeyTypeLinter{} + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + + // BIGINT(20) UNSIGNED is ideal - no violations + assert.Empty(t, violations) +} + +func TestPrimaryKeyTypeLinter_CaseInsensitive(t *testing.T) { + // Test that type checking is case-insensitive + sql := `CREATE TABLE users ( + id bigint unsigned PRIMARY KEY, + name VARCHAR(255) + )` + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := &PrimaryKeyTypeLinter{} + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + + // Should recognize lowercase bigint unsigned + assert.Empty(t, violations) +} + +func TestPrimaryKeyTypeLinter_AutoIncrement(t *testing.T) { + sql := `CREATE TABLE users ( + id BIGINT UNSIGNED AUTO_INCREMENT PRIMARY KEY, + name VARCHAR(255) + )` + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := &PrimaryKeyTypeLinter{} + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + + // BIGINT UNSIGNED with AUTO_INCREMENT is ideal - no violations + assert.Empty(t, violations) +} + +func TestPrimaryKeyTypeLinter_UUIDAsVarchar(t *testing.T) { + // Common anti-pattern: using VARCHAR for UUIDs + sql := `CREATE TABLE users ( + uuid VARCHAR(36) PRIMARY KEY, + name VARCHAR(255) + )` + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := &PrimaryKeyTypeLinter{} + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + + // VARCHAR for UUID should be an error + require.Len(t, violations, 1) + assert.Equal(t, SeverityError, violations[0].Severity) + assert.Contains(t, *violations[0].Suggestion, "BINARY/VARBINARY") +} + +func TestPrimaryKeyTypeLinter_UUIDAsBinary(t *testing.T) { + // Correct way to store UUIDs + sql := `CREATE TABLE users ( + uuid BINARY(16) PRIMARY KEY, + name VARCHAR(255) + )` + ct, err := statement.ParseCreateTable(sql) + require.NoError(t, err) + + linter := &PrimaryKeyTypeLinter{} + violations := linter.Lint([]*statement.CreateTable{ct}, nil) + + // BINARY(16) for UUID is correct - no violations + assert.Empty(t, violations) +} diff --git a/pkg/lint/lint_test.go b/pkg/lint/lint_test.go new file mode 100644 index 00000000..04c5f86f --- /dev/null +++ b/pkg/lint/lint_test.go @@ -0,0 +1,545 @@ +package lint + +import ( + "testing" + + "github.com/block/spirit/pkg/statement" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Mock linter for testing +type mockLinter struct { + name string + description string + violations []Violation +} + +func (m *mockLinter) String() string { + //TODO implement me + panic("implement me") +} + +func (m *mockLinter) Name() string { return m.name } +func (m *mockLinter) Description() string { return m.description } +func (m *mockLinter) Lint(createTables []*statement.CreateTable, statements []*statement.AbstractStatement) []Violation { + return m.violations +} + +// Configurable mock linter for testing +type mockConfigurableLinter struct { + mockLinter + + configCalled bool + configValue map[string]string +} + +func (m *mockConfigurableLinter) String() string { + //TODO implement me + panic("implement me") +} + +func (m *mockConfigurableLinter) Configure(config map[string]string) error { + m.configCalled = true + m.configValue = config + + return nil +} + +func (m *mockConfigurableLinter) DefaultConfig() map[string]string { + return map[string]string{ + "default": "value", + } +} + +func TestRegister(t *testing.T) { + // Reset registry before test + Reset() + + linter := &mockLinter{ + name: "test_linter", + description: "A test linter", + } + + Register(linter) + + // Verify linter was registered + names := List() + assert.Contains(t, names, "test_linter") + + // Verify we can get it back + retrieved, err := Get("test_linter") + require.NoError(t, err) + assert.Equal(t, "test_linter", retrieved.Name()) +} + +func TestRegisterMultiple(t *testing.T) { + Reset() + + linter1 := &mockLinter{name: "linter1"} + linter2 := &mockLinter{name: "linter2"} + linter3 := &mockLinter{name: "linter3"} + + Register(linter1) + Register(linter2) + Register(linter3) + + names := List() + assert.Len(t, names, 3) + assert.Contains(t, names, "linter1") + assert.Contains(t, names, "linter2") + assert.Contains(t, names, "linter3") +} + +func TestEnableDisable(t *testing.T) { + Reset() + + linter := &mockLinter{name: "test_linter"} + Register(linter) + + // Linters are enabled by default + assert.True(t, linters["test_linter"].enabled) + + // Disable it + err := Disable("test_linter") + require.NoError(t, err) + assert.False(t, linters["test_linter"].enabled) + + // Enable it again + err = Enable("test_linter") + require.NoError(t, err) + assert.True(t, linters["test_linter"].enabled) +} + +func TestEnableDisableNonexistent(t *testing.T) { + Reset() + + err := Enable("nonexistent") + assert.Error(t, err) + assert.Contains(t, err.Error(), "not found") + + err = Disable("nonexistent") + assert.Error(t, err) + assert.Contains(t, err.Error(), "not found") +} + +func TestGet(t *testing.T) { + Reset() + + linter := &mockLinter{ + name: "test_linter", + description: "A test linter", + } + Register(linter) + + retrieved, err := Get("test_linter") + require.NoError(t, err) + assert.Equal(t, "test_linter", retrieved.Name()) + assert.Equal(t, "A test linter", retrieved.Description()) +} + +func TestGetNonexistent(t *testing.T) { + Reset() + + _, err := Get("nonexistent") + assert.Error(t, err) + assert.Contains(t, err.Error(), "not found") +} + +func TestRunLinters_Empty(t *testing.T) { + Reset() + + violations, err := RunLinters(nil, nil, Config{}) + require.NoError(t, err) + assert.Empty(t, violations) +} + +func TestRunLinters_SingleLinter(t *testing.T) { + Reset() + + linter := &mockLinter{ + name: "test_linter", + } + + expectedViolations := []Violation{ + { + Linter: linter, + Severity: SeverityError, + Message: "Test error", + }, + } + linter.violations = expectedViolations + + Register(linter) + + violations, err := RunLinters(nil, nil, Config{}) + require.NoError(t, err) + assert.Len(t, violations, 1) + assert.Equal(t, "test_linter", violations[0].Linter.Name()) + assert.Equal(t, SeverityError, violations[0].Severity) + assert.Equal(t, "Test error", violations[0].Message) +} + +func TestRunLinters_MultipleLinters(t *testing.T) { + Reset() + + linter1 := &mockLinter{ + name: "linter1", + } + linter1.violations = []Violation{ + {Linter: linter1, Severity: SeverityError, Message: "Error 1"}, + } + + linter2 := &mockLinter{ + name: "linter2", + } + linter2.violations = []Violation{ + {Linter: linter2, Severity: SeverityWarning, Message: "Warning 1"}, + {Linter: linter2, Severity: SeverityInfo, Message: "Info 1"}, + } + + Register(linter1) + Register(linter2) + + violations, err := RunLinters(nil, nil, Config{}) + require.NoError(t, err) + assert.Len(t, violations, 3) +} + +func TestRunLinters_WithConfig_Disabled(t *testing.T) { + Reset() + + linter := &mockLinter{ + name: "test_linter", + } + linter.violations = []Violation{ + {Linter: linter, Severity: SeverityError, Message: "Should not see this"}, + } + Register(linter) + + // Disable the linter via config + violations, err := RunLinters(nil, nil, Config{ + Enabled: map[string]bool{ + "test_linter": false, + }, + }) + require.NoError(t, err) + + assert.Empty(t, violations) +} + +func TestRunLinters_WithConfig_Enabled(t *testing.T) { + Reset() + + linter := &mockLinter{ + name: "test_linter", + } + linter.violations = []Violation{ + {Linter: linter, Severity: SeverityError, Message: "Should see this"}, + } + + // Disable by default + Register(linter) + require.NoError(t, Disable("test_linter")) + + // But explicitly enable via config + violations, err := RunLinters(nil, nil, Config{ + Enabled: map[string]bool{ + "test_linter": true, + }, + }) + require.NoError(t, err) + + assert.Len(t, violations, 1) + assert.Equal(t, "Should see this", violations[0].Message) +} + +func TestRunLinters_ConfigurableLinter(t *testing.T) { + Reset() + + linter := &mockConfigurableLinter{} + linter.name = "configurable_linter" + linter.violations = []Violation{ + {Linter: linter, Severity: SeverityError, Message: "Test"}, + } + Register(linter) + + config := map[string]string{"key": "value"} + violations, err := RunLinters(nil, nil, Config{ + Settings: map[string]map[string]string{ + "configurable_linter": config, + }, + }) + require.NoError(t, err) + + assert.Len(t, violations, 1) + assert.True(t, linter.configCalled) + // User config should be merged with defaults + expected := map[string]string{ + "default": "value", // from DefaultConfig + "key": "value", // from user config + } + assert.Equal(t, expected, linter.configValue) +} + +func TestRunLinters_ConfigurableLinter_NoConfig(t *testing.T) { + Reset() + + linter := &mockConfigurableLinter{} + linter.name = "configurable_linter" + linter.violations = []Violation{ + {Linter: linter, Severity: SeverityError, Message: "Test"}, + } + Register(linter) + + violations, err := RunLinters(nil, nil, Config{}) + require.NoError(t, err) + + assert.Len(t, violations, 1) + // Now Configure is always called (with defaults) + assert.True(t, linter.configCalled) + // Should have received the default config + assert.Equal(t, map[string]string{"default": "value"}, linter.configValue) +} + +func TestHasErrors(t *testing.T) { + violations := []Violation{ + {Severity: SeverityWarning}, + {Severity: SeverityInfo}, + } + assert.False(t, HasErrors(violations)) + + violations = append(violations, Violation{Severity: SeverityError}) + assert.True(t, HasErrors(violations)) +} + +func TestHasWarnings(t *testing.T) { + violations := []Violation{ + {Severity: SeverityError}, + {Severity: SeverityInfo}, + } + assert.False(t, HasWarnings(violations)) + + violations = append(violations, Violation{Severity: SeverityWarning}) + assert.True(t, HasWarnings(violations)) +} + +func TestFilterBySeverity(t *testing.T) { + violations := []Violation{ + {Severity: SeverityError, Message: "Error 1"}, + {Severity: SeverityWarning, Message: "Warning 1"}, + {Severity: SeverityError, Message: "Error 2"}, + {Severity: SeverityInfo, Message: "Info 1"}, + } + + errors := FilterBySeverity(violations, SeverityError) + assert.Len(t, errors, 2) + assert.Equal(t, "Error 1", errors[0].Message) + assert.Equal(t, "Error 2", errors[1].Message) + + warnings := FilterBySeverity(violations, SeverityWarning) + assert.Len(t, warnings, 1) + assert.Equal(t, "Warning 1", warnings[0].Message) + + infos := FilterBySeverity(violations, SeverityInfo) + assert.Len(t, infos, 1) + assert.Equal(t, "Info 1", infos[0].Message) +} + +func TestFilterByLinter(t *testing.T) { + linter1 := &mockLinter{name: "linter1"} + linter2 := &mockLinter{name: "linter2"} + + violations := []Violation{ + {Linter: linter1, Message: "Message 1"}, + {Linter: linter2, Message: "Message 2"}, + {Linter: linter1, Message: "Message 3"}, + } + + linter1Violations := FilterByLinter(violations, "linter1") + assert.Len(t, linter1Violations, 2) + assert.Equal(t, "Message 1", linter1Violations[0].Message) + assert.Equal(t, "Message 3", linter1Violations[1].Message) + + linter2Violations := FilterByLinter(violations, "linter2") + assert.Len(t, linter2Violations, 1) + assert.Equal(t, "Message 2", linter2Violations[0].Message) + + nonexistentViolations := FilterByLinter(violations, "nonexistent") + assert.Empty(t, nonexistentViolations) +} + +func TestListSorted(t *testing.T) { + Reset() + + // Register in non-alphabetical order + Register(&mockLinter{name: "zebra"}) + Register(&mockLinter{name: "alpha"}) + Register(&mockLinter{name: "beta"}) + + names := List() + assert.Equal(t, []string{"alpha", "beta", "zebra"}, names) +} + +func TestReset(t *testing.T) { + Reset() + + Register(&mockLinter{name: "linter1"}) + Register(&mockLinter{name: "linter2"}) + + assert.Len(t, List(), 2) + + Reset() + + assert.Empty(t, List()) +} + +func TestViolationWithLocation(t *testing.T) { + column := "test_column" + index := "test_index" + constraint := "test_constraint" + linter := &mockLinter{name: "test_linter"} + + violation := Violation{ + Linter: linter, + Severity: SeverityError, + Message: "Test message", + Location: &Location{ + Table: "test_table", + Column: &column, + Index: &index, + Constraint: &constraint, + }, + } + + assert.Equal(t, "test_table", violation.Location.Table) + assert.Equal(t, "test_column", *violation.Location.Column) + assert.Equal(t, "test_index", *violation.Location.Index) + assert.Equal(t, "test_constraint", *violation.Location.Constraint) +} + +func TestViolationWithSuggestion(t *testing.T) { + suggestion := "Try this instead" + linter := &mockLinter{name: "test_linter"} + + violation := Violation{ + Linter: linter, + Severity: SeverityWarning, + Message: "Test message", + Suggestion: &suggestion, + } + + assert.NotNil(t, violation.Suggestion) + assert.Equal(t, "Try this instead", *violation.Suggestion) +} + +func TestViolationWithContext(t *testing.T) { + linter := &mockLinter{name: "test_linter"} + + violation := Violation{ + Linter: linter, + Severity: SeverityInfo, + Message: "Test message", + Context: map[string]any{ + "key1": "value1", + "key2": 42, + }, + } + + assert.Len(t, violation.Context, 2) + assert.Equal(t, "value1", violation.Context["key1"]) + assert.Equal(t, 42, violation.Context["key2"]) +} + +// ConfigBool tests + +func TestConfigBool_ValidTrue(t *testing.T) { + tests := []string{"true", "TRUE", "True", "TrUe"} + for _, value := range tests { + t.Run(value, func(t *testing.T) { + result, err := ConfigBool(value, "testKey") + require.NoError(t, err) + assert.True(t, result) + }) + } +} + +func TestConfigBool_ValidFalse(t *testing.T) { + tests := []string{"false", "FALSE", "False", "FaLsE"} + for _, value := range tests { + t.Run(value, func(t *testing.T) { + result, err := ConfigBool(value, "testKey") + require.NoError(t, err) + assert.False(t, result) + }) + } +} + +func TestConfigBool_Invalid(t *testing.T) { + tests := []struct { + value string + key string + }{ + {"yes", "testKey"}, + {"no", "testKey"}, + {"1", "testKey"}, + {"0", "testKey"}, + {"True ", "testKey"}, // trailing space + {" true", "testKey"}, // leading space + {"", "testKey"}, + {"invalid", "myOption"}, + } + + for _, tt := range tests { + t.Run(tt.value, func(t *testing.T) { + result, err := ConfigBool(tt.value, tt.key) + require.Error(t, err) + assert.False(t, result) + assert.Contains(t, err.Error(), "invalid value for "+tt.key) + assert.Contains(t, err.Error(), tt.value) + assert.Contains(t, err.Error(), "expected 'true' or 'false'") + }) + } +} + +// DefaultConfig tests + +func TestRunLinters_AppliesDefaultConfig(t *testing.T) { + Reset() + Register(&InvisibleIndexBeforeDropLinter{}) + + sql := "ALTER TABLE users DROP INDEX idx_email" + stmts, err := statement.New(sql) + require.NoError(t, err) + + // Run without any config - should apply default (raiseError=false) + violations, err := RunLinters(nil, stmts, Config{}) + require.NoError(t, err) + + require.Len(t, violations, 1) + // Default raiseError is "false", so severity should be Warning + assert.Equal(t, SeverityWarning, violations[0].Severity) +} + +func TestRunLinters_UserConfigOverridesDefault(t *testing.T) { + Reset() + Register(&InvisibleIndexBeforeDropLinter{}) + + sql := "ALTER TABLE users DROP INDEX idx_email" + stmts, err := statement.New(sql) + require.NoError(t, err) + + // Override default raiseError=false with true + violations, err := RunLinters(nil, stmts, Config{ + Settings: map[string]map[string]string{ + "invisible_index_before_drop": { + "raiseError": "true", + }, + }, + }) + require.NoError(t, err) + + require.Len(t, violations, 1) + // User set raiseError=true, so severity should be Error + assert.Equal(t, SeverityError, violations[0].Severity) +} diff --git a/pkg/lint/linter.go b/pkg/lint/linter.go new file mode 100644 index 00000000..f24360a0 --- /dev/null +++ b/pkg/lint/linter.go @@ -0,0 +1,57 @@ +package lint + +import ( + "fmt" + "strings" + + "github.com/block/spirit/pkg/statement" +) + +// Linter is the interface that all linters must implement +type Linter interface { + // Name returns the unique name of this linter + Name() string + + // Description returns a human-readable description of what this linter checks + Description() string + + // Lint performs the actual linting and returns any violations found. + // Linters can use either or both of the parameters as needed. + Lint(createTables []*statement.CreateTable, alterStatements []*statement.AbstractStatement) []Violation + + // String returns a string representation of the linter + String() string +} + +// ConfigurableLinter is an optional interface for linters that support configuration +type ConfigurableLinter interface { + Linter + + // Configure applies configuration to the linter + // Configuration is provided as a map of string keys to string values + Configure(config map[string]string) error + + // DefaultConfig returns the default configuration for this linter + DefaultConfig() map[string]string +} + +// Stringer returns a string representation of the linter +// This is a helper function used by linters' String() methods. +func Stringer(l Linter) string { + return l.Name() + " - " + l.Description() +} + +// ConfigBool parses a boolean configuration value from a string. +// It accepts "true" or "false" (case-insensitive) and returns an error for invalid values. +// The key parameter is used in error messages to provide context. +func ConfigBool(value string, key string) (bool, error) { + if strings.EqualFold(value, "true") { + return true, nil + } + + if strings.EqualFold(value, "false") { + return false, nil + } + + return false, fmt.Errorf("invalid value for %s: %s (expected 'true' or 'false')", key, value) +} diff --git a/pkg/lint/registry.go b/pkg/lint/registry.go new file mode 100644 index 00000000..dffe8632 --- /dev/null +++ b/pkg/lint/registry.go @@ -0,0 +1,111 @@ +package lint + +import ( + "fmt" + "sort" + "sync" +) + +// linter represents a registered linter with metadata +type linter struct { + l Linter + enabled bool +} + +var ( + linters map[string]linter + lock sync.RWMutex +) + +// Register registers a linter with the global registry. +// This should be called from init() functions in linter implementations. +// Linters are enabled by default when registered. +func Register(l Linter) { + lock.Lock() + defer lock.Unlock() + + if linters == nil { + linters = make(map[string]linter) + } + + linters[l.Name()] = linter{ + l: l, + enabled: true, + } +} + +// Enable enables specific linters by name. +// Returns an error if the linter is not found. +func Enable(names ...string) error { + lock.Lock() + defer lock.Unlock() + + for _, name := range names { + l, ok := linters[name] + if !ok { + return fmt.Errorf("linter %q not found", name) + } + + l.enabled = true + linters[name] = l + } + + return nil +} + +// Disable disables specific linters by name. +// Returns an error if the linter is not found. +func Disable(names ...string) error { + lock.Lock() + defer lock.Unlock() + + for _, name := range names { + l, ok := linters[name] + if !ok { + return fmt.Errorf("linter %q not found", name) + } + + l.enabled = false + linters[name] = l + } + + return nil +} + +// List returns the names of all registered linters in sorted order. +func List() []string { + lock.RLock() + defer lock.RUnlock() + + names := make([]string, 0, len(linters)) + for name := range linters { + names = append(names, name) + } + + sort.Strings(names) + + return names +} + +// Get returns a linter by name. +// Returns an error if the linter is not found. +func Get(name string) (Linter, error) { + lock.RLock() + defer lock.RUnlock() + + l, ok := linters[name] + if !ok { + return nil, fmt.Errorf("linter %q not found", name) + } + + return l.l, nil +} + +// Reset clears all registered linters. +// This is primarily useful for testing. +func Reset() { + lock.Lock() + defer lock.Unlock() + + linters = make(map[string]linter) +} diff --git a/pkg/lint/violation.go b/pkg/lint/violation.go new file mode 100644 index 00000000..d9f725fc --- /dev/null +++ b/pkg/lint/violation.go @@ -0,0 +1,84 @@ +package lint + +import "fmt" + +// Severity represents the severity level of a linting violation +type Severity string + +const ( + // SeverityError indicates a violation that will cause actual problems + // (syntax errors, MySQL limitations, etc.) + SeverityError Severity = "ERROR" + + // SeverityWarning indicates a best practice violation or potential issue + SeverityWarning Severity = "WARNING" + + // SeverityInfo indicates a suggestion or style preference + SeverityInfo Severity = "INFO" +) + +// Violation represents a linting violation found during analysis +type Violation struct { + // Linter is the linter that produced this violation + Linter Linter + + // Severity is the severity level of the violation + Severity Severity + + // Message is a human-readable description of the violation + Message string + + // Location provides information about where the violation occurred + Location *Location + + // Suggestion is an optional suggestion for fixing the violation + Suggestion *string + + // Context provides additional context-specific information + Context map[string]any +} + +func (v Violation) String() string { + msg := fmt.Sprintf("[%s] %s: %s", v.Severity, v.Linter.Name(), v.Message) + if v.Location != nil { + msg += fmt.Sprintf(" (%s)", v.Location) + } + + if v.Suggestion != nil { + msg += " Suggestion: " + *v.Suggestion + } + + return msg +} + +// Location provides information about where a violation occurred +type Location struct { + // Table is the name of the table where the violation occurred + Table string + + // Column is the name of the column (if applicable) + Column *string + + // Index is the name of the index (if applicable) + Index *string + + // Constraint is the name of the constraint (if applicable) + Constraint *string +} + +func (l *Location) String() string { + msg := "Table: " + l.Table + if l.Column != nil { + msg += ", Column: " + *l.Column + } + + if l.Index != nil { + msg += ", Index: " + *l.Index + } + + if l.Constraint != nil { + msg += ", Constraint: " + *l.Constraint + } + + return msg +} diff --git a/pkg/statement/parse_create_table.go b/pkg/statement/parse_create_table.go index 83e6bcaf..2029cf46 100644 --- a/pkg/statement/parse_create_table.go +++ b/pkg/statement/parse_create_table.go @@ -171,7 +171,7 @@ func ParseCreateTable(sql string) (*CreateTable, error) { stmts, _, err := p.Parse(sql, "", "") if err != nil { - return nil, fmt.Errorf("failed to parse SQL: %w", err) + return nil, fmt.Errorf("failed to parse SQL %q: %w", sql, err) } if len(stmts) != 1 {