@@ -3,6 +3,7 @@ package analyzer
33import (
44 "context"
55 "database/sql"
6+ "database/sql/driver"
67 "fmt"
78 "hash/fnv"
89 "io"
@@ -139,90 +140,61 @@ func (a *Analyzer) Analyze(ctx context.Context, n ast.Node, query string, migrat
139140 }
140141 }
141142
142- // Count parameters in the query
143- paramCount := countParameters (query )
144-
145- // Try to prepare the statement first to validate syntax
146- stmt , err := a .conn .PrepareContext (ctx , query )
143+ // Get metadata directly from prepared statement via driver connection
144+ result , err := a .getStatementMetadata (ctx , n , query , ps )
147145 if err != nil {
148- return nil , a . extractSqlErr ( n , err )
146+ return nil , err
149147 }
150- stmt .Close ()
151148
149+ return result , nil
150+ }
151+
152+ // getStatementMetadata validates the query by preparing it against the database.
153+ // It returns empty columns/params to let catalog-based inference handle types,
154+ // since MySQL's metadata types don't always match what sqlc's type system expects
155+ // (e.g., MySQL returns BIGINT for boolean expressions, generic types for parameters).
156+ func (a * Analyzer ) getStatementMetadata (ctx context.Context , n ast.Node , query string , ps * named.ParamSet ) (* core.Analysis , error ) {
152157 var result core.Analysis
153158
154- // For SELECT queries, execute with default parameter values to get column metadata
155- if isSelectQuery (query ) {
156- cols , err := a .getColumnMetadata (ctx , query , paramCount )
157- if err == nil {
158- result .Columns = cols
159- }
160- // If we fail to get column metadata, fall through to return empty columns
161- // and let the catalog-based inference handle it
159+ // Get a raw connection to access driver-level prepared statement
160+ conn , err := a .conn .Conn (ctx )
161+ if err != nil {
162+ return nil , a .extractSqlErr (n , fmt .Errorf ("failed to get connection: %w" , err ))
162163 }
164+ defer conn .Close ()
163165
164- // Build parameter info
165- for i := 1 ; i <= paramCount ; i ++ {
166- name := ""
167- if ps != nil {
168- name , _ = ps . NameFor ( i )
166+ err = conn . Raw ( func ( driverConn any ) error {
167+ // Get the driver connection that supports PrepareContext
168+ preparer , ok := driverConn .(driver. ConnPrepareContext )
169+ if ! ok {
170+ return fmt . Errorf ( "driver connection does not support PrepareContext" )
169171 }
170- result .Params = append (result .Params , & core.Parameter {
171- Number : int32 (i ),
172- Column : & core.Column {
173- Name : name ,
174- DataType : "any" ,
175- NotNull : false ,
176- },
177- })
178- }
179-
180- return & result , nil
181- }
182172
183- // isSelectQuery checks if a query is a SELECT statement
184- func isSelectQuery (query string ) bool {
185- trimmed := strings .TrimSpace (strings .ToUpper (query ))
186- return strings .HasPrefix (trimmed , "SELECT" ) ||
187- strings .HasPrefix (trimmed , "WITH" ) // CTEs
188- }
189-
190- // getColumnMetadata executes the query with default values to retrieve column information
191- func (a * Analyzer ) getColumnMetadata (ctx context.Context , query string , paramCount int ) ([]* core.Column , error ) {
192- // Generate default parameter values (use 1 for all - works for most types)
193- args := make ([]any , paramCount )
194- for i := range args {
195- args [i ] = 1
196- }
173+ // Prepare the statement - this validates the SQL syntax and schema references
174+ // by sending COM_STMT_PREPARE to MySQL
175+ stmt , err := preparer .PrepareContext (ctx , query )
176+ if err != nil {
177+ return err
178+ }
179+ defer stmt .Close ()
197180
198- // Wrap query to avoid fetching data: SELECT * FROM (query) AS _sqlc_wrapper LIMIT 0
199- // This ensures we get column metadata without executing the actual query
200- wrappedQuery := fmt .Sprintf ("SELECT * FROM (%s) AS _sqlc_wrapper LIMIT 0" , query )
181+ // We intentionally don't use the column/parameter metadata from MySQL
182+ // because MySQL's type system doesn't always align with sqlc's expectations:
183+ // - Boolean expressions return BIGINT instead of bool
184+ // - Parameters get generic types (BIGINT/VARCHAR) instead of column types
185+ // - Type names differ from what the catalog inference provides
186+ //
187+ // By returning empty results, combineAnalysis() in analyze.go will
188+ // preserve the catalog-inferred types which match the expected output.
201189
202- rows , err := a .conn .QueryContext (ctx , wrappedQuery , args ... )
203- if err != nil {
204- // If wrapped query fails, try direct query with LIMIT 0
205- // Some queries may not support being wrapped (e.g., queries with UNION at the end)
206- return nil , err
207- }
208- defer rows .Close ()
190+ return nil
191+ })
209192
210- colTypes , err := rows .ColumnTypes ()
211193 if err != nil {
212- return nil , err
213- }
214-
215- var columns []* core.Column
216- for _ , col := range colTypes {
217- nullable , _ := col .Nullable ()
218- columns = append (columns , & core.Column {
219- Name : col .Name (),
220- DataType : strings .ToLower (col .DatabaseTypeName ()),
221- NotNull : ! nullable ,
222- })
194+ return nil , a .extractSqlErr (n , err )
223195 }
224196
225- return columns , nil
197+ return & result , nil
226198}
227199
228200// replaceDatabase replaces the database name in a MySQL DSN
@@ -253,47 +225,6 @@ func replaceDatabase(dsn string, newDB string) string {
253225 return dsn [:slashIdx + 1 ] + newDB + dsn [slashIdx + paramIdx :]
254226}
255227
256- // countParameters counts the number of ? placeholders in a query
257- func countParameters (query string ) int {
258- count := 0
259- inString := false
260- stringChar := byte (0 )
261- escaped := false
262-
263- for i := 0 ; i < len (query ); i ++ {
264- c := query [i ]
265-
266- if escaped {
267- escaped = false
268- continue
269- }
270-
271- if c == '\\' {
272- escaped = true
273- continue
274- }
275-
276- if inString {
277- if c == stringChar {
278- inString = false
279- }
280- continue
281- }
282-
283- if c == '\'' || c == '"' || c == '`' {
284- inString = true
285- stringChar = c
286- continue
287- }
288-
289- if c == '?' {
290- count ++
291- }
292- }
293-
294- return count
295- }
296-
297228func (a * Analyzer ) extractSqlErr (n ast.Node , err error ) error {
298229 if err == nil {
299230 return nil
0 commit comments