diff --git a/go.mod b/go.mod index 450573ddab..630795248e 100644 --- a/go.mod +++ b/go.mod @@ -64,3 +64,5 @@ require ( google.golang.org/genproto/googleapis/rpc v0.0.0-20251022142026-3a174f9686a8 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect ) + +replace github.com/go-sql-driver/mysql => github.com/sqlc-dev/mysql v0.0.0-20251129233104-d81e1cac6db2 diff --git a/go.sum b/go.sum index 3178cae5c1..002020f15c 100644 --- a/go.sum +++ b/go.sum @@ -26,8 +26,6 @@ github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= -github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo= -github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw= github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= @@ -159,6 +157,8 @@ github.com/spf13/cobra v1.10.1/go.mod h1:7SmJGaTHFVBY0jW4NXGluQoLvhqFQM+6XSKD+P4 github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/sqlc-dev/mysql v0.0.0-20251129233104-d81e1cac6db2 h1:kmCAKKtOgK6EXXQX9oPdEASIhgor7TCpWxD8NtcqVcU= +github.com/sqlc-dev/mysql v0.0.0-20251129233104-d81e1cac6db2/go.mod h1:TrDMWzjNTKvJeK2GC8uspG+PWyPLiY9QKvwdWpAdlZE= github.com/stoewer/go-strcase v1.2.0 h1:Z2iHWqGXH00XYgqDmNgQbIBxf3wrNq0F3feEy0ainaU= github.com/stoewer/go-strcase v1.2.0/go.mod h1:IBiWB2sKIp3wVVQ3Y035++gc+knqhUQag1KpM8ahLw8= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/internal/sql/ast/list.go b/internal/sql/ast/list.go index 38be310e3c..3bb9d90dcd 100644 --- a/internal/sql/ast/list.go +++ b/internal/sql/ast/list.go @@ -14,5 +14,5 @@ func (n *List) Format(buf *TrackedBuffer, d format.Dialect) { if n == nil { return } - buf.join(n, d, ",") + buf.join(n, d, ", ") } diff --git a/internal/x/expander/expander.go b/internal/x/expander/expander.go new file mode 100644 index 0000000000..af0cab26e8 --- /dev/null +++ b/internal/x/expander/expander.go @@ -0,0 +1,507 @@ +package expander + +import ( + "context" + "fmt" + "io" + "strings" + + "github.com/sqlc-dev/sqlc/internal/sql/ast" + "github.com/sqlc-dev/sqlc/internal/sql/astutils" + "github.com/sqlc-dev/sqlc/internal/sql/format" +) + +// Parser is an interface for SQL parsers that can parse SQL into AST statements. +type Parser interface { + Parse(r io.Reader) ([]ast.Statement, error) +} + +// ColumnGetter retrieves column names for a query by preparing it against a database. +type ColumnGetter interface { + GetColumnNames(ctx context.Context, query string) ([]string, error) +} + +// Expander expands SELECT * and RETURNING * queries by replacing * with explicit column names +// obtained from preparing the query against a database. +type Expander struct { + colGetter ColumnGetter + parser Parser + dialect format.Dialect +} + +// New creates a new Expander with the given column getter, parser, and dialect. +func New(colGetter ColumnGetter, parser Parser, dialect format.Dialect) *Expander { + return &Expander{ + colGetter: colGetter, + parser: parser, + dialect: dialect, + } +} + +// Expand takes a SQL query, and if it contains * in SELECT or RETURNING clause, +// expands it to use explicit column names. Returns the expanded query string. +func (e *Expander) Expand(ctx context.Context, query string) (string, error) { + // Parse the query + stmts, err := e.parser.Parse(strings.NewReader(query)) + if err != nil { + return "", fmt.Errorf("failed to parse query: %w", err) + } + + if len(stmts) == 0 { + return query, nil + } + + stmt := stmts[0].Raw.Stmt + + // Check if there's any star in the statement (including CTEs, subqueries, etc.) + if !hasStarAnywhere(stmt) { + return query, nil + } + + // Expand all stars in the statement recursively + if err := e.expandNode(ctx, stmt); err != nil { + return "", err + } + + // Format the modified AST back to SQL + expanded := ast.Format(stmts[0].Raw, e.dialect) + + return expanded, nil +} + +// expandNode recursively expands * in all parts of the statement +func (e *Expander) expandNode(ctx context.Context, node ast.Node) error { + if node == nil { + return nil + } + + switch n := node.(type) { + case *ast.SelectStmt: + return e.expandSelectStmt(ctx, n) + case *ast.InsertStmt: + return e.expandInsertStmt(ctx, n) + case *ast.UpdateStmt: + return e.expandUpdateStmt(ctx, n) + case *ast.DeleteStmt: + return e.expandDeleteStmt(ctx, n) + case *ast.CommonTableExpr: + return e.expandNode(ctx, n.Ctequery) + } + return nil +} + +// expandSelectStmt expands * in a SELECT statement including CTEs and subqueries +func (e *Expander) expandSelectStmt(ctx context.Context, stmt *ast.SelectStmt) error { + // First expand any CTEs - must be done in order since later CTEs may depend on earlier ones + if stmt.WithClause != nil && stmt.WithClause.Ctes != nil { + for _, cteNode := range stmt.WithClause.Ctes.Items { + cte, ok := cteNode.(*ast.CommonTableExpr) + if !ok { + continue + } + cteSelect, ok := cte.Ctequery.(*ast.SelectStmt) + if !ok { + continue + } + if hasStarInList(cteSelect.TargetList) { + // Get column names for this CTE + columns, err := e.getCTEColumnNames(ctx, stmt, cte) + if err != nil { + return err + } + cteSelect.TargetList = rewriteTargetList(cteSelect.TargetList, columns) + } + // Recursively handle nested CTEs/subqueries in this CTE + if err := e.expandSelectStmtInner(ctx, cteSelect); err != nil { + return err + } + } + } + + // Expand subqueries in FROM clause + if stmt.FromClause != nil { + for _, fromItem := range stmt.FromClause.Items { + if err := e.expandFromClause(ctx, fromItem); err != nil { + return err + } + } + } + + // Expand the target list if it has stars + if hasStarInList(stmt.TargetList) { + // Format the current state to get columns + tempRaw := &ast.RawStmt{Stmt: stmt} + tempQuery := ast.Format(tempRaw, e.dialect) + columns, err := e.getColumnNames(ctx, tempQuery) + if err != nil { + return fmt.Errorf("failed to get column names: %w", err) + } + stmt.TargetList = rewriteTargetList(stmt.TargetList, columns) + } + + return nil +} + +// expandSelectStmtInner expands nested structures without re-processing the target list +func (e *Expander) expandSelectStmtInner(ctx context.Context, stmt *ast.SelectStmt) error { + // Expand subqueries in FROM clause + if stmt.FromClause != nil { + for _, fromItem := range stmt.FromClause.Items { + if err := e.expandFromClause(ctx, fromItem); err != nil { + return err + } + } + } + return nil +} + +// getCTEColumnNames gets the column names for a CTE by constructing a query with proper context +func (e *Expander) getCTEColumnNames(ctx context.Context, stmt *ast.SelectStmt, targetCTE *ast.CommonTableExpr) ([]string, error) { + // Build a temporary query: WITH SELECT * FROM + var ctesToInclude []ast.Node + for _, cteNode := range stmt.WithClause.Ctes.Items { + ctesToInclude = append(ctesToInclude, cteNode) + cte, ok := cteNode.(*ast.CommonTableExpr) + if ok && cte.Ctename != nil && targetCTE.Ctename != nil && *cte.Ctename == *targetCTE.Ctename { + break + } + } + + // Create a SELECT * FROM with the relevant CTEs + cteName := "" + if targetCTE.Ctename != nil { + cteName = *targetCTE.Ctename + } + + tempStmt := &ast.SelectStmt{ + WithClause: &ast.WithClause{ + Ctes: &ast.List{Items: ctesToInclude}, + Recursive: stmt.WithClause.Recursive, + }, + TargetList: &ast.List{ + Items: []ast.Node{ + &ast.ResTarget{ + Val: &ast.ColumnRef{ + Fields: &ast.List{ + Items: []ast.Node{&ast.A_Star{}}, + }, + }, + }, + }, + }, + FromClause: &ast.List{ + Items: []ast.Node{ + &ast.RangeVar{ + Relname: &cteName, + }, + }, + }, + } + + tempRaw := &ast.RawStmt{Stmt: tempStmt} + tempQuery := ast.Format(tempRaw, e.dialect) + + return e.getColumnNames(ctx, tempQuery) +} + +// expandInsertStmt expands * in an INSERT statement's RETURNING clause +func (e *Expander) expandInsertStmt(ctx context.Context, stmt *ast.InsertStmt) error { + // Expand CTEs first + if stmt.WithClause != nil && stmt.WithClause.Ctes != nil { + for _, cte := range stmt.WithClause.Ctes.Items { + if err := e.expandNode(ctx, cte); err != nil { + return err + } + } + } + + // Expand the SELECT part if present + if stmt.SelectStmt != nil { + if err := e.expandNode(ctx, stmt.SelectStmt); err != nil { + return err + } + } + + // Expand RETURNING clause + if hasStarInList(stmt.ReturningList) { + tempRaw := &ast.RawStmt{Stmt: stmt} + tempQuery := ast.Format(tempRaw, e.dialect) + columns, err := e.getColumnNames(ctx, tempQuery) + if err != nil { + return fmt.Errorf("failed to get column names: %w", err) + } + stmt.ReturningList = rewriteTargetList(stmt.ReturningList, columns) + } + + return nil +} + +// expandUpdateStmt expands * in an UPDATE statement's RETURNING clause +func (e *Expander) expandUpdateStmt(ctx context.Context, stmt *ast.UpdateStmt) error { + // Expand CTEs first + if stmt.WithClause != nil && stmt.WithClause.Ctes != nil { + for _, cte := range stmt.WithClause.Ctes.Items { + if err := e.expandNode(ctx, cte); err != nil { + return err + } + } + } + + // Expand RETURNING clause + if hasStarInList(stmt.ReturningList) { + tempRaw := &ast.RawStmt{Stmt: stmt} + tempQuery := ast.Format(tempRaw, e.dialect) + columns, err := e.getColumnNames(ctx, tempQuery) + if err != nil { + return fmt.Errorf("failed to get column names: %w", err) + } + stmt.ReturningList = rewriteTargetList(stmt.ReturningList, columns) + } + + return nil +} + +// expandDeleteStmt expands * in a DELETE statement's RETURNING clause +func (e *Expander) expandDeleteStmt(ctx context.Context, stmt *ast.DeleteStmt) error { + // Expand CTEs first + if stmt.WithClause != nil && stmt.WithClause.Ctes != nil { + for _, cte := range stmt.WithClause.Ctes.Items { + if err := e.expandNode(ctx, cte); err != nil { + return err + } + } + } + + // Expand RETURNING clause + if hasStarInList(stmt.ReturningList) { + tempRaw := &ast.RawStmt{Stmt: stmt} + tempQuery := ast.Format(tempRaw, e.dialect) + columns, err := e.getColumnNames(ctx, tempQuery) + if err != nil { + return fmt.Errorf("failed to get column names: %w", err) + } + stmt.ReturningList = rewriteTargetList(stmt.ReturningList, columns) + } + + return nil +} + +// expandFromClause expands * in subqueries within FROM clause +func (e *Expander) expandFromClause(ctx context.Context, node ast.Node) error { + if node == nil { + return nil + } + + switch n := node.(type) { + case *ast.RangeSubselect: + if n.Subquery != nil { + return e.expandNode(ctx, n.Subquery) + } + case *ast.JoinExpr: + if err := e.expandFromClause(ctx, n.Larg); err != nil { + return err + } + if err := e.expandFromClause(ctx, n.Rarg); err != nil { + return err + } + } + return nil +} + +// hasStarAnywhere checks if there's a * anywhere in the statement using astutils.Search +func hasStarAnywhere(node ast.Node) bool { + if node == nil { + return false + } + // Use astutils.Search to find any A_Star node in the AST + stars := astutils.Search(node, func(n ast.Node) bool { + _, ok := n.(*ast.A_Star) + return ok + }) + return len(stars.Items) > 0 +} + +// hasStarInList checks if a target list contains a * expression using astutils.Search +func hasStarInList(targets *ast.List) bool { + if targets == nil { + return false + } + // Use astutils.Search to find any A_Star node in the target list + stars := astutils.Search(targets, func(n ast.Node) bool { + _, ok := n.(*ast.A_Star) + return ok + }) + return len(stars.Items) > 0 +} + +// getColumnNames prepares the query and returns the column names from the result +func (e *Expander) getColumnNames(ctx context.Context, query string) ([]string, error) { + return e.colGetter.GetColumnNames(ctx, query) +} + +// countStarsInList counts the number of * expressions in a target list +func countStarsInList(targets *ast.List) int { + if targets == nil { + return 0 + } + count := 0 + for _, target := range targets.Items { + resTarget, ok := target.(*ast.ResTarget) + if !ok { + continue + } + if resTarget.Val == nil { + continue + } + colRef, ok := resTarget.Val.(*ast.ColumnRef) + if !ok { + continue + } + if colRef.Fields == nil { + continue + } + for _, field := range colRef.Fields.Items { + if _, ok := field.(*ast.A_Star); ok { + count++ + break + } + } + } + return count +} + +// countNonStarsInList counts the number of non-* expressions in a target list +func countNonStarsInList(targets *ast.List) int { + if targets == nil { + return 0 + } + count := 0 + for _, target := range targets.Items { + resTarget, ok := target.(*ast.ResTarget) + if !ok { + count++ + continue + } + if resTarget.Val == nil { + count++ + continue + } + colRef, ok := resTarget.Val.(*ast.ColumnRef) + if !ok { + count++ + continue + } + if colRef.Fields == nil { + count++ + continue + } + isStar := false + for _, field := range colRef.Fields.Items { + if _, ok := field.(*ast.A_Star); ok { + isStar = true + break + } + } + if !isStar { + count++ + } + } + return count +} + +// rewriteTargetList replaces * in a target list with explicit column references +func rewriteTargetList(targets *ast.List, columns []string) *ast.List { + if targets == nil { + return nil + } + + starCount := countStarsInList(targets) + nonStarCount := countNonStarsInList(targets) + + // Calculate how many columns each * expands to + // Total columns = (columns per star * number of stars) + non-star columns + // So: columns per star = (total - non-star) / stars + columnsPerStar := 0 + if starCount > 0 { + columnsPerStar = (len(columns) - nonStarCount) / starCount + } + + newItems := make([]ast.Node, 0, len(columns)) + colIndex := 0 + + for _, target := range targets.Items { + resTarget, ok := target.(*ast.ResTarget) + if !ok { + newItems = append(newItems, target) + colIndex++ + continue + } + + if resTarget.Val == nil { + newItems = append(newItems, target) + colIndex++ + continue + } + + colRef, ok := resTarget.Val.(*ast.ColumnRef) + if !ok { + newItems = append(newItems, target) + colIndex++ + continue + } + + if colRef.Fields == nil { + newItems = append(newItems, target) + colIndex++ + continue + } + + // Check if this is a * (with or without table qualifier) + // and extract any table prefix + isStar := false + var tablePrefix []string + for _, field := range colRef.Fields.Items { + if _, ok := field.(*ast.A_Star); ok { + isStar = true + break + } + // Collect prefix parts (schema, table name) + if str, ok := field.(*ast.String); ok { + tablePrefix = append(tablePrefix, str.Str) + } + } + + if !isStar { + newItems = append(newItems, target) + colIndex++ + continue + } + + // Replace * with explicit column references + for i := 0; i < columnsPerStar && colIndex < len(columns); i++ { + newItems = append(newItems, makeColumnTargetWithPrefix(columns[colIndex], tablePrefix)) + colIndex++ + } + } + + return &ast.List{Items: newItems} +} + +// makeColumnTargetWithPrefix creates a ResTarget node for a column reference with optional table prefix +func makeColumnTargetWithPrefix(colName string, prefix []string) ast.Node { + fields := make([]ast.Node, 0, len(prefix)+1) + + // Add prefix parts (schema, table name) + for _, p := range prefix { + fields = append(fields, &ast.String{Str: p}) + } + + // Add column name + fields = append(fields, &ast.String{Str: colName}) + + return &ast.ResTarget{ + Val: &ast.ColumnRef{ + Fields: &ast.List{Items: fields}, + }, + } +} diff --git a/internal/x/expander/expander_test.go b/internal/x/expander/expander_test.go new file mode 100644 index 0000000000..84de74cdf3 --- /dev/null +++ b/internal/x/expander/expander_test.go @@ -0,0 +1,446 @@ +package expander + +import ( + "context" + "database/sql" + "database/sql/driver" + "fmt" + "os" + "testing" + + "github.com/go-sql-driver/mysql" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/ncruces/go-sqlite3" + _ "github.com/ncruces/go-sqlite3/embed" + + "github.com/sqlc-dev/sqlc/internal/engine/dolphin" + "github.com/sqlc-dev/sqlc/internal/engine/postgresql" + "github.com/sqlc-dev/sqlc/internal/engine/sqlite" +) + +// PostgreSQLColumnGetter implements ColumnGetter for PostgreSQL using pgxpool. +type PostgreSQLColumnGetter struct { + pool *pgxpool.Pool +} + +func (g *PostgreSQLColumnGetter) GetColumnNames(ctx context.Context, query string) ([]string, error) { + conn, err := g.pool.Acquire(ctx) + if err != nil { + return nil, err + } + defer conn.Release() + + desc, err := conn.Conn().Prepare(ctx, "", query) + if err != nil { + return nil, err + } + + columns := make([]string, len(desc.Fields)) + for i, field := range desc.Fields { + columns[i] = field.Name + } + + return columns, nil +} + +// MySQLColumnGetter implements ColumnGetter for MySQL using the forked driver's StmtMetadata. +type MySQLColumnGetter struct { + db *sql.DB +} + +func (g *MySQLColumnGetter) GetColumnNames(ctx context.Context, query string) ([]string, error) { + conn, err := g.db.Conn(ctx) + if err != nil { + return nil, err + } + defer conn.Close() + + var columns []string + err = conn.Raw(func(driverConn any) error { + preparer, ok := driverConn.(driver.ConnPrepareContext) + if !ok { + return fmt.Errorf("driver connection does not support PrepareContext") + } + + stmt, err := preparer.PrepareContext(ctx, query) + if err != nil { + return err + } + defer stmt.Close() + + meta, ok := stmt.(mysql.StmtMetadata) + if !ok { + return fmt.Errorf("prepared statement does not implement StmtMetadata") + } + + for _, col := range meta.ColumnMetadata() { + columns = append(columns, col.Name) + } + return nil + }) + if err != nil { + return nil, err + } + + return columns, nil +} + +// SQLiteColumnGetter implements ColumnGetter for SQLite using the native ncruces/go-sqlite3 API. +type SQLiteColumnGetter struct { + conn *sqlite3.Conn +} + +func (g *SQLiteColumnGetter) GetColumnNames(ctx context.Context, query string) ([]string, error) { + // Prepare the statement - this gives us column metadata without executing + stmt, _, err := g.conn.Prepare(query) + if err != nil { + return nil, err + } + defer stmt.Close() + + // Get column names from the prepared statement + count := stmt.ColumnCount() + columns := make([]string, count) + for i := 0; i < count; i++ { + columns[i] = stmt.ColumnName(i) + } + + return columns, nil +} + +func TestExpandPostgreSQL(t *testing.T) { + // Skip if no database connection available + uri := os.Getenv("POSTGRESQL_SERVER_URI") + if uri == "" { + uri = "postgres://postgres:mysecretpassword@localhost:5432/postgres" + } + + ctx := context.Background() + + pool, err := pgxpool.New(ctx, uri) + if err != nil { + t.Skipf("could not connect to database: %v", err) + } + defer pool.Close() + + // Create a test table + _, err = pool.Exec(ctx, ` + DROP TABLE IF EXISTS authors; + CREATE TABLE authors ( + id SERIAL PRIMARY KEY, + name TEXT NOT NULL, + bio TEXT + ); + `) + if err != nil { + t.Fatalf("failed to create test table: %v", err) + } + defer pool.Exec(ctx, "DROP TABLE IF EXISTS authors") + + // Create the parser which also implements format.Dialect + parser := postgresql.NewParser() + + // Create the expander + colGetter := &PostgreSQLColumnGetter{pool: pool} + exp := New(colGetter, parser, parser) + + tests := []struct { + name string + query string + expected string + }{ + { + name: "simple select star", + query: "SELECT * FROM authors", + expected: "SELECT id, name, bio FROM authors;", + }, + { + name: "select with no star", + query: "SELECT id, name FROM authors", + expected: "SELECT id, name FROM authors", // No change, returns original + }, + { + name: "select star with where clause", + query: "SELECT * FROM authors WHERE id = 1", + expected: "SELECT id, name, bio FROM authors WHERE id = 1;", + }, + { + name: "double star", + query: "SELECT *, * FROM authors", + expected: "SELECT id, name, bio, id, name, bio FROM authors;", + }, + { + name: "table qualified star", + query: "SELECT authors.* FROM authors", + expected: "SELECT authors.id, authors.name, authors.bio FROM authors;", + }, + { + name: "star in middle of columns", + query: "SELECT id, *, name FROM authors", + expected: "SELECT id, id, name, bio, name FROM authors;", + }, + { + name: "insert returning star", + query: "INSERT INTO authors (name, bio) VALUES ('John', 'A writer') RETURNING *", + expected: "INSERT INTO authors (name, bio) VALUES ('John', 'A writer') RETURNING id, name, bio;", + }, + { + name: "insert returning mixed", + query: "INSERT INTO authors (name, bio) VALUES ('John', 'A writer') RETURNING id, *", + expected: "INSERT INTO authors (name, bio) VALUES ('John', 'A writer') RETURNING id, id, name, bio;", + }, + { + name: "update returning star", + query: "UPDATE authors SET name = 'Jane' WHERE id = 1 RETURNING *", + expected: "UPDATE authors SET name = 'Jane' WHERE id = 1 RETURNING id, name, bio;", + }, + { + name: "delete returning star", + query: "DELETE FROM authors WHERE id = 1 RETURNING *", + expected: "DELETE FROM authors WHERE id = 1 RETURNING id, name, bio;", + }, + { + name: "cte with select star", + query: "WITH a AS (SELECT * FROM authors) SELECT * FROM a", + expected: "WITH a AS (SELECT id, name, bio FROM authors) SELECT id, name, bio FROM a;", + }, + { + name: "multiple ctes with dependency", + query: "WITH a AS (SELECT * FROM authors), b AS (SELECT * FROM a) SELECT * FROM b", + expected: "WITH a AS (SELECT id, name, bio FROM authors), b AS (SELECT id, name, bio FROM a) SELECT id, name, bio FROM b;", + }, + { + name: "count star not expanded", + query: "SELECT COUNT(*) FROM authors", + expected: "SELECT COUNT(*) FROM authors", // No change - COUNT(*) should not be expanded + }, + { + name: "count star with other columns", + query: "SELECT COUNT(*), name FROM authors GROUP BY name", + expected: "SELECT COUNT(*), name FROM authors GROUP BY name", // No change + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := exp.Expand(ctx, tc.query) + if err != nil { + t.Fatalf("Expand failed: %v", err) + } + if result != tc.expected { + t.Errorf("expected %q, got %q", tc.expected, result) + } + }) + } +} + +func TestExpandMySQL(t *testing.T) { + // Get MySQL connection parameters + user := os.Getenv("MYSQL_USER") + if user == "" { + user = "root" + } + pass := os.Getenv("MYSQL_ROOT_PASSWORD") + if pass == "" { + pass = "mysecretpassword" + } + host := os.Getenv("MYSQL_HOST") + if host == "" { + host = "127.0.0.1" + } + port := os.Getenv("MYSQL_PORT") + if port == "" { + port = "3306" + } + dbname := os.Getenv("MYSQL_DATABASE") + if dbname == "" { + dbname = "dinotest" + } + + source := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?multiStatements=true&parseTime=true", user, pass, host, port, dbname) + + ctx := context.Background() + + db, err := sql.Open("mysql", source) + if err != nil { + t.Skipf("could not connect to MySQL: %v", err) + } + defer db.Close() + + // Verify connection + if err := db.Ping(); err != nil { + t.Skipf("could not ping MySQL: %v", err) + } + + // Create a test table + _, err = db.ExecContext(ctx, `DROP TABLE IF EXISTS authors`) + if err != nil { + t.Fatalf("failed to drop test table: %v", err) + } + _, err = db.ExecContext(ctx, ` + CREATE TABLE authors ( + id INT AUTO_INCREMENT PRIMARY KEY, + name VARCHAR(255) NOT NULL, + bio TEXT + ) + `) + if err != nil { + t.Fatalf("failed to create test table: %v", err) + } + defer db.ExecContext(ctx, "DROP TABLE IF EXISTS authors") + + // Create the parser which also implements format.Dialect + parser := dolphin.NewParser() + + // Create the expander + colGetter := &MySQLColumnGetter{db: db} + exp := New(colGetter, parser, parser) + + tests := []struct { + name string + query string + expected string + }{ + { + name: "simple select star", + query: "SELECT * FROM authors", + expected: "SELECT id, name, bio FROM authors;", + }, + { + name: "select with no star", + query: "SELECT id, name FROM authors", + expected: "SELECT id, name FROM authors", // No change, returns original + }, + { + name: "select star with where clause", + query: "SELECT * FROM authors WHERE id = 1", + expected: "SELECT id, name, bio FROM authors WHERE id = 1;", + }, + { + name: "table qualified star", + query: "SELECT authors.* FROM authors", + expected: "SELECT authors.id, authors.name, authors.bio FROM authors;", + }, + { + name: "double table qualified star", + query: "SELECT authors.*, authors.* FROM authors", + expected: "SELECT authors.id, authors.name, authors.bio, authors.id, authors.name, authors.bio FROM authors;", + }, + { + name: "star in middle of columns table qualified", + query: "SELECT id, authors.*, name FROM authors", + expected: "SELECT id, authors.id, authors.name, authors.bio, name FROM authors;", + }, + { + name: "count star not expanded", + query: "SELECT COUNT(*) FROM authors", + expected: "SELECT COUNT(*) FROM authors", // No change - COUNT(*) should not be expanded + }, + { + name: "count star with other columns", + query: "SELECT COUNT(*), name FROM authors GROUP BY name", + expected: "SELECT COUNT(*), name FROM authors GROUP BY name", // No change + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := exp.Expand(ctx, tc.query) + if err != nil { + t.Fatalf("Expand failed: %v", err) + } + if result != tc.expected { + t.Errorf("expected %q, got %q", tc.expected, result) + } + }) + } +} + +func TestExpandSQLite(t *testing.T) { + ctx := context.Background() + + // Create an in-memory SQLite database using native API + conn, err := sqlite3.Open(":memory:") + if err != nil { + t.Fatalf("could not open SQLite: %v", err) + } + defer conn.Close() + + // Create a test table + err = conn.Exec(` + CREATE TABLE authors ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + bio TEXT + ) + `) + if err != nil { + t.Fatalf("failed to create test table: %v", err) + } + + // Create the parser which also implements format.Dialect + parser := sqlite.NewParser() + + // Create the expander using native SQLite column getter + colGetter := &SQLiteColumnGetter{conn: conn} + exp := New(colGetter, parser, parser) + + tests := []struct { + name string + query string + expected string + }{ + { + name: "simple select star", + query: "SELECT * FROM authors", + expected: "SELECT id, name, bio FROM authors;", + }, + { + name: "select with no star", + query: "SELECT id, name FROM authors", + expected: "SELECT id, name FROM authors", // No change, returns original + }, + { + name: "select star with where clause", + query: "SELECT * FROM authors WHERE id = 1", + expected: "SELECT id, name, bio FROM authors WHERE id = 1;", + }, + { + name: "double star", + query: "SELECT *, * FROM authors", + expected: "SELECT id, name, bio, id, name, bio FROM authors;", + }, + { + name: "table qualified star", + query: "SELECT authors.* FROM authors", + expected: "SELECT authors.id, authors.name, authors.bio FROM authors;", + }, + { + name: "star in middle of columns", + query: "SELECT id, *, name FROM authors", + expected: "SELECT id, id, name, bio, name FROM authors;", + }, + { + name: "count star not expanded", + query: "SELECT COUNT(*) FROM authors", + expected: "SELECT COUNT(*) FROM authors", // No change - COUNT(*) should not be expanded + }, + { + name: "count star with other columns", + query: "SELECT COUNT(*), name FROM authors GROUP BY name", + expected: "SELECT COUNT(*), name FROM authors GROUP BY name", // No change + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result, err := exp.Expand(ctx, tc.query) + if err != nil { + t.Fatalf("Expand failed: %v", err) + } + if result != tc.expected { + t.Errorf("expected %q, got %q", tc.expected, result) + } + }) + } +}