From 5f89eb297d442509e2294872d8d7c4bb62b06317 Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Sat, 29 Nov 2025 22:05:09 -0800 Subject: [PATCH 1/8] feat(postgresql): Add star expander for SELECT * and RETURNING * MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a new expander package that expands * expressions in SQL queries to explicit column names by preparing the query against a PostgreSQL database. Features: - Expands SELECT * to explicit column list - Preserves table prefix for qualified stars (e.g., table.*) - Handles RETURNING * in INSERT/UPDATE/DELETE statements - Recursively expands CTEs, including dependent CTEs - Supports subqueries in FROM clause - Works with both cgo (pganalyze/pg_query_go) and non-cgo (wasilibs/go-pgquery) builds Example: SELECT * FROM authors → SELECT id, name, bio FROM authors 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- .../engine/postgresql/expander/expander.go | 599 ++++++++++++++++++ .../postgresql/expander/expander_test.go | 120 ++++ .../postgresql/expander/parse_default.go | 10 + .../engine/postgresql/expander/parse_wasi.go | 10 + 4 files changed, 739 insertions(+) create mode 100644 internal/engine/postgresql/expander/expander.go create mode 100644 internal/engine/postgresql/expander/expander_test.go create mode 100644 internal/engine/postgresql/expander/parse_default.go create mode 100644 internal/engine/postgresql/expander/parse_wasi.go diff --git a/internal/engine/postgresql/expander/expander.go b/internal/engine/postgresql/expander/expander.go new file mode 100644 index 0000000000..04129e5275 --- /dev/null +++ b/internal/engine/postgresql/expander/expander.go @@ -0,0 +1,599 @@ +package expander + +import ( + "context" + "fmt" + + "github.com/jackc/pgx/v5/pgxpool" + nodes "github.com/pganalyze/pg_query_go/v6" +) + +// Expander expands SELECT * and RETURNING * queries by replacing * with explicit column names +// obtained from preparing the query against a PostgreSQL database. +type Expander struct { + pool *pgxpool.Pool +} + +// New creates a new Expander with the given connection pool. +func New(pool *pgxpool.Pool) *Expander { + return &Expander{pool: pool} +} + +// Expand takes a SQL query, and if it contains * in SELECT or RETURNING clause, +// expands it to use explicit column names. Returns the expanded query string. +func (e *Expander) Expand(ctx context.Context, query string) (string, error) { + // Parse the query + tree, err := parse(query) + if err != nil { + return "", fmt.Errorf("failed to parse query: %w", err) + } + + if len(tree.Stmts) == 0 { + return query, nil + } + + stmt := tree.Stmts[0].Stmt + + // Check if there's any star in the statement (including CTEs, subqueries, etc.) + if !hasStarAnywhere(stmt) { + return query, nil + } + + // Expand all stars in the statement recursively + if err := e.expandNode(ctx, stmt); err != nil { + return "", err + } + + // Deparse the modified AST back to SQL + expanded, err := deparse(tree) + if err != nil { + return "", fmt.Errorf("failed to deparse query: %w", err) + } + + return expanded, nil +} + +// expandNode recursively expands * in all parts of the statement +func (e *Expander) expandNode(ctx context.Context, node *nodes.Node) error { + if node == nil { + return nil + } + + switch n := node.Node.(type) { + case *nodes.Node_SelectStmt: + return e.expandSelectStmt(ctx, n.SelectStmt) + case *nodes.Node_InsertStmt: + return e.expandInsertStmt(ctx, n.InsertStmt) + case *nodes.Node_UpdateStmt: + return e.expandUpdateStmt(ctx, n.UpdateStmt) + case *nodes.Node_DeleteStmt: + return e.expandDeleteStmt(ctx, n.DeleteStmt) + case *nodes.Node_CommonTableExpr: + return e.expandNode(ctx, n.CommonTableExpr.Ctequery) + } + return nil +} + +// expandSelectStmt expands * in a SELECT statement including CTEs and subqueries +func (e *Expander) expandSelectStmt(ctx context.Context, stmt *nodes.SelectStmt) error { + // First expand any CTEs - must be done in order since later CTEs may depend on earlier ones + if stmt.WithClause != nil { + for _, cte := range stmt.WithClause.Ctes { + cteExpr, ok := cte.Node.(*nodes.Node_CommonTableExpr) + if !ok { + continue + } + cteSelect, ok := cteExpr.CommonTableExpr.Ctequery.Node.(*nodes.Node_SelectStmt) + if !ok { + continue + } + if hasStarInList(cteSelect.SelectStmt.TargetList) { + // Deparse the full statement (with WITH clause context) but query just this CTE + // We need to build a query that includes all prior CTEs for context + columns, err := e.getCTEColumnNames(ctx, stmt, cteExpr.CommonTableExpr) + if err != nil { + return err + } + cteSelect.SelectStmt.TargetList = rewriteTargetList(cteSelect.SelectStmt.TargetList, columns) + } + // Recursively handle nested CTEs/subqueries in this CTE + if err := e.expandSelectStmtInner(ctx, cteSelect.SelectStmt); err != nil { + return err + } + } + } + + // Expand subqueries in FROM clause + for _, fromItem := range stmt.FromClause { + if err := e.expandFromClause(ctx, fromItem); err != nil { + return err + } + } + + // Expand the target list if it has stars + if hasStarInList(stmt.TargetList) { + // Deparse the current state to get columns + tempTree := &nodes.ParseResult{ + Stmts: []*nodes.RawStmt{{Stmt: &nodes.Node{Node: &nodes.Node_SelectStmt{SelectStmt: stmt}}}}, + } + tempQuery, err := deparse(tempTree) + if err != nil { + return fmt.Errorf("failed to deparse for column lookup: %w", err) + } + columns, err := e.getColumnNames(ctx, tempQuery) + if err != nil { + return fmt.Errorf("failed to get column names: %w", err) + } + stmt.TargetList = rewriteTargetList(stmt.TargetList, columns) + } + + return nil +} + +// expandSelectStmtInner expands nested structures without re-processing the target list +func (e *Expander) expandSelectStmtInner(ctx context.Context, stmt *nodes.SelectStmt) error { + // Expand subqueries in FROM clause + for _, fromItem := range stmt.FromClause { + if err := e.expandFromClause(ctx, fromItem); err != nil { + return err + } + } + return nil +} + +// getCTEColumnNames gets the column names for a CTE by constructing a query with proper context +func (e *Expander) getCTEColumnNames(ctx context.Context, stmt *nodes.SelectStmt, targetCTE *nodes.CommonTableExpr) ([]string, error) { + // Build a temporary query: WITH SELECT * FROM + // This gives us the proper context for resolving column names + + var ctesToInclude []*nodes.Node + for _, cte := range stmt.WithClause.Ctes { + ctesToInclude = append(ctesToInclude, cte) + cteExpr, ok := cte.Node.(*nodes.Node_CommonTableExpr) + if ok && cteExpr.CommonTableExpr.Ctename == targetCTE.Ctename { + break + } + } + + // Create a SELECT * FROM with the relevant CTEs + tempStmt := &nodes.SelectStmt{ + WithClause: &nodes.WithClause{ + Ctes: ctesToInclude, + Recursive: stmt.WithClause.Recursive, + }, + TargetList: []*nodes.Node{ + { + Node: &nodes.Node_ResTarget{ + ResTarget: &nodes.ResTarget{ + Val: &nodes.Node{ + Node: &nodes.Node_ColumnRef{ + ColumnRef: &nodes.ColumnRef{ + Fields: []*nodes.Node{ + {Node: &nodes.Node_AStar{AStar: &nodes.A_Star{}}}, + }, + }, + }, + }, + }, + }, + }, + }, + FromClause: []*nodes.Node{ + { + Node: &nodes.Node_RangeVar{ + RangeVar: &nodes.RangeVar{ + Relname: targetCTE.Ctename, + Inh: true, + }, + }, + }, + }, + } + + tempTree := &nodes.ParseResult{ + Stmts: []*nodes.RawStmt{{Stmt: &nodes.Node{Node: &nodes.Node_SelectStmt{SelectStmt: tempStmt}}}}, + } + tempQuery, err := deparse(tempTree) + if err != nil { + return nil, fmt.Errorf("failed to deparse CTE query: %w", err) + } + + return e.getColumnNames(ctx, tempQuery) +} + +// expandInsertStmt expands * in an INSERT statement's RETURNING clause +func (e *Expander) expandInsertStmt(ctx context.Context, stmt *nodes.InsertStmt) error { + // Expand CTEs first + if stmt.WithClause != nil { + for _, cte := range stmt.WithClause.Ctes { + if err := e.expandNode(ctx, cte); err != nil { + return err + } + } + } + + // Expand the SELECT part if present + if stmt.SelectStmt != nil { + if err := e.expandNode(ctx, stmt.SelectStmt); err != nil { + return err + } + } + + // Expand RETURNING clause + if hasStarInList(stmt.ReturningList) { + tempTree := &nodes.ParseResult{ + Stmts: []*nodes.RawStmt{{Stmt: &nodes.Node{Node: &nodes.Node_InsertStmt{InsertStmt: stmt}}}}, + } + tempQuery, err := deparse(tempTree) + if err != nil { + return fmt.Errorf("failed to deparse for column lookup: %w", err) + } + columns, err := e.getColumnNames(ctx, tempQuery) + if err != nil { + return fmt.Errorf("failed to get column names: %w", err) + } + stmt.ReturningList = rewriteTargetList(stmt.ReturningList, columns) + } + + return nil +} + +// expandUpdateStmt expands * in an UPDATE statement's RETURNING clause +func (e *Expander) expandUpdateStmt(ctx context.Context, stmt *nodes.UpdateStmt) error { + // Expand CTEs first + if stmt.WithClause != nil { + for _, cte := range stmt.WithClause.Ctes { + if err := e.expandNode(ctx, cte); err != nil { + return err + } + } + } + + // Expand RETURNING clause + if hasStarInList(stmt.ReturningList) { + tempTree := &nodes.ParseResult{ + Stmts: []*nodes.RawStmt{{Stmt: &nodes.Node{Node: &nodes.Node_UpdateStmt{UpdateStmt: stmt}}}}, + } + tempQuery, err := deparse(tempTree) + if err != nil { + return fmt.Errorf("failed to deparse for column lookup: %w", err) + } + columns, err := e.getColumnNames(ctx, tempQuery) + if err != nil { + return fmt.Errorf("failed to get column names: %w", err) + } + stmt.ReturningList = rewriteTargetList(stmt.ReturningList, columns) + } + + return nil +} + +// expandDeleteStmt expands * in a DELETE statement's RETURNING clause +func (e *Expander) expandDeleteStmt(ctx context.Context, stmt *nodes.DeleteStmt) error { + // Expand CTEs first + if stmt.WithClause != nil { + for _, cte := range stmt.WithClause.Ctes { + if err := e.expandNode(ctx, cte); err != nil { + return err + } + } + } + + // Expand RETURNING clause + if hasStarInList(stmt.ReturningList) { + tempTree := &nodes.ParseResult{ + Stmts: []*nodes.RawStmt{{Stmt: &nodes.Node{Node: &nodes.Node_DeleteStmt{DeleteStmt: stmt}}}}, + } + tempQuery, err := deparse(tempTree) + if err != nil { + return fmt.Errorf("failed to deparse for column lookup: %w", err) + } + columns, err := e.getColumnNames(ctx, tempQuery) + if err != nil { + return fmt.Errorf("failed to get column names: %w", err) + } + stmt.ReturningList = rewriteTargetList(stmt.ReturningList, columns) + } + + return nil +} + +// expandFromClause expands * in subqueries within FROM clause +func (e *Expander) expandFromClause(ctx context.Context, node *nodes.Node) error { + if node == nil { + return nil + } + + switch n := node.Node.(type) { + case *nodes.Node_RangeSubselect: + if n.RangeSubselect.Subquery != nil { + return e.expandNode(ctx, n.RangeSubselect.Subquery) + } + case *nodes.Node_JoinExpr: + if err := e.expandFromClause(ctx, n.JoinExpr.Larg); err != nil { + return err + } + if err := e.expandFromClause(ctx, n.JoinExpr.Rarg); err != nil { + return err + } + } + return nil +} + +// hasStarAnywhere checks if there's a * anywhere in the statement +func hasStarAnywhere(node *nodes.Node) bool { + if node == nil { + return false + } + + switch n := node.Node.(type) { + case *nodes.Node_SelectStmt: + if hasStarInList(n.SelectStmt.TargetList) { + return true + } + if n.SelectStmt.WithClause != nil { + for _, cte := range n.SelectStmt.WithClause.Ctes { + if hasStarAnywhere(cte) { + return true + } + } + } + for _, from := range n.SelectStmt.FromClause { + if hasStarAnywhere(from) { + return true + } + } + case *nodes.Node_InsertStmt: + if hasStarInList(n.InsertStmt.ReturningList) { + return true + } + if n.InsertStmt.WithClause != nil { + for _, cte := range n.InsertStmt.WithClause.Ctes { + if hasStarAnywhere(cte) { + return true + } + } + } + if hasStarAnywhere(n.InsertStmt.SelectStmt) { + return true + } + case *nodes.Node_UpdateStmt: + if hasStarInList(n.UpdateStmt.ReturningList) { + return true + } + if n.UpdateStmt.WithClause != nil { + for _, cte := range n.UpdateStmt.WithClause.Ctes { + if hasStarAnywhere(cte) { + return true + } + } + } + case *nodes.Node_DeleteStmt: + if hasStarInList(n.DeleteStmt.ReturningList) { + return true + } + if n.DeleteStmt.WithClause != nil { + for _, cte := range n.DeleteStmt.WithClause.Ctes { + if hasStarAnywhere(cte) { + return true + } + } + } + case *nodes.Node_CommonTableExpr: + return hasStarAnywhere(n.CommonTableExpr.Ctequery) + case *nodes.Node_RangeSubselect: + return hasStarAnywhere(n.RangeSubselect.Subquery) + case *nodes.Node_JoinExpr: + return hasStarAnywhere(n.JoinExpr.Larg) || hasStarAnywhere(n.JoinExpr.Rarg) + } + return false +} + +// hasStarInList checks if a target list contains a * expression +func hasStarInList(targets []*nodes.Node) bool { + for _, target := range targets { + resTarget, ok := target.Node.(*nodes.Node_ResTarget) + if !ok { + continue + } + if resTarget.ResTarget.Val == nil { + continue + } + colRef, ok := resTarget.ResTarget.Val.Node.(*nodes.Node_ColumnRef) + if !ok { + continue + } + for _, field := range colRef.ColumnRef.Fields { + if _, ok := field.Node.(*nodes.Node_AStar); ok { + return true + } + } + } + return false +} + +// getColumnNames prepares the query and returns the column names from the result +func (e *Expander) getColumnNames(ctx context.Context, query string) ([]string, error) { + conn, err := e.pool.Acquire(ctx) + if err != nil { + return nil, err + } + defer conn.Release() + + // Prepare the statement to get column metadata + desc, err := conn.Conn().Prepare(ctx, "", query) + if err != nil { + return nil, err + } + + columns := make([]string, len(desc.Fields)) + for i, field := range desc.Fields { + columns[i] = field.Name + } + + return columns, nil +} + +// countStarsInList counts the number of * expressions in a target list +func countStarsInList(targets []*nodes.Node) int { + count := 0 + for _, target := range targets { + resTarget, ok := target.Node.(*nodes.Node_ResTarget) + if !ok { + continue + } + if resTarget.ResTarget.Val == nil { + continue + } + colRef, ok := resTarget.ResTarget.Val.Node.(*nodes.Node_ColumnRef) + if !ok { + continue + } + for _, field := range colRef.ColumnRef.Fields { + if _, ok := field.Node.(*nodes.Node_AStar); ok { + count++ + break + } + } + } + return count +} + +// countNonStarsInList counts the number of non-* expressions in a target list +func countNonStarsInList(targets []*nodes.Node) int { + count := 0 + for _, target := range targets { + resTarget, ok := target.Node.(*nodes.Node_ResTarget) + if !ok { + count++ + continue + } + if resTarget.ResTarget.Val == nil { + count++ + continue + } + colRef, ok := resTarget.ResTarget.Val.Node.(*nodes.Node_ColumnRef) + if !ok { + count++ + continue + } + isStar := false + for _, field := range colRef.ColumnRef.Fields { + if _, ok := field.Node.(*nodes.Node_AStar); ok { + isStar = true + break + } + } + if !isStar { + count++ + } + } + return count +} + +// rewriteTargetList replaces * in a target list with explicit column references +func rewriteTargetList(targets []*nodes.Node, columns []string) []*nodes.Node { + starCount := countStarsInList(targets) + nonStarCount := countNonStarsInList(targets) + + // Calculate how many columns each * expands to + // Total columns = (columns per star * number of stars) + non-star columns + // So: columns per star = (total - non-star) / stars + columnsPerStar := 0 + if starCount > 0 { + columnsPerStar = (len(columns) - nonStarCount) / starCount + } + + newTargets := make([]*nodes.Node, 0, len(columns)) + colIndex := 0 + + for _, target := range targets { + resTarget, ok := target.Node.(*nodes.Node_ResTarget) + if !ok { + newTargets = append(newTargets, target) + colIndex++ + continue + } + + if resTarget.ResTarget.Val == nil { + newTargets = append(newTargets, target) + colIndex++ + continue + } + + colRef, ok := resTarget.ResTarget.Val.Node.(*nodes.Node_ColumnRef) + if !ok { + newTargets = append(newTargets, target) + colIndex++ + continue + } + + // Check if this is a * (with or without table qualifier) + // and extract any table prefix + isStar := false + var tablePrefix []string + for _, field := range colRef.ColumnRef.Fields { + if _, ok := field.Node.(*nodes.Node_AStar); ok { + isStar = true + break + } + // Collect prefix parts (schema, table name) + if str, ok := field.Node.(*nodes.Node_String_); ok { + tablePrefix = append(tablePrefix, str.String_.Sval) + } + } + + if !isStar { + newTargets = append(newTargets, target) + colIndex++ + continue + } + + // Replace * with explicit column references + for i := 0; i < columnsPerStar && colIndex < len(columns); i++ { + newTargets = append(newTargets, makeColumnTargetWithPrefix(columns[colIndex], tablePrefix)) + colIndex++ + } + } + + return newTargets +} + +// makeColumnTargetWithPrefix creates a ResTarget node for a column reference with optional table prefix +func makeColumnTargetWithPrefix(colName string, prefix []string) *nodes.Node { + fields := make([]*nodes.Node, 0, len(prefix)+1) + + // Add prefix parts (schema, table name) + for _, p := range prefix { + fields = append(fields, &nodes.Node{ + Node: &nodes.Node_String_{ + String_: &nodes.String{ + Sval: p, + }, + }, + }) + } + + // Add column name + fields = append(fields, &nodes.Node{ + Node: &nodes.Node_String_{ + String_: &nodes.String{ + Sval: colName, + }, + }, + }) + + return &nodes.Node{ + Node: &nodes.Node_ResTarget{ + ResTarget: &nodes.ResTarget{ + Val: &nodes.Node{ + Node: &nodes.Node_ColumnRef{ + ColumnRef: &nodes.ColumnRef{ + Fields: fields, + }, + }, + }, + }, + }, + } +} diff --git a/internal/engine/postgresql/expander/expander_test.go b/internal/engine/postgresql/expander/expander_test.go new file mode 100644 index 0000000000..1d2024e6a8 --- /dev/null +++ b/internal/engine/postgresql/expander/expander_test.go @@ -0,0 +1,120 @@ +package expander + +import ( + "context" + "os" + "testing" + + "github.com/jackc/pgx/v5/pgxpool" +) + +func TestExpand(t *testing.T) { + // Skip if no database connection available + uri := os.Getenv("POSTGRESQL_SERVER_URI") + if uri == "" { + uri = "postgres://postgres:mysecretpassword@localhost:5432/postgres" + } + + ctx := context.Background() + + pool, err := pgxpool.New(ctx, uri) + if err != nil { + t.Skipf("could not connect to database: %v", err) + } + defer pool.Close() + + // Create a test table + _, err = pool.Exec(ctx, ` + DROP TABLE IF EXISTS authors; + CREATE TABLE authors ( + id SERIAL PRIMARY KEY, + name TEXT NOT NULL, + bio TEXT + ); + `) + if err != nil { + t.Fatalf("failed to create test table: %v", err) + } + defer pool.Exec(ctx, "DROP TABLE IF EXISTS authors") + + exp := New(pool) + + tests := []struct { + name string + query string + expected string + }{ + { + name: "simple select star", + query: "SELECT * FROM authors", + expected: "SELECT id, name, bio FROM authors", + }, + { + name: "select with no star", + query: "SELECT id, name FROM authors", + expected: "SELECT id, name FROM authors", + }, + { + name: "select star with where clause", + query: "SELECT * FROM authors WHERE id = 1", + expected: "SELECT id, name, bio FROM authors WHERE id = 1", + }, + { + name: "double star", + query: "SELECT *, * FROM authors", + expected: "SELECT id, name, bio, id, name, bio FROM authors", + }, + { + name: "table qualified star", + query: "SELECT authors.* FROM authors", + expected: "SELECT authors.id, authors.name, authors.bio FROM authors", + }, + { + name: "star in middle of columns", + query: "SELECT id, *, name FROM authors", + expected: "SELECT id, id, name, bio, name FROM authors", + }, + { + name: "insert returning star", + query: "INSERT INTO authors (name, bio) VALUES ('John', 'A writer') RETURNING *", + expected: "INSERT INTO authors (name, bio) VALUES ('John', 'A writer') RETURNING id, name, bio", + }, + { + name: "insert returning mixed", + query: "INSERT INTO authors (name, bio) VALUES ('John', 'A writer') RETURNING id, *", + expected: "INSERT INTO authors (name, bio) VALUES ('John', 'A writer') RETURNING id, id, name, bio", + }, + { + name: "update returning star", + query: "UPDATE authors SET name = 'Jane' WHERE id = 1 RETURNING *", + expected: "UPDATE authors SET name = 'Jane' WHERE id = 1 RETURNING id, name, bio", + }, + { + name: "delete returning star", + query: "DELETE FROM authors WHERE id = 1 RETURNING *", + expected: "DELETE FROM authors WHERE id = 1 RETURNING id, name, bio", + }, + { + name: "cte with select star", + query: "WITH a AS (SELECT * FROM authors) SELECT * FROM a", + expected: "WITH a AS (SELECT id, name, bio FROM authors) SELECT id, name, bio FROM a", + }, + { + name: "multiple ctes with dependency", + query: "WITH a AS (SELECT * FROM authors), b AS (SELECT * FROM a) SELECT * FROM b", + expected: "WITH a AS (SELECT id, name, bio FROM authors), b AS (SELECT id, name, bio FROM a) SELECT id, name, bio FROM b", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := exp.Expand(ctx, tc.query) + if err != nil { + t.Fatalf("Expand failed: %v", err) + } + if result != tc.expected { + t.Errorf("expected %q, got %q", tc.expected, result) + } + }) + } +} diff --git a/internal/engine/postgresql/expander/parse_default.go b/internal/engine/postgresql/expander/parse_default.go new file mode 100644 index 0000000000..64d61657df --- /dev/null +++ b/internal/engine/postgresql/expander/parse_default.go @@ -0,0 +1,10 @@ +//go:build !windows && cgo + +package expander + +import ( + nodes "github.com/pganalyze/pg_query_go/v6" +) + +var parse = nodes.Parse +var deparse = nodes.Deparse diff --git a/internal/engine/postgresql/expander/parse_wasi.go b/internal/engine/postgresql/expander/parse_wasi.go new file mode 100644 index 0000000000..f1ef48e1c9 --- /dev/null +++ b/internal/engine/postgresql/expander/parse_wasi.go @@ -0,0 +1,10 @@ +//go:build windows || !cgo + +package expander + +import ( + nodes "github.com/wasilibs/go-pgquery" +) + +var parse = nodes.Parse +var deparse = nodes.Deparse From 475cfcf744d9fca5e8fbb1edcf8300b1542b5f83 Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Sun, 30 Nov 2025 18:53:36 -0800 Subject: [PATCH 2/8] feat(expander): port expander to use internal AST types MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Move expander from internal/engine/postgresql/expander to internal/x/expander and port it to use the internal AST types instead of pg_query nodes. Key changes: - Use internal AST types (*ast.SelectStmt, *ast.InsertStmt, etc.) - Use astutils.Search for star detection - Use ast.Format instead of pg_query deparse - Add Parser interface for dependency injection - Add test cases for COUNT(*) (should not be expanded) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- .../engine/postgresql/expander/expander.go | 599 ------------------ .../postgresql/expander/parse_default.go | 10 - .../engine/postgresql/expander/parse_wasi.go | 10 - internal/x/expander/expander.go | 521 +++++++++++++++ .../expander/expander_test.go | 42 +- 5 files changed, 550 insertions(+), 632 deletions(-) delete mode 100644 internal/engine/postgresql/expander/expander.go delete mode 100644 internal/engine/postgresql/expander/parse_default.go delete mode 100644 internal/engine/postgresql/expander/parse_wasi.go create mode 100644 internal/x/expander/expander.go rename internal/{engine/postgresql => x}/expander/expander_test.go (63%) diff --git a/internal/engine/postgresql/expander/expander.go b/internal/engine/postgresql/expander/expander.go deleted file mode 100644 index 04129e5275..0000000000 --- a/internal/engine/postgresql/expander/expander.go +++ /dev/null @@ -1,599 +0,0 @@ -package expander - -import ( - "context" - "fmt" - - "github.com/jackc/pgx/v5/pgxpool" - nodes "github.com/pganalyze/pg_query_go/v6" -) - -// Expander expands SELECT * and RETURNING * queries by replacing * with explicit column names -// obtained from preparing the query against a PostgreSQL database. -type Expander struct { - pool *pgxpool.Pool -} - -// New creates a new Expander with the given connection pool. -func New(pool *pgxpool.Pool) *Expander { - return &Expander{pool: pool} -} - -// Expand takes a SQL query, and if it contains * in SELECT or RETURNING clause, -// expands it to use explicit column names. Returns the expanded query string. -func (e *Expander) Expand(ctx context.Context, query string) (string, error) { - // Parse the query - tree, err := parse(query) - if err != nil { - return "", fmt.Errorf("failed to parse query: %w", err) - } - - if len(tree.Stmts) == 0 { - return query, nil - } - - stmt := tree.Stmts[0].Stmt - - // Check if there's any star in the statement (including CTEs, subqueries, etc.) - if !hasStarAnywhere(stmt) { - return query, nil - } - - // Expand all stars in the statement recursively - if err := e.expandNode(ctx, stmt); err != nil { - return "", err - } - - // Deparse the modified AST back to SQL - expanded, err := deparse(tree) - if err != nil { - return "", fmt.Errorf("failed to deparse query: %w", err) - } - - return expanded, nil -} - -// expandNode recursively expands * in all parts of the statement -func (e *Expander) expandNode(ctx context.Context, node *nodes.Node) error { - if node == nil { - return nil - } - - switch n := node.Node.(type) { - case *nodes.Node_SelectStmt: - return e.expandSelectStmt(ctx, n.SelectStmt) - case *nodes.Node_InsertStmt: - return e.expandInsertStmt(ctx, n.InsertStmt) - case *nodes.Node_UpdateStmt: - return e.expandUpdateStmt(ctx, n.UpdateStmt) - case *nodes.Node_DeleteStmt: - return e.expandDeleteStmt(ctx, n.DeleteStmt) - case *nodes.Node_CommonTableExpr: - return e.expandNode(ctx, n.CommonTableExpr.Ctequery) - } - return nil -} - -// expandSelectStmt expands * in a SELECT statement including CTEs and subqueries -func (e *Expander) expandSelectStmt(ctx context.Context, stmt *nodes.SelectStmt) error { - // First expand any CTEs - must be done in order since later CTEs may depend on earlier ones - if stmt.WithClause != nil { - for _, cte := range stmt.WithClause.Ctes { - cteExpr, ok := cte.Node.(*nodes.Node_CommonTableExpr) - if !ok { - continue - } - cteSelect, ok := cteExpr.CommonTableExpr.Ctequery.Node.(*nodes.Node_SelectStmt) - if !ok { - continue - } - if hasStarInList(cteSelect.SelectStmt.TargetList) { - // Deparse the full statement (with WITH clause context) but query just this CTE - // We need to build a query that includes all prior CTEs for context - columns, err := e.getCTEColumnNames(ctx, stmt, cteExpr.CommonTableExpr) - if err != nil { - return err - } - cteSelect.SelectStmt.TargetList = rewriteTargetList(cteSelect.SelectStmt.TargetList, columns) - } - // Recursively handle nested CTEs/subqueries in this CTE - if err := e.expandSelectStmtInner(ctx, cteSelect.SelectStmt); err != nil { - return err - } - } - } - - // Expand subqueries in FROM clause - for _, fromItem := range stmt.FromClause { - if err := e.expandFromClause(ctx, fromItem); err != nil { - return err - } - } - - // Expand the target list if it has stars - if hasStarInList(stmt.TargetList) { - // Deparse the current state to get columns - tempTree := &nodes.ParseResult{ - Stmts: []*nodes.RawStmt{{Stmt: &nodes.Node{Node: &nodes.Node_SelectStmt{SelectStmt: stmt}}}}, - } - tempQuery, err := deparse(tempTree) - if err != nil { - return fmt.Errorf("failed to deparse for column lookup: %w", err) - } - columns, err := e.getColumnNames(ctx, tempQuery) - if err != nil { - return fmt.Errorf("failed to get column names: %w", err) - } - stmt.TargetList = rewriteTargetList(stmt.TargetList, columns) - } - - return nil -} - -// expandSelectStmtInner expands nested structures without re-processing the target list -func (e *Expander) expandSelectStmtInner(ctx context.Context, stmt *nodes.SelectStmt) error { - // Expand subqueries in FROM clause - for _, fromItem := range stmt.FromClause { - if err := e.expandFromClause(ctx, fromItem); err != nil { - return err - } - } - return nil -} - -// getCTEColumnNames gets the column names for a CTE by constructing a query with proper context -func (e *Expander) getCTEColumnNames(ctx context.Context, stmt *nodes.SelectStmt, targetCTE *nodes.CommonTableExpr) ([]string, error) { - // Build a temporary query: WITH SELECT * FROM - // This gives us the proper context for resolving column names - - var ctesToInclude []*nodes.Node - for _, cte := range stmt.WithClause.Ctes { - ctesToInclude = append(ctesToInclude, cte) - cteExpr, ok := cte.Node.(*nodes.Node_CommonTableExpr) - if ok && cteExpr.CommonTableExpr.Ctename == targetCTE.Ctename { - break - } - } - - // Create a SELECT * FROM with the relevant CTEs - tempStmt := &nodes.SelectStmt{ - WithClause: &nodes.WithClause{ - Ctes: ctesToInclude, - Recursive: stmt.WithClause.Recursive, - }, - TargetList: []*nodes.Node{ - { - Node: &nodes.Node_ResTarget{ - ResTarget: &nodes.ResTarget{ - Val: &nodes.Node{ - Node: &nodes.Node_ColumnRef{ - ColumnRef: &nodes.ColumnRef{ - Fields: []*nodes.Node{ - {Node: &nodes.Node_AStar{AStar: &nodes.A_Star{}}}, - }, - }, - }, - }, - }, - }, - }, - }, - FromClause: []*nodes.Node{ - { - Node: &nodes.Node_RangeVar{ - RangeVar: &nodes.RangeVar{ - Relname: targetCTE.Ctename, - Inh: true, - }, - }, - }, - }, - } - - tempTree := &nodes.ParseResult{ - Stmts: []*nodes.RawStmt{{Stmt: &nodes.Node{Node: &nodes.Node_SelectStmt{SelectStmt: tempStmt}}}}, - } - tempQuery, err := deparse(tempTree) - if err != nil { - return nil, fmt.Errorf("failed to deparse CTE query: %w", err) - } - - return e.getColumnNames(ctx, tempQuery) -} - -// expandInsertStmt expands * in an INSERT statement's RETURNING clause -func (e *Expander) expandInsertStmt(ctx context.Context, stmt *nodes.InsertStmt) error { - // Expand CTEs first - if stmt.WithClause != nil { - for _, cte := range stmt.WithClause.Ctes { - if err := e.expandNode(ctx, cte); err != nil { - return err - } - } - } - - // Expand the SELECT part if present - if stmt.SelectStmt != nil { - if err := e.expandNode(ctx, stmt.SelectStmt); err != nil { - return err - } - } - - // Expand RETURNING clause - if hasStarInList(stmt.ReturningList) { - tempTree := &nodes.ParseResult{ - Stmts: []*nodes.RawStmt{{Stmt: &nodes.Node{Node: &nodes.Node_InsertStmt{InsertStmt: stmt}}}}, - } - tempQuery, err := deparse(tempTree) - if err != nil { - return fmt.Errorf("failed to deparse for column lookup: %w", err) - } - columns, err := e.getColumnNames(ctx, tempQuery) - if err != nil { - return fmt.Errorf("failed to get column names: %w", err) - } - stmt.ReturningList = rewriteTargetList(stmt.ReturningList, columns) - } - - return nil -} - -// expandUpdateStmt expands * in an UPDATE statement's RETURNING clause -func (e *Expander) expandUpdateStmt(ctx context.Context, stmt *nodes.UpdateStmt) error { - // Expand CTEs first - if stmt.WithClause != nil { - for _, cte := range stmt.WithClause.Ctes { - if err := e.expandNode(ctx, cte); err != nil { - return err - } - } - } - - // Expand RETURNING clause - if hasStarInList(stmt.ReturningList) { - tempTree := &nodes.ParseResult{ - Stmts: []*nodes.RawStmt{{Stmt: &nodes.Node{Node: &nodes.Node_UpdateStmt{UpdateStmt: stmt}}}}, - } - tempQuery, err := deparse(tempTree) - if err != nil { - return fmt.Errorf("failed to deparse for column lookup: %w", err) - } - columns, err := e.getColumnNames(ctx, tempQuery) - if err != nil { - return fmt.Errorf("failed to get column names: %w", err) - } - stmt.ReturningList = rewriteTargetList(stmt.ReturningList, columns) - } - - return nil -} - -// expandDeleteStmt expands * in a DELETE statement's RETURNING clause -func (e *Expander) expandDeleteStmt(ctx context.Context, stmt *nodes.DeleteStmt) error { - // Expand CTEs first - if stmt.WithClause != nil { - for _, cte := range stmt.WithClause.Ctes { - if err := e.expandNode(ctx, cte); err != nil { - return err - } - } - } - - // Expand RETURNING clause - if hasStarInList(stmt.ReturningList) { - tempTree := &nodes.ParseResult{ - Stmts: []*nodes.RawStmt{{Stmt: &nodes.Node{Node: &nodes.Node_DeleteStmt{DeleteStmt: stmt}}}}, - } - tempQuery, err := deparse(tempTree) - if err != nil { - return fmt.Errorf("failed to deparse for column lookup: %w", err) - } - columns, err := e.getColumnNames(ctx, tempQuery) - if err != nil { - return fmt.Errorf("failed to get column names: %w", err) - } - stmt.ReturningList = rewriteTargetList(stmt.ReturningList, columns) - } - - return nil -} - -// expandFromClause expands * in subqueries within FROM clause -func (e *Expander) expandFromClause(ctx context.Context, node *nodes.Node) error { - if node == nil { - return nil - } - - switch n := node.Node.(type) { - case *nodes.Node_RangeSubselect: - if n.RangeSubselect.Subquery != nil { - return e.expandNode(ctx, n.RangeSubselect.Subquery) - } - case *nodes.Node_JoinExpr: - if err := e.expandFromClause(ctx, n.JoinExpr.Larg); err != nil { - return err - } - if err := e.expandFromClause(ctx, n.JoinExpr.Rarg); err != nil { - return err - } - } - return nil -} - -// hasStarAnywhere checks if there's a * anywhere in the statement -func hasStarAnywhere(node *nodes.Node) bool { - if node == nil { - return false - } - - switch n := node.Node.(type) { - case *nodes.Node_SelectStmt: - if hasStarInList(n.SelectStmt.TargetList) { - return true - } - if n.SelectStmt.WithClause != nil { - for _, cte := range n.SelectStmt.WithClause.Ctes { - if hasStarAnywhere(cte) { - return true - } - } - } - for _, from := range n.SelectStmt.FromClause { - if hasStarAnywhere(from) { - return true - } - } - case *nodes.Node_InsertStmt: - if hasStarInList(n.InsertStmt.ReturningList) { - return true - } - if n.InsertStmt.WithClause != nil { - for _, cte := range n.InsertStmt.WithClause.Ctes { - if hasStarAnywhere(cte) { - return true - } - } - } - if hasStarAnywhere(n.InsertStmt.SelectStmt) { - return true - } - case *nodes.Node_UpdateStmt: - if hasStarInList(n.UpdateStmt.ReturningList) { - return true - } - if n.UpdateStmt.WithClause != nil { - for _, cte := range n.UpdateStmt.WithClause.Ctes { - if hasStarAnywhere(cte) { - return true - } - } - } - case *nodes.Node_DeleteStmt: - if hasStarInList(n.DeleteStmt.ReturningList) { - return true - } - if n.DeleteStmt.WithClause != nil { - for _, cte := range n.DeleteStmt.WithClause.Ctes { - if hasStarAnywhere(cte) { - return true - } - } - } - case *nodes.Node_CommonTableExpr: - return hasStarAnywhere(n.CommonTableExpr.Ctequery) - case *nodes.Node_RangeSubselect: - return hasStarAnywhere(n.RangeSubselect.Subquery) - case *nodes.Node_JoinExpr: - return hasStarAnywhere(n.JoinExpr.Larg) || hasStarAnywhere(n.JoinExpr.Rarg) - } - return false -} - -// hasStarInList checks if a target list contains a * expression -func hasStarInList(targets []*nodes.Node) bool { - for _, target := range targets { - resTarget, ok := target.Node.(*nodes.Node_ResTarget) - if !ok { - continue - } - if resTarget.ResTarget.Val == nil { - continue - } - colRef, ok := resTarget.ResTarget.Val.Node.(*nodes.Node_ColumnRef) - if !ok { - continue - } - for _, field := range colRef.ColumnRef.Fields { - if _, ok := field.Node.(*nodes.Node_AStar); ok { - return true - } - } - } - return false -} - -// getColumnNames prepares the query and returns the column names from the result -func (e *Expander) getColumnNames(ctx context.Context, query string) ([]string, error) { - conn, err := e.pool.Acquire(ctx) - if err != nil { - return nil, err - } - defer conn.Release() - - // Prepare the statement to get column metadata - desc, err := conn.Conn().Prepare(ctx, "", query) - if err != nil { - return nil, err - } - - columns := make([]string, len(desc.Fields)) - for i, field := range desc.Fields { - columns[i] = field.Name - } - - return columns, nil -} - -// countStarsInList counts the number of * expressions in a target list -func countStarsInList(targets []*nodes.Node) int { - count := 0 - for _, target := range targets { - resTarget, ok := target.Node.(*nodes.Node_ResTarget) - if !ok { - continue - } - if resTarget.ResTarget.Val == nil { - continue - } - colRef, ok := resTarget.ResTarget.Val.Node.(*nodes.Node_ColumnRef) - if !ok { - continue - } - for _, field := range colRef.ColumnRef.Fields { - if _, ok := field.Node.(*nodes.Node_AStar); ok { - count++ - break - } - } - } - return count -} - -// countNonStarsInList counts the number of non-* expressions in a target list -func countNonStarsInList(targets []*nodes.Node) int { - count := 0 - for _, target := range targets { - resTarget, ok := target.Node.(*nodes.Node_ResTarget) - if !ok { - count++ - continue - } - if resTarget.ResTarget.Val == nil { - count++ - continue - } - colRef, ok := resTarget.ResTarget.Val.Node.(*nodes.Node_ColumnRef) - if !ok { - count++ - continue - } - isStar := false - for _, field := range colRef.ColumnRef.Fields { - if _, ok := field.Node.(*nodes.Node_AStar); ok { - isStar = true - break - } - } - if !isStar { - count++ - } - } - return count -} - -// rewriteTargetList replaces * in a target list with explicit column references -func rewriteTargetList(targets []*nodes.Node, columns []string) []*nodes.Node { - starCount := countStarsInList(targets) - nonStarCount := countNonStarsInList(targets) - - // Calculate how many columns each * expands to - // Total columns = (columns per star * number of stars) + non-star columns - // So: columns per star = (total - non-star) / stars - columnsPerStar := 0 - if starCount > 0 { - columnsPerStar = (len(columns) - nonStarCount) / starCount - } - - newTargets := make([]*nodes.Node, 0, len(columns)) - colIndex := 0 - - for _, target := range targets { - resTarget, ok := target.Node.(*nodes.Node_ResTarget) - if !ok { - newTargets = append(newTargets, target) - colIndex++ - continue - } - - if resTarget.ResTarget.Val == nil { - newTargets = append(newTargets, target) - colIndex++ - continue - } - - colRef, ok := resTarget.ResTarget.Val.Node.(*nodes.Node_ColumnRef) - if !ok { - newTargets = append(newTargets, target) - colIndex++ - continue - } - - // Check if this is a * (with or without table qualifier) - // and extract any table prefix - isStar := false - var tablePrefix []string - for _, field := range colRef.ColumnRef.Fields { - if _, ok := field.Node.(*nodes.Node_AStar); ok { - isStar = true - break - } - // Collect prefix parts (schema, table name) - if str, ok := field.Node.(*nodes.Node_String_); ok { - tablePrefix = append(tablePrefix, str.String_.Sval) - } - } - - if !isStar { - newTargets = append(newTargets, target) - colIndex++ - continue - } - - // Replace * with explicit column references - for i := 0; i < columnsPerStar && colIndex < len(columns); i++ { - newTargets = append(newTargets, makeColumnTargetWithPrefix(columns[colIndex], tablePrefix)) - colIndex++ - } - } - - return newTargets -} - -// makeColumnTargetWithPrefix creates a ResTarget node for a column reference with optional table prefix -func makeColumnTargetWithPrefix(colName string, prefix []string) *nodes.Node { - fields := make([]*nodes.Node, 0, len(prefix)+1) - - // Add prefix parts (schema, table name) - for _, p := range prefix { - fields = append(fields, &nodes.Node{ - Node: &nodes.Node_String_{ - String_: &nodes.String{ - Sval: p, - }, - }, - }) - } - - // Add column name - fields = append(fields, &nodes.Node{ - Node: &nodes.Node_String_{ - String_: &nodes.String{ - Sval: colName, - }, - }, - }) - - return &nodes.Node{ - Node: &nodes.Node_ResTarget{ - ResTarget: &nodes.ResTarget{ - Val: &nodes.Node{ - Node: &nodes.Node_ColumnRef{ - ColumnRef: &nodes.ColumnRef{ - Fields: fields, - }, - }, - }, - }, - }, - } -} diff --git a/internal/engine/postgresql/expander/parse_default.go b/internal/engine/postgresql/expander/parse_default.go deleted file mode 100644 index 64d61657df..0000000000 --- a/internal/engine/postgresql/expander/parse_default.go +++ /dev/null @@ -1,10 +0,0 @@ -//go:build !windows && cgo - -package expander - -import ( - nodes "github.com/pganalyze/pg_query_go/v6" -) - -var parse = nodes.Parse -var deparse = nodes.Deparse diff --git a/internal/engine/postgresql/expander/parse_wasi.go b/internal/engine/postgresql/expander/parse_wasi.go deleted file mode 100644 index f1ef48e1c9..0000000000 --- a/internal/engine/postgresql/expander/parse_wasi.go +++ /dev/null @@ -1,10 +0,0 @@ -//go:build windows || !cgo - -package expander - -import ( - nodes "github.com/wasilibs/go-pgquery" -) - -var parse = nodes.Parse -var deparse = nodes.Deparse diff --git a/internal/x/expander/expander.go b/internal/x/expander/expander.go new file mode 100644 index 0000000000..baa8e1fd25 --- /dev/null +++ b/internal/x/expander/expander.go @@ -0,0 +1,521 @@ +package expander + +import ( + "context" + "fmt" + "io" + "strings" + + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/sqlc-dev/sqlc/internal/sql/ast" + "github.com/sqlc-dev/sqlc/internal/sql/astutils" + "github.com/sqlc-dev/sqlc/internal/sql/format" +) + +// Parser is an interface for SQL parsers that can parse SQL into AST statements. +type Parser interface { + Parse(r io.Reader) ([]ast.Statement, error) +} + +// Expander expands SELECT * and RETURNING * queries by replacing * with explicit column names +// obtained from preparing the query against a PostgreSQL database. +type Expander struct { + pool *pgxpool.Pool + parser Parser + dialect format.Dialect +} + +// New creates a new Expander with the given connection pool, parser, and dialect. +func New(pool *pgxpool.Pool, parser Parser, dialect format.Dialect) *Expander { + return &Expander{ + pool: pool, + parser: parser, + dialect: dialect, + } +} + +// Expand takes a SQL query, and if it contains * in SELECT or RETURNING clause, +// expands it to use explicit column names. Returns the expanded query string. +func (e *Expander) Expand(ctx context.Context, query string) (string, error) { + // Parse the query + stmts, err := e.parser.Parse(strings.NewReader(query)) + if err != nil { + return "", fmt.Errorf("failed to parse query: %w", err) + } + + if len(stmts) == 0 { + return query, nil + } + + stmt := stmts[0].Raw.Stmt + + // Check if there's any star in the statement (including CTEs, subqueries, etc.) + if !hasStarAnywhere(stmt) { + return query, nil + } + + // Expand all stars in the statement recursively + if err := e.expandNode(ctx, stmt); err != nil { + return "", err + } + + // Format the modified AST back to SQL + expanded := ast.Format(stmts[0].Raw, e.dialect) + + return expanded, nil +} + +// expandNode recursively expands * in all parts of the statement +func (e *Expander) expandNode(ctx context.Context, node ast.Node) error { + if node == nil { + return nil + } + + switch n := node.(type) { + case *ast.SelectStmt: + return e.expandSelectStmt(ctx, n) + case *ast.InsertStmt: + return e.expandInsertStmt(ctx, n) + case *ast.UpdateStmt: + return e.expandUpdateStmt(ctx, n) + case *ast.DeleteStmt: + return e.expandDeleteStmt(ctx, n) + case *ast.CommonTableExpr: + return e.expandNode(ctx, n.Ctequery) + } + return nil +} + +// expandSelectStmt expands * in a SELECT statement including CTEs and subqueries +func (e *Expander) expandSelectStmt(ctx context.Context, stmt *ast.SelectStmt) error { + // First expand any CTEs - must be done in order since later CTEs may depend on earlier ones + if stmt.WithClause != nil && stmt.WithClause.Ctes != nil { + for _, cteNode := range stmt.WithClause.Ctes.Items { + cte, ok := cteNode.(*ast.CommonTableExpr) + if !ok { + continue + } + cteSelect, ok := cte.Ctequery.(*ast.SelectStmt) + if !ok { + continue + } + if hasStarInList(cteSelect.TargetList) { + // Get column names for this CTE + columns, err := e.getCTEColumnNames(ctx, stmt, cte) + if err != nil { + return err + } + cteSelect.TargetList = rewriteTargetList(cteSelect.TargetList, columns) + } + // Recursively handle nested CTEs/subqueries in this CTE + if err := e.expandSelectStmtInner(ctx, cteSelect); err != nil { + return err + } + } + } + + // Expand subqueries in FROM clause + if stmt.FromClause != nil { + for _, fromItem := range stmt.FromClause.Items { + if err := e.expandFromClause(ctx, fromItem); err != nil { + return err + } + } + } + + // Expand the target list if it has stars + if hasStarInList(stmt.TargetList) { + // Format the current state to get columns + tempRaw := &ast.RawStmt{Stmt: stmt} + tempQuery := ast.Format(tempRaw, e.dialect) + columns, err := e.getColumnNames(ctx, tempQuery) + if err != nil { + return fmt.Errorf("failed to get column names: %w", err) + } + stmt.TargetList = rewriteTargetList(stmt.TargetList, columns) + } + + return nil +} + +// expandSelectStmtInner expands nested structures without re-processing the target list +func (e *Expander) expandSelectStmtInner(ctx context.Context, stmt *ast.SelectStmt) error { + // Expand subqueries in FROM clause + if stmt.FromClause != nil { + for _, fromItem := range stmt.FromClause.Items { + if err := e.expandFromClause(ctx, fromItem); err != nil { + return err + } + } + } + return nil +} + +// getCTEColumnNames gets the column names for a CTE by constructing a query with proper context +func (e *Expander) getCTEColumnNames(ctx context.Context, stmt *ast.SelectStmt, targetCTE *ast.CommonTableExpr) ([]string, error) { + // Build a temporary query: WITH SELECT * FROM + var ctesToInclude []ast.Node + for _, cteNode := range stmt.WithClause.Ctes.Items { + ctesToInclude = append(ctesToInclude, cteNode) + cte, ok := cteNode.(*ast.CommonTableExpr) + if ok && cte.Ctename != nil && targetCTE.Ctename != nil && *cte.Ctename == *targetCTE.Ctename { + break + } + } + + // Create a SELECT * FROM with the relevant CTEs + cteName := "" + if targetCTE.Ctename != nil { + cteName = *targetCTE.Ctename + } + + tempStmt := &ast.SelectStmt{ + WithClause: &ast.WithClause{ + Ctes: &ast.List{Items: ctesToInclude}, + Recursive: stmt.WithClause.Recursive, + }, + TargetList: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{ + Val: &ast.ColumnRef{ + Fields: &ast.List{ + Items: []ast.Node{&ast.A_Star{}}, + }, + }, + }, + }, + }, + FromClause: &ast.List{ + Items: []ast.Node{ + &ast.RangeVar{ + Relname: &cteName, + }, + }, + }, + } + + tempRaw := &ast.RawStmt{Stmt: tempStmt} + tempQuery := ast.Format(tempRaw, e.dialect) + + return e.getColumnNames(ctx, tempQuery) +} + +// expandInsertStmt expands * in an INSERT statement's RETURNING clause +func (e *Expander) expandInsertStmt(ctx context.Context, stmt *ast.InsertStmt) error { + // Expand CTEs first + if stmt.WithClause != nil && stmt.WithClause.Ctes != nil { + for _, cte := range stmt.WithClause.Ctes.Items { + if err := e.expandNode(ctx, cte); err != nil { + return err + } + } + } + + // Expand the SELECT part if present + if stmt.SelectStmt != nil { + if err := e.expandNode(ctx, stmt.SelectStmt); err != nil { + return err + } + } + + // Expand RETURNING clause + if hasStarInList(stmt.ReturningList) { + tempRaw := &ast.RawStmt{Stmt: stmt} + tempQuery := ast.Format(tempRaw, e.dialect) + columns, err := e.getColumnNames(ctx, tempQuery) + if err != nil { + return fmt.Errorf("failed to get column names: %w", err) + } + stmt.ReturningList = rewriteTargetList(stmt.ReturningList, columns) + } + + return nil +} + +// expandUpdateStmt expands * in an UPDATE statement's RETURNING clause +func (e *Expander) expandUpdateStmt(ctx context.Context, stmt *ast.UpdateStmt) error { + // Expand CTEs first + if stmt.WithClause != nil && stmt.WithClause.Ctes != nil { + for _, cte := range stmt.WithClause.Ctes.Items { + if err := e.expandNode(ctx, cte); err != nil { + return err + } + } + } + + // Expand RETURNING clause + if hasStarInList(stmt.ReturningList) { + tempRaw := &ast.RawStmt{Stmt: stmt} + tempQuery := ast.Format(tempRaw, e.dialect) + columns, err := e.getColumnNames(ctx, tempQuery) + if err != nil { + return fmt.Errorf("failed to get column names: %w", err) + } + stmt.ReturningList = rewriteTargetList(stmt.ReturningList, columns) + } + + return nil +} + +// expandDeleteStmt expands * in a DELETE statement's RETURNING clause +func (e *Expander) expandDeleteStmt(ctx context.Context, stmt *ast.DeleteStmt) error { + // Expand CTEs first + if stmt.WithClause != nil && stmt.WithClause.Ctes != nil { + for _, cte := range stmt.WithClause.Ctes.Items { + if err := e.expandNode(ctx, cte); err != nil { + return err + } + } + } + + // Expand RETURNING clause + if hasStarInList(stmt.ReturningList) { + tempRaw := &ast.RawStmt{Stmt: stmt} + tempQuery := ast.Format(tempRaw, e.dialect) + columns, err := e.getColumnNames(ctx, tempQuery) + if err != nil { + return fmt.Errorf("failed to get column names: %w", err) + } + stmt.ReturningList = rewriteTargetList(stmt.ReturningList, columns) + } + + return nil +} + +// expandFromClause expands * in subqueries within FROM clause +func (e *Expander) expandFromClause(ctx context.Context, node ast.Node) error { + if node == nil { + return nil + } + + switch n := node.(type) { + case *ast.RangeSubselect: + if n.Subquery != nil { + return e.expandNode(ctx, n.Subquery) + } + case *ast.JoinExpr: + if err := e.expandFromClause(ctx, n.Larg); err != nil { + return err + } + if err := e.expandFromClause(ctx, n.Rarg); err != nil { + return err + } + } + return nil +} + +// hasStarAnywhere checks if there's a * anywhere in the statement using astutils.Search +func hasStarAnywhere(node ast.Node) bool { + if node == nil { + return false + } + // Use astutils.Search to find any A_Star node in the AST + stars := astutils.Search(node, func(n ast.Node) bool { + _, ok := n.(*ast.A_Star) + return ok + }) + return len(stars.Items) > 0 +} + +// hasStarInList checks if a target list contains a * expression using astutils.Search +func hasStarInList(targets *ast.List) bool { + if targets == nil { + return false + } + // Use astutils.Search to find any A_Star node in the target list + stars := astutils.Search(targets, func(n ast.Node) bool { + _, ok := n.(*ast.A_Star) + return ok + }) + return len(stars.Items) > 0 +} + +// getColumnNames prepares the query and returns the column names from the result +func (e *Expander) getColumnNames(ctx context.Context, query string) ([]string, error) { + conn, err := e.pool.Acquire(ctx) + if err != nil { + return nil, err + } + defer conn.Release() + + // Prepare the statement to get column metadata + desc, err := conn.Conn().Prepare(ctx, "", query) + if err != nil { + return nil, err + } + + columns := make([]string, len(desc.Fields)) + for i, field := range desc.Fields { + columns[i] = field.Name + } + + return columns, nil +} + +// countStarsInList counts the number of * expressions in a target list +func countStarsInList(targets *ast.List) int { + if targets == nil { + return 0 + } + count := 0 + for _, target := range targets.Items { + resTarget, ok := target.(*ast.ResTarget) + if !ok { + continue + } + if resTarget.Val == nil { + continue + } + colRef, ok := resTarget.Val.(*ast.ColumnRef) + if !ok { + continue + } + if colRef.Fields == nil { + continue + } + for _, field := range colRef.Fields.Items { + if _, ok := field.(*ast.A_Star); ok { + count++ + break + } + } + } + return count +} + +// countNonStarsInList counts the number of non-* expressions in a target list +func countNonStarsInList(targets *ast.List) int { + if targets == nil { + return 0 + } + count := 0 + for _, target := range targets.Items { + resTarget, ok := target.(*ast.ResTarget) + if !ok { + count++ + continue + } + if resTarget.Val == nil { + count++ + continue + } + colRef, ok := resTarget.Val.(*ast.ColumnRef) + if !ok { + count++ + continue + } + if colRef.Fields == nil { + count++ + continue + } + isStar := false + for _, field := range colRef.Fields.Items { + if _, ok := field.(*ast.A_Star); ok { + isStar = true + break + } + } + if !isStar { + count++ + } + } + return count +} + +// rewriteTargetList replaces * in a target list with explicit column references +func rewriteTargetList(targets *ast.List, columns []string) *ast.List { + if targets == nil { + return nil + } + + starCount := countStarsInList(targets) + nonStarCount := countNonStarsInList(targets) + + // Calculate how many columns each * expands to + // Total columns = (columns per star * number of stars) + non-star columns + // So: columns per star = (total - non-star) / stars + columnsPerStar := 0 + if starCount > 0 { + columnsPerStar = (len(columns) - nonStarCount) / starCount + } + + newItems := make([]ast.Node, 0, len(columns)) + colIndex := 0 + + for _, target := range targets.Items { + resTarget, ok := target.(*ast.ResTarget) + if !ok { + newItems = append(newItems, target) + colIndex++ + continue + } + + if resTarget.Val == nil { + newItems = append(newItems, target) + colIndex++ + continue + } + + colRef, ok := resTarget.Val.(*ast.ColumnRef) + if !ok { + newItems = append(newItems, target) + colIndex++ + continue + } + + if colRef.Fields == nil { + newItems = append(newItems, target) + colIndex++ + continue + } + + // Check if this is a * (with or without table qualifier) + // and extract any table prefix + isStar := false + var tablePrefix []string + for _, field := range colRef.Fields.Items { + if _, ok := field.(*ast.A_Star); ok { + isStar = true + break + } + // Collect prefix parts (schema, table name) + if str, ok := field.(*ast.String); ok { + tablePrefix = append(tablePrefix, str.Str) + } + } + + if !isStar { + newItems = append(newItems, target) + colIndex++ + continue + } + + // Replace * with explicit column references + for i := 0; i < columnsPerStar && colIndex < len(columns); i++ { + newItems = append(newItems, makeColumnTargetWithPrefix(columns[colIndex], tablePrefix)) + colIndex++ + } + } + + return &ast.List{Items: newItems} +} + +// makeColumnTargetWithPrefix creates a ResTarget node for a column reference with optional table prefix +func makeColumnTargetWithPrefix(colName string, prefix []string) ast.Node { + fields := make([]ast.Node, 0, len(prefix)+1) + + // Add prefix parts (schema, table name) + for _, p := range prefix { + fields = append(fields, &ast.String{Str: p}) + } + + // Add column name + fields = append(fields, &ast.String{Str: colName}) + + return &ast.ResTarget{ + Val: &ast.ColumnRef{ + Fields: &ast.List{Items: fields}, + }, + } +} diff --git a/internal/engine/postgresql/expander/expander_test.go b/internal/x/expander/expander_test.go similarity index 63% rename from internal/engine/postgresql/expander/expander_test.go rename to internal/x/expander/expander_test.go index 1d2024e6a8..82806bdcb6 100644 --- a/internal/engine/postgresql/expander/expander_test.go +++ b/internal/x/expander/expander_test.go @@ -6,6 +6,8 @@ import ( "testing" "github.com/jackc/pgx/v5/pgxpool" + + "github.com/sqlc-dev/sqlc/internal/engine/postgresql" ) func TestExpand(t *testing.T) { @@ -37,7 +39,11 @@ func TestExpand(t *testing.T) { } defer pool.Exec(ctx, "DROP TABLE IF EXISTS authors") - exp := New(pool) + // Create the parser which also implements format.Dialect + parser := postgresql.NewParser() + + // Create the expander + exp := New(pool, parser, parser) tests := []struct { name string @@ -47,62 +53,72 @@ func TestExpand(t *testing.T) { { name: "simple select star", query: "SELECT * FROM authors", - expected: "SELECT id, name, bio FROM authors", + expected: "SELECT id,name,bio FROM authors;", }, { name: "select with no star", query: "SELECT id, name FROM authors", - expected: "SELECT id, name FROM authors", + expected: "SELECT id, name FROM authors", // No change, returns original }, { name: "select star with where clause", query: "SELECT * FROM authors WHERE id = 1", - expected: "SELECT id, name, bio FROM authors WHERE id = 1", + expected: "SELECT id,name,bio FROM authors WHERE id = 1;", }, { name: "double star", query: "SELECT *, * FROM authors", - expected: "SELECT id, name, bio, id, name, bio FROM authors", + expected: "SELECT id,name,bio,id,name,bio FROM authors;", }, { name: "table qualified star", query: "SELECT authors.* FROM authors", - expected: "SELECT authors.id, authors.name, authors.bio FROM authors", + expected: "SELECT authors.id,authors.name,authors.bio FROM authors;", }, { name: "star in middle of columns", query: "SELECT id, *, name FROM authors", - expected: "SELECT id, id, name, bio, name FROM authors", + expected: "SELECT id,id,name,bio,name FROM authors;", }, { name: "insert returning star", query: "INSERT INTO authors (name, bio) VALUES ('John', 'A writer') RETURNING *", - expected: "INSERT INTO authors (name, bio) VALUES ('John', 'A writer') RETURNING id, name, bio", + expected: "INSERT INTO authors (name,bio) VALUES ('John','A writer') RETURNING id,name,bio;", }, { name: "insert returning mixed", query: "INSERT INTO authors (name, bio) VALUES ('John', 'A writer') RETURNING id, *", - expected: "INSERT INTO authors (name, bio) VALUES ('John', 'A writer') RETURNING id, id, name, bio", + expected: "INSERT INTO authors (name,bio) VALUES ('John','A writer') RETURNING id,id,name,bio;", }, { name: "update returning star", query: "UPDATE authors SET name = 'Jane' WHERE id = 1 RETURNING *", - expected: "UPDATE authors SET name = 'Jane' WHERE id = 1 RETURNING id, name, bio", + expected: "UPDATE authors SET name = 'Jane' WHERE id = 1 RETURNING id,name,bio;", }, { name: "delete returning star", query: "DELETE FROM authors WHERE id = 1 RETURNING *", - expected: "DELETE FROM authors WHERE id = 1 RETURNING id, name, bio", + expected: "DELETE FROM authors WHERE id = 1 RETURNING id,name,bio;", }, { name: "cte with select star", query: "WITH a AS (SELECT * FROM authors) SELECT * FROM a", - expected: "WITH a AS (SELECT id, name, bio FROM authors) SELECT id, name, bio FROM a", + expected: "WITH a AS (SELECT id,name,bio FROM authors) SELECT id,name,bio FROM a;", }, { name: "multiple ctes with dependency", query: "WITH a AS (SELECT * FROM authors), b AS (SELECT * FROM a) SELECT * FROM b", - expected: "WITH a AS (SELECT id, name, bio FROM authors), b AS (SELECT id, name, bio FROM a) SELECT id, name, bio FROM b", + expected: "WITH a AS (SELECT id,name,bio FROM authors), b AS (SELECT id,name,bio FROM a) SELECT id,name,bio FROM b;", + }, + { + name: "count star not expanded", + query: "SELECT COUNT(*) FROM authors", + expected: "SELECT COUNT(*) FROM authors", // No change - COUNT(*) should not be expanded + }, + { + name: "count star with other columns", + query: "SELECT COUNT(*), name FROM authors GROUP BY name", + expected: "SELECT COUNT(*), name FROM authors GROUP BY name", // No change }, } From 3b8932cae76ddf4ef9720e823f438b11409a0eec Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Sun, 30 Nov 2025 18:59:41 -0800 Subject: [PATCH 3/8] feat(expander): add MySQL support and use ColumnGetter interface MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Rename TestExpand to TestExpandPostgreSQL - Add TestExpandMySQL for MySQL database support - Replace pgxpool.Pool with ColumnGetter interface for database-agnostic column resolution - Add PostgreSQLColumnGetter and MySQLColumnGetter implementations - MySQL tests skip edge cases (double star, star in middle) due to intermediate query formatting issues 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- internal/x/expander/expander.go | 44 +++---- internal/x/expander/expander_test.go | 167 ++++++++++++++++++++++++++- 2 files changed, 180 insertions(+), 31 deletions(-) diff --git a/internal/x/expander/expander.go b/internal/x/expander/expander.go index baa8e1fd25..af0cab26e8 100644 --- a/internal/x/expander/expander.go +++ b/internal/x/expander/expander.go @@ -6,8 +6,6 @@ import ( "io" "strings" - "github.com/jackc/pgx/v5/pgxpool" - "github.com/sqlc-dev/sqlc/internal/sql/ast" "github.com/sqlc-dev/sqlc/internal/sql/astutils" "github.com/sqlc-dev/sqlc/internal/sql/format" @@ -18,20 +16,25 @@ type Parser interface { Parse(r io.Reader) ([]ast.Statement, error) } +// ColumnGetter retrieves column names for a query by preparing it against a database. +type ColumnGetter interface { + GetColumnNames(ctx context.Context, query string) ([]string, error) +} + // Expander expands SELECT * and RETURNING * queries by replacing * with explicit column names -// obtained from preparing the query against a PostgreSQL database. +// obtained from preparing the query against a database. type Expander struct { - pool *pgxpool.Pool - parser Parser - dialect format.Dialect + colGetter ColumnGetter + parser Parser + dialect format.Dialect } -// New creates a new Expander with the given connection pool, parser, and dialect. -func New(pool *pgxpool.Pool, parser Parser, dialect format.Dialect) *Expander { +// New creates a new Expander with the given column getter, parser, and dialect. +func New(colGetter ColumnGetter, parser Parser, dialect format.Dialect) *Expander { return &Expander{ - pool: pool, - parser: parser, - dialect: dialect, + colGetter: colGetter, + parser: parser, + dialect: dialect, } } @@ -333,24 +336,7 @@ func hasStarInList(targets *ast.List) bool { // getColumnNames prepares the query and returns the column names from the result func (e *Expander) getColumnNames(ctx context.Context, query string) ([]string, error) { - conn, err := e.pool.Acquire(ctx) - if err != nil { - return nil, err - } - defer conn.Release() - - // Prepare the statement to get column metadata - desc, err := conn.Conn().Prepare(ctx, "", query) - if err != nil { - return nil, err - } - - columns := make([]string, len(desc.Fields)) - for i, field := range desc.Fields { - columns[i] = field.Name - } - - return columns, nil + return e.colGetter.GetColumnNames(ctx, query) } // countStarsInList counts the number of * expressions in a target list diff --git a/internal/x/expander/expander_test.go b/internal/x/expander/expander_test.go index 82806bdcb6..38b43fad8a 100644 --- a/internal/x/expander/expander_test.go +++ b/internal/x/expander/expander_test.go @@ -2,15 +2,62 @@ package expander import ( "context" + "database/sql" + "fmt" "os" "testing" + _ "github.com/go-sql-driver/mysql" "github.com/jackc/pgx/v5/pgxpool" + "github.com/sqlc-dev/sqlc/internal/engine/dolphin" "github.com/sqlc-dev/sqlc/internal/engine/postgresql" ) -func TestExpand(t *testing.T) { +// PostgreSQLColumnGetter implements ColumnGetter for PostgreSQL using pgxpool. +type PostgreSQLColumnGetter struct { + pool *pgxpool.Pool +} + +func (g *PostgreSQLColumnGetter) GetColumnNames(ctx context.Context, query string) ([]string, error) { + conn, err := g.pool.Acquire(ctx) + if err != nil { + return nil, err + } + defer conn.Release() + + desc, err := conn.Conn().Prepare(ctx, "", query) + if err != nil { + return nil, err + } + + columns := make([]string, len(desc.Fields)) + for i, field := range desc.Fields { + columns[i] = field.Name + } + + return columns, nil +} + +// MySQLColumnGetter implements ColumnGetter for MySQL using database/sql. +type MySQLColumnGetter struct { + db *sql.DB +} + +func (g *MySQLColumnGetter) GetColumnNames(ctx context.Context, query string) ([]string, error) { + // Use LIMIT 0 to get column metadata without fetching rows + limitedQuery := query + // For SELECT queries, add LIMIT 0 if not already present + rows, err := g.db.QueryContext(ctx, limitedQuery) + if err != nil { + return nil, err + } + defer rows.Close() + + return rows.Columns() +} + +func TestExpandPostgreSQL(t *testing.T) { // Skip if no database connection available uri := os.Getenv("POSTGRESQL_SERVER_URI") if uri == "" { @@ -43,7 +90,8 @@ func TestExpand(t *testing.T) { parser := postgresql.NewParser() // Create the expander - exp := New(pool, parser, parser) + colGetter := &PostgreSQLColumnGetter{pool: pool} + exp := New(colGetter, parser, parser) tests := []struct { name string @@ -134,3 +182,118 @@ func TestExpand(t *testing.T) { }) } } + +func TestExpandMySQL(t *testing.T) { + // Get MySQL connection parameters + user := os.Getenv("MYSQL_USER") + if user == "" { + user = "root" + } + pass := os.Getenv("MYSQL_ROOT_PASSWORD") + if pass == "" { + pass = "mysecretpassword" + } + host := os.Getenv("MYSQL_HOST") + if host == "" { + host = "127.0.0.1" + } + port := os.Getenv("MYSQL_PORT") + if port == "" { + port = "3306" + } + dbname := os.Getenv("MYSQL_DATABASE") + if dbname == "" { + dbname = "dinotest" + } + + source := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?multiStatements=true&parseTime=true", user, pass, host, port, dbname) + + ctx := context.Background() + + db, err := sql.Open("mysql", source) + if err != nil { + t.Skipf("could not connect to MySQL: %v", err) + } + defer db.Close() + + // Verify connection + if err := db.Ping(); err != nil { + t.Skipf("could not ping MySQL: %v", err) + } + + // Create a test table + _, err = db.ExecContext(ctx, `DROP TABLE IF EXISTS authors`) + if err != nil { + t.Fatalf("failed to drop test table: %v", err) + } + _, err = db.ExecContext(ctx, ` + CREATE TABLE authors ( + id INT AUTO_INCREMENT PRIMARY KEY, + name VARCHAR(255) NOT NULL, + bio TEXT + ) + `) + if err != nil { + t.Fatalf("failed to create test table: %v", err) + } + defer db.ExecContext(ctx, "DROP TABLE IF EXISTS authors") + + // Create the parser which also implements format.Dialect + parser := dolphin.NewParser() + + // Create the expander + colGetter := &MySQLColumnGetter{db: db} + exp := New(colGetter, parser, parser) + + tests := []struct { + name string + query string + expected string + }{ + { + name: "simple select star", + query: "SELECT * FROM authors", + expected: "SELECT id,name,bio FROM authors;", + }, + { + name: "select with no star", + query: "SELECT id, name FROM authors", + expected: "SELECT id, name FROM authors", // No change, returns original + }, + { + name: "select star with where clause", + query: "SELECT * FROM authors WHERE id = 1", + expected: "SELECT id,name,bio FROM authors WHERE id = 1;", + }, + { + name: "table qualified star", + query: "SELECT authors.* FROM authors", + expected: "SELECT authors.id,authors.name,authors.bio FROM authors;", + }, + { + name: "count star not expanded", + query: "SELECT COUNT(*) FROM authors", + expected: "SELECT COUNT(*) FROM authors", // No change - COUNT(*) should not be expanded + }, + { + name: "count star with other columns", + query: "SELECT COUNT(*), name FROM authors GROUP BY name", + expected: "SELECT COUNT(*), name FROM authors GROUP BY name", // No change + }, + // Note: "double star" and "star in middle of columns" tests are skipped for MySQL + // because the intermediate query formatting produces invalid MySQL syntax. + // These are edge cases that rarely occur in real-world usage. + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := exp.Expand(ctx, tc.query) + if err != nil { + t.Fatalf("Expand failed: %v", err) + } + if result != tc.expected { + t.Errorf("expected %q, got %q", tc.expected, result) + } + }) + } +} From dd128fe5f11d3fd4444d05cc511c75c7d5fc6559 Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Sun, 30 Nov 2025 19:02:24 -0800 Subject: [PATCH 4/8] fix(expander): use valid MySQL syntax for edge case tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit MySQL doesn't support unqualified `*` mixed with other columns (e.g., `SELECT *, *` or `SELECT id, *, name`). These are valid PostgreSQL but invalid MySQL syntax. Update MySQL tests to use table-qualified stars which are valid: - `SELECT authors.*, authors.*` instead of `SELECT *, *` - `SELECT id, authors.*, name` instead of `SELECT id, *, name` 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- internal/x/expander/expander_test.go | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/internal/x/expander/expander_test.go b/internal/x/expander/expander_test.go index 38b43fad8a..18b0c966fd 100644 --- a/internal/x/expander/expander_test.go +++ b/internal/x/expander/expander_test.go @@ -270,6 +270,16 @@ func TestExpandMySQL(t *testing.T) { query: "SELECT authors.* FROM authors", expected: "SELECT authors.id,authors.name,authors.bio FROM authors;", }, + { + name: "double table qualified star", + query: "SELECT authors.*, authors.* FROM authors", + expected: "SELECT authors.id,authors.name,authors.bio,authors.id,authors.name,authors.bio FROM authors;", + }, + { + name: "star in middle of columns table qualified", + query: "SELECT id, authors.*, name FROM authors", + expected: "SELECT id,authors.id,authors.name,authors.bio,name FROM authors;", + }, { name: "count star not expanded", query: "SELECT COUNT(*) FROM authors", @@ -280,9 +290,6 @@ func TestExpandMySQL(t *testing.T) { query: "SELECT COUNT(*), name FROM authors GROUP BY name", expected: "SELECT COUNT(*), name FROM authors GROUP BY name", // No change }, - // Note: "double star" and "star in middle of columns" tests are skipped for MySQL - // because the intermediate query formatting produces invalid MySQL syntax. - // These are edge cases that rarely occur in real-world usage. } for _, tc := range tests { From e7836b8de9e54bcb9f24c54821e442c037aa1a99 Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Sun, 30 Nov 2025 19:06:36 -0800 Subject: [PATCH 5/8] feat(expander): add SQLite support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add TestExpandSQLite with 8 test cases using in-memory SQLite database - Rename MySQLColumnGetter to SQLColumnGetter since both MySQL and SQLite use the same database/sql-based implementation - SQLite supports the same star syntax as PostgreSQL (including `SELECT *, *` and `SELECT id, *, name`) Test results: - PostgreSQL: 14 tests - MySQL: 8 tests - SQLite: 8 tests 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- internal/x/expander/expander_test.go | 105 +++++++++++++++++++++++++-- 1 file changed, 97 insertions(+), 8 deletions(-) diff --git a/internal/x/expander/expander_test.go b/internal/x/expander/expander_test.go index 18b0c966fd..3f8320ea89 100644 --- a/internal/x/expander/expander_test.go +++ b/internal/x/expander/expander_test.go @@ -9,9 +9,12 @@ import ( _ "github.com/go-sql-driver/mysql" "github.com/jackc/pgx/v5/pgxpool" + _ "github.com/ncruces/go-sqlite3/driver" + _ "github.com/ncruces/go-sqlite3/embed" "github.com/sqlc-dev/sqlc/internal/engine/dolphin" "github.com/sqlc-dev/sqlc/internal/engine/postgresql" + "github.com/sqlc-dev/sqlc/internal/engine/sqlite" ) // PostgreSQLColumnGetter implements ColumnGetter for PostgreSQL using pgxpool. @@ -39,16 +42,13 @@ func (g *PostgreSQLColumnGetter) GetColumnNames(ctx context.Context, query strin return columns, nil } -// MySQLColumnGetter implements ColumnGetter for MySQL using database/sql. -type MySQLColumnGetter struct { +// SQLColumnGetter implements ColumnGetter for MySQL and SQLite using database/sql. +type SQLColumnGetter struct { db *sql.DB } -func (g *MySQLColumnGetter) GetColumnNames(ctx context.Context, query string) ([]string, error) { - // Use LIMIT 0 to get column metadata without fetching rows - limitedQuery := query - // For SELECT queries, add LIMIT 0 if not already present - rows, err := g.db.QueryContext(ctx, limitedQuery) +func (g *SQLColumnGetter) GetColumnNames(ctx context.Context, query string) ([]string, error) { + rows, err := g.db.QueryContext(ctx, query) if err != nil { return nil, err } @@ -242,7 +242,7 @@ func TestExpandMySQL(t *testing.T) { parser := dolphin.NewParser() // Create the expander - colGetter := &MySQLColumnGetter{db: db} + colGetter := &SQLColumnGetter{db: db} exp := New(colGetter, parser, parser) tests := []struct { @@ -304,3 +304,92 @@ func TestExpandMySQL(t *testing.T) { }) } } + +func TestExpandSQLite(t *testing.T) { + ctx := context.Background() + + // Create an in-memory SQLite database + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("could not open SQLite: %v", err) + } + defer db.Close() + + // Create a test table + _, err = db.ExecContext(ctx, ` + CREATE TABLE authors ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + bio TEXT + ) + `) + if err != nil { + t.Fatalf("failed to create test table: %v", err) + } + + // Create the parser which also implements format.Dialect + parser := sqlite.NewParser() + + // Create the expander + colGetter := &SQLColumnGetter{db: db} + exp := New(colGetter, parser, parser) + + tests := []struct { + name string + query string + expected string + }{ + { + name: "simple select star", + query: "SELECT * FROM authors", + expected: "SELECT id,name,bio FROM authors;", + }, + { + name: "select with no star", + query: "SELECT id, name FROM authors", + expected: "SELECT id, name FROM authors", // No change, returns original + }, + { + name: "select star with where clause", + query: "SELECT * FROM authors WHERE id = 1", + expected: "SELECT id,name,bio FROM authors WHERE id = 1;", + }, + { + name: "double star", + query: "SELECT *, * FROM authors", + expected: "SELECT id,name,bio,id,name,bio FROM authors;", + }, + { + name: "table qualified star", + query: "SELECT authors.* FROM authors", + expected: "SELECT authors.id,authors.name,authors.bio FROM authors;", + }, + { + name: "star in middle of columns", + query: "SELECT id, *, name FROM authors", + expected: "SELECT id,id,name,bio,name FROM authors;", + }, + { + name: "count star not expanded", + query: "SELECT COUNT(*) FROM authors", + expected: "SELECT COUNT(*) FROM authors", // No change - COUNT(*) should not be expanded + }, + { + name: "count star with other columns", + query: "SELECT COUNT(*), name FROM authors GROUP BY name", + expected: "SELECT COUNT(*), name FROM authors GROUP BY name", // No change + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := exp.Expand(ctx, tc.query) + if err != nil { + t.Fatalf("Expand failed: %v", err) + } + if result != tc.expected { + t.Errorf("expected %q, got %q", tc.expected, result) + } + }) + } +} From 3c31da6bc45feebb18f8341f662648cba9593a51 Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Sun, 30 Nov 2025 19:09:33 -0800 Subject: [PATCH 6/8] refactor(expander): use PrepareContext in SQLColumnGetter MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Use PrepareContext to validate the query before executing it to get column metadata. While database/sql doesn't expose column names from prepared statements directly (unlike pgx), this at least validates the SQL syntax before execution. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- internal/x/expander/expander_test.go | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/internal/x/expander/expander_test.go b/internal/x/expander/expander_test.go index 3f8320ea89..b1a5c20ef5 100644 --- a/internal/x/expander/expander_test.go +++ b/internal/x/expander/expander_test.go @@ -48,7 +48,16 @@ type SQLColumnGetter struct { } func (g *SQLColumnGetter) GetColumnNames(ctx context.Context, query string) ([]string, error) { - rows, err := g.db.QueryContext(ctx, query) + // Prepare the statement to validate the query and get column metadata + stmt, err := g.db.PrepareContext(ctx, query) + if err != nil { + return nil, err + } + defer stmt.Close() + + // Execute with LIMIT 0 workaround by wrapping in a subquery to get column names + // without fetching actual data. We need to execute to get column metadata from database/sql. + rows, err := stmt.QueryContext(ctx) if err != nil { return nil, err } From ffc490ba8b3ac67edf85626c6fc02140f9404764 Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Sun, 30 Nov 2025 19:18:28 -0800 Subject: [PATCH 7/8] refactor(expander): use native ncruces/go-sqlite3 API for column names MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Use the native sqlite3.Conn.Prepare and stmt.ColumnName/ColumnCount APIs to get column names without executing the query. This is more efficient and consistent with how PostgreSQL handles it. Changes: - Add SQLiteColumnGetter using native sqlite3.Conn - Rename SQLColumnGetter to MySQLColumnGetter (MySQL still needs to execute) - SQLite test now uses sqlite3.Open instead of sql.Open 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- internal/x/expander/expander_test.go | 49 ++++++++++++++++++++-------- 1 file changed, 36 insertions(+), 13 deletions(-) diff --git a/internal/x/expander/expander_test.go b/internal/x/expander/expander_test.go index b1a5c20ef5..3297173fea 100644 --- a/internal/x/expander/expander_test.go +++ b/internal/x/expander/expander_test.go @@ -9,7 +9,7 @@ import ( _ "github.com/go-sql-driver/mysql" "github.com/jackc/pgx/v5/pgxpool" - _ "github.com/ncruces/go-sqlite3/driver" + "github.com/ncruces/go-sqlite3" _ "github.com/ncruces/go-sqlite3/embed" "github.com/sqlc-dev/sqlc/internal/engine/dolphin" @@ -42,12 +42,12 @@ func (g *PostgreSQLColumnGetter) GetColumnNames(ctx context.Context, query strin return columns, nil } -// SQLColumnGetter implements ColumnGetter for MySQL and SQLite using database/sql. -type SQLColumnGetter struct { +// MySQLColumnGetter implements ColumnGetter for MySQL using database/sql. +type MySQLColumnGetter struct { db *sql.DB } -func (g *SQLColumnGetter) GetColumnNames(ctx context.Context, query string) ([]string, error) { +func (g *MySQLColumnGetter) GetColumnNames(ctx context.Context, query string) ([]string, error) { // Prepare the statement to validate the query and get column metadata stmt, err := g.db.PrepareContext(ctx, query) if err != nil { @@ -55,8 +55,8 @@ func (g *SQLColumnGetter) GetColumnNames(ctx context.Context, query string) ([]s } defer stmt.Close() - // Execute with LIMIT 0 workaround by wrapping in a subquery to get column names - // without fetching actual data. We need to execute to get column metadata from database/sql. + // Execute to get column metadata from database/sql. + // database/sql doesn't expose column names from prepared statements directly. rows, err := stmt.QueryContext(ctx) if err != nil { return nil, err @@ -66,6 +66,29 @@ func (g *SQLColumnGetter) GetColumnNames(ctx context.Context, query string) ([]s return rows.Columns() } +// SQLiteColumnGetter implements ColumnGetter for SQLite using the native ncruces/go-sqlite3 API. +type SQLiteColumnGetter struct { + conn *sqlite3.Conn +} + +func (g *SQLiteColumnGetter) GetColumnNames(ctx context.Context, query string) ([]string, error) { + // Prepare the statement - this gives us column metadata without executing + stmt, _, err := g.conn.Prepare(query) + if err != nil { + return nil, err + } + defer stmt.Close() + + // Get column names from the prepared statement + count := stmt.ColumnCount() + columns := make([]string, count) + for i := 0; i < count; i++ { + columns[i] = stmt.ColumnName(i) + } + + return columns, nil +} + func TestExpandPostgreSQL(t *testing.T) { // Skip if no database connection available uri := os.Getenv("POSTGRESQL_SERVER_URI") @@ -251,7 +274,7 @@ func TestExpandMySQL(t *testing.T) { parser := dolphin.NewParser() // Create the expander - colGetter := &SQLColumnGetter{db: db} + colGetter := &MySQLColumnGetter{db: db} exp := New(colGetter, parser, parser) tests := []struct { @@ -317,15 +340,15 @@ func TestExpandMySQL(t *testing.T) { func TestExpandSQLite(t *testing.T) { ctx := context.Background() - // Create an in-memory SQLite database - db, err := sql.Open("sqlite3", ":memory:") + // Create an in-memory SQLite database using native API + conn, err := sqlite3.Open(":memory:") if err != nil { t.Fatalf("could not open SQLite: %v", err) } - defer db.Close() + defer conn.Close() // Create a test table - _, err = db.ExecContext(ctx, ` + err = conn.Exec(` CREATE TABLE authors ( id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL, @@ -339,8 +362,8 @@ func TestExpandSQLite(t *testing.T) { // Create the parser which also implements format.Dialect parser := sqlite.NewParser() - // Create the expander - colGetter := &SQLColumnGetter{db: db} + // Create the expander using native SQLite column getter + colGetter := &SQLiteColumnGetter{conn: conn} exp := New(colGetter, parser, parser) tests := []struct { From f1bbc6ed028766593377724558e626d1b9c04899 Mon Sep 17 00:00:00 2001 From: Kyle Conroy Date: Sun, 30 Nov 2025 19:30:53 -0800 Subject: [PATCH 8/8] refactor(expander): use forked MySQL driver and fix list formatting MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Update MySQLColumnGetter to use github.com/sqlc-dev/mysql fork with StmtMetadata interface for getting column names via prepare - Add replace directive to go.mod for the forked MySQL driver - Fix list formatting to use ", " separator instead of "," for proper SQL spacing (e.g., "SELECT id, name, bio" instead of "SELECT id,name,bio") - Update test expectations to reflect proper spacing 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- go.mod | 2 + go.sum | 4 +- internal/sql/ast/list.go | 2 +- internal/x/expander/expander_test.go | 81 +++++++++++++++++----------- 4 files changed, 55 insertions(+), 34 deletions(-) diff --git a/go.mod b/go.mod index 450573ddab..630795248e 100644 --- a/go.mod +++ b/go.mod @@ -64,3 +64,5 @@ require ( google.golang.org/genproto/googleapis/rpc v0.0.0-20251022142026-3a174f9686a8 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect ) + +replace github.com/go-sql-driver/mysql => github.com/sqlc-dev/mysql v0.0.0-20251129233104-d81e1cac6db2 diff --git a/go.sum b/go.sum index 3178cae5c1..002020f15c 100644 --- a/go.sum +++ b/go.sum @@ -26,8 +26,6 @@ github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= -github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo= -github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw= github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= @@ -159,6 +157,8 @@ github.com/spf13/cobra v1.10.1/go.mod h1:7SmJGaTHFVBY0jW4NXGluQoLvhqFQM+6XSKD+P4 github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/sqlc-dev/mysql v0.0.0-20251129233104-d81e1cac6db2 h1:kmCAKKtOgK6EXXQX9oPdEASIhgor7TCpWxD8NtcqVcU= +github.com/sqlc-dev/mysql v0.0.0-20251129233104-d81e1cac6db2/go.mod h1:TrDMWzjNTKvJeK2GC8uspG+PWyPLiY9QKvwdWpAdlZE= github.com/stoewer/go-strcase v1.2.0 h1:Z2iHWqGXH00XYgqDmNgQbIBxf3wrNq0F3feEy0ainaU= github.com/stoewer/go-strcase v1.2.0/go.mod h1:IBiWB2sKIp3wVVQ3Y035++gc+knqhUQag1KpM8ahLw8= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/internal/sql/ast/list.go b/internal/sql/ast/list.go index 38be310e3c..3bb9d90dcd 100644 --- a/internal/sql/ast/list.go +++ b/internal/sql/ast/list.go @@ -14,5 +14,5 @@ func (n *List) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } - buf.join(n, d, ",") + buf.join(n, d, ", ") } diff --git a/internal/x/expander/expander_test.go b/internal/x/expander/expander_test.go index 3297173fea..84de74cdf3 100644 --- a/internal/x/expander/expander_test.go +++ b/internal/x/expander/expander_test.go @@ -3,11 +3,12 @@ package expander import ( "context" "database/sql" + "database/sql/driver" "fmt" "os" "testing" - _ "github.com/go-sql-driver/mysql" + "github.com/go-sql-driver/mysql" "github.com/jackc/pgx/v5/pgxpool" "github.com/ncruces/go-sqlite3" _ "github.com/ncruces/go-sqlite3/embed" @@ -42,28 +43,46 @@ func (g *PostgreSQLColumnGetter) GetColumnNames(ctx context.Context, query strin return columns, nil } -// MySQLColumnGetter implements ColumnGetter for MySQL using database/sql. +// MySQLColumnGetter implements ColumnGetter for MySQL using the forked driver's StmtMetadata. type MySQLColumnGetter struct { db *sql.DB } func (g *MySQLColumnGetter) GetColumnNames(ctx context.Context, query string) ([]string, error) { - // Prepare the statement to validate the query and get column metadata - stmt, err := g.db.PrepareContext(ctx, query) + conn, err := g.db.Conn(ctx) if err != nil { return nil, err } - defer stmt.Close() + defer conn.Close() - // Execute to get column metadata from database/sql. - // database/sql doesn't expose column names from prepared statements directly. - rows, err := stmt.QueryContext(ctx) + var columns []string + err = conn.Raw(func(driverConn any) error { + preparer, ok := driverConn.(driver.ConnPrepareContext) + if !ok { + return fmt.Errorf("driver connection does not support PrepareContext") + } + + stmt, err := preparer.PrepareContext(ctx, query) + if err != nil { + return err + } + defer stmt.Close() + + meta, ok := stmt.(mysql.StmtMetadata) + if !ok { + return fmt.Errorf("prepared statement does not implement StmtMetadata") + } + + for _, col := range meta.ColumnMetadata() { + columns = append(columns, col.Name) + } + return nil + }) if err != nil { return nil, err } - defer rows.Close() - return rows.Columns() + return columns, nil } // SQLiteColumnGetter implements ColumnGetter for SQLite using the native ncruces/go-sqlite3 API. @@ -133,7 +152,7 @@ func TestExpandPostgreSQL(t *testing.T) { { name: "simple select star", query: "SELECT * FROM authors", - expected: "SELECT id,name,bio FROM authors;", + expected: "SELECT id, name, bio FROM authors;", }, { name: "select with no star", @@ -143,52 +162,52 @@ func TestExpandPostgreSQL(t *testing.T) { { name: "select star with where clause", query: "SELECT * FROM authors WHERE id = 1", - expected: "SELECT id,name,bio FROM authors WHERE id = 1;", + expected: "SELECT id, name, bio FROM authors WHERE id = 1;", }, { name: "double star", query: "SELECT *, * FROM authors", - expected: "SELECT id,name,bio,id,name,bio FROM authors;", + expected: "SELECT id, name, bio, id, name, bio FROM authors;", }, { name: "table qualified star", query: "SELECT authors.* FROM authors", - expected: "SELECT authors.id,authors.name,authors.bio FROM authors;", + expected: "SELECT authors.id, authors.name, authors.bio FROM authors;", }, { name: "star in middle of columns", query: "SELECT id, *, name FROM authors", - expected: "SELECT id,id,name,bio,name FROM authors;", + expected: "SELECT id, id, name, bio, name FROM authors;", }, { name: "insert returning star", query: "INSERT INTO authors (name, bio) VALUES ('John', 'A writer') RETURNING *", - expected: "INSERT INTO authors (name,bio) VALUES ('John','A writer') RETURNING id,name,bio;", + expected: "INSERT INTO authors (name, bio) VALUES ('John', 'A writer') RETURNING id, name, bio;", }, { name: "insert returning mixed", query: "INSERT INTO authors (name, bio) VALUES ('John', 'A writer') RETURNING id, *", - expected: "INSERT INTO authors (name,bio) VALUES ('John','A writer') RETURNING id,id,name,bio;", + expected: "INSERT INTO authors (name, bio) VALUES ('John', 'A writer') RETURNING id, id, name, bio;", }, { name: "update returning star", query: "UPDATE authors SET name = 'Jane' WHERE id = 1 RETURNING *", - expected: "UPDATE authors SET name = 'Jane' WHERE id = 1 RETURNING id,name,bio;", + expected: "UPDATE authors SET name = 'Jane' WHERE id = 1 RETURNING id, name, bio;", }, { name: "delete returning star", query: "DELETE FROM authors WHERE id = 1 RETURNING *", - expected: "DELETE FROM authors WHERE id = 1 RETURNING id,name,bio;", + expected: "DELETE FROM authors WHERE id = 1 RETURNING id, name, bio;", }, { name: "cte with select star", query: "WITH a AS (SELECT * FROM authors) SELECT * FROM a", - expected: "WITH a AS (SELECT id,name,bio FROM authors) SELECT id,name,bio FROM a;", + expected: "WITH a AS (SELECT id, name, bio FROM authors) SELECT id, name, bio FROM a;", }, { name: "multiple ctes with dependency", query: "WITH a AS (SELECT * FROM authors), b AS (SELECT * FROM a) SELECT * FROM b", - expected: "WITH a AS (SELECT id,name,bio FROM authors), b AS (SELECT id,name,bio FROM a) SELECT id,name,bio FROM b;", + expected: "WITH a AS (SELECT id, name, bio FROM authors), b AS (SELECT id, name, bio FROM a) SELECT id, name, bio FROM b;", }, { name: "count star not expanded", @@ -285,7 +304,7 @@ func TestExpandMySQL(t *testing.T) { { name: "simple select star", query: "SELECT * FROM authors", - expected: "SELECT id,name,bio FROM authors;", + expected: "SELECT id, name, bio FROM authors;", }, { name: "select with no star", @@ -295,22 +314,22 @@ func TestExpandMySQL(t *testing.T) { { name: "select star with where clause", query: "SELECT * FROM authors WHERE id = 1", - expected: "SELECT id,name,bio FROM authors WHERE id = 1;", + expected: "SELECT id, name, bio FROM authors WHERE id = 1;", }, { name: "table qualified star", query: "SELECT authors.* FROM authors", - expected: "SELECT authors.id,authors.name,authors.bio FROM authors;", + expected: "SELECT authors.id, authors.name, authors.bio FROM authors;", }, { name: "double table qualified star", query: "SELECT authors.*, authors.* FROM authors", - expected: "SELECT authors.id,authors.name,authors.bio,authors.id,authors.name,authors.bio FROM authors;", + expected: "SELECT authors.id, authors.name, authors.bio, authors.id, authors.name, authors.bio FROM authors;", }, { name: "star in middle of columns table qualified", query: "SELECT id, authors.*, name FROM authors", - expected: "SELECT id,authors.id,authors.name,authors.bio,name FROM authors;", + expected: "SELECT id, authors.id, authors.name, authors.bio, name FROM authors;", }, { name: "count star not expanded", @@ -374,7 +393,7 @@ func TestExpandSQLite(t *testing.T) { { name: "simple select star", query: "SELECT * FROM authors", - expected: "SELECT id,name,bio FROM authors;", + expected: "SELECT id, name, bio FROM authors;", }, { name: "select with no star", @@ -384,22 +403,22 @@ func TestExpandSQLite(t *testing.T) { { name: "select star with where clause", query: "SELECT * FROM authors WHERE id = 1", - expected: "SELECT id,name,bio FROM authors WHERE id = 1;", + expected: "SELECT id, name, bio FROM authors WHERE id = 1;", }, { name: "double star", query: "SELECT *, * FROM authors", - expected: "SELECT id,name,bio,id,name,bio FROM authors;", + expected: "SELECT id, name, bio, id, name, bio FROM authors;", }, { name: "table qualified star", query: "SELECT authors.* FROM authors", - expected: "SELECT authors.id,authors.name,authors.bio FROM authors;", + expected: "SELECT authors.id, authors.name, authors.bio FROM authors;", }, { name: "star in middle of columns", query: "SELECT id, *, name FROM authors", - expected: "SELECT id,id,name,bio,name FROM authors;", + expected: "SELECT id, id, name, bio, name FROM authors;", }, { name: "count star not expanded",