@@ -3,13 +3,14 @@ package analyzer
33import (
44 "context"
55 "database/sql"
6+ "database/sql/driver"
67 "fmt"
78 "hash/fnv"
89 "io"
910 "strings"
1011 "sync"
1112
12- _ "github.com/go-sql-driver/mysql"
13+ "github.com/go-sql-driver/mysql"
1314
1415 core "github.com/sqlc-dev/sqlc/internal/analysis"
1516 "github.com/sqlc-dev/sqlc/internal/config"
@@ -139,90 +140,102 @@ func (a *Analyzer) Analyze(ctx context.Context, n ast.Node, query string, migrat
139140 }
140141 }
141142
142- // Count parameters in the query
143- paramCount := countParameters (query )
144-
145- // Try to prepare the statement first to validate syntax
146- stmt , err := a .conn .PrepareContext (ctx , query )
143+ // Get metadata directly from prepared statement via driver connection
144+ result , err := a .getStatementMetadata (ctx , n , query , ps )
147145 if err != nil {
148- return nil , a . extractSqlErr ( n , err )
146+ return nil , err
149147 }
150- stmt .Close ()
151148
149+ return result , nil
150+ }
151+
152+ // getStatementMetadata uses the MySQL driver's prepared statement metadata API
153+ // to get column and parameter type information without executing the query
154+ func (a * Analyzer ) getStatementMetadata (ctx context.Context , n ast.Node , query string , ps * named.ParamSet ) (* core.Analysis , error ) {
152155 var result core.Analysis
153156
154- // For SELECT queries, execute with default parameter values to get column metadata
155- if isSelectQuery (query ) {
156- cols , err := a .getColumnMetadata (ctx , query , paramCount )
157- if err == nil {
158- result .Columns = cols
159- }
160- // If we fail to get column metadata, fall through to return empty columns
161- // and let the catalog-based inference handle it
157+ // Get a raw connection to access driver-level prepared statement
158+ conn , err := a .conn .Conn (ctx )
159+ if err != nil {
160+ return nil , a .extractSqlErr (n , fmt .Errorf ("failed to get connection: %w" , err ))
162161 }
162+ defer conn .Close ()
163163
164- // Build parameter info
165- for i := 1 ; i <= paramCount ; i ++ {
166- name := ""
167- if ps != nil {
168- name , _ = ps . NameFor ( i )
164+ err = conn . Raw ( func ( driverConn any ) error {
165+ // Get the driver connection that supports PrepareContext
166+ preparer , ok := driverConn .(driver. ConnPrepareContext )
167+ if ! ok {
168+ return fmt . Errorf ( "driver connection does not support PrepareContext" )
169169 }
170- result .Params = append (result .Params , & core.Parameter {
171- Number : int32 (i ),
172- Column : & core.Column {
173- Name : name ,
174- DataType : "any" ,
175- NotNull : false ,
176- },
177- })
178- }
179-
180- return & result , nil
181- }
182170
183- // isSelectQuery checks if a query is a SELECT statement
184- func isSelectQuery (query string ) bool {
185- trimmed := strings .TrimSpace (strings .ToUpper (query ))
186- return strings .HasPrefix (trimmed , "SELECT" ) ||
187- strings .HasPrefix (trimmed , "WITH" ) // CTEs
188- }
171+ // Prepare the statement - this sends COM_STMT_PREPARE to MySQL
172+ // and receives column and parameter metadata
173+ stmt , err := preparer .PrepareContext (ctx , query )
174+ if err != nil {
175+ return err
176+ }
177+ defer stmt .Close ()
178+
179+ // Access the metadata via the StmtMetadata interface from our forked driver
180+ meta , ok := stmt .(mysql.StmtMetadata )
181+ if ! ok {
182+ // Fallback: just use param count from NumInput
183+ paramCount := stmt .NumInput ()
184+ for i := 1 ; i <= paramCount ; i ++ {
185+ name := ""
186+ if ps != nil {
187+ name , _ = ps .NameFor (i )
188+ }
189+ result .Params = append (result .Params , & core.Parameter {
190+ Number : int32 (i ),
191+ Column : & core.Column {
192+ Name : name ,
193+ DataType : "any" ,
194+ NotNull : false ,
195+ },
196+ })
197+ }
198+ return nil
199+ }
189200
190- // getColumnMetadata executes the query with default values to retrieve column information
191- func (a * Analyzer ) getColumnMetadata (ctx context.Context , query string , paramCount int ) ([]* core.Column , error ) {
192- // Generate default parameter values (use 1 for all - works for most types)
193- args := make ([]any , paramCount )
194- for i := range args {
195- args [i ] = 1
196- }
201+ // Get column metadata
202+ for _ , col := range meta .ColumnMetadata () {
203+ result .Columns = append (result .Columns , & core.Column {
204+ Name : col .Name ,
205+ DataType : strings .ToLower (col .DatabaseTypeName ),
206+ NotNull : ! col .Nullable ,
207+ Unsigned : col .Unsigned ,
208+ Length : int32 (col .Length ),
209+ })
210+ }
197211
198- // Wrap query to avoid fetching data: SELECT * FROM (query) AS _sqlc_wrapper LIMIT 0
199- // This ensures we get column metadata without executing the actual query
200- wrappedQuery := fmt .Sprintf ("SELECT * FROM (%s) AS _sqlc_wrapper LIMIT 0" , query )
212+ // Get parameter metadata
213+ paramMeta := meta .ParamMetadata ()
214+ for i , param := range paramMeta {
215+ name := ""
216+ if ps != nil {
217+ name , _ = ps .NameFor (i + 1 )
218+ }
219+ result .Params = append (result .Params , & core.Parameter {
220+ Number : int32 (i + 1 ),
221+ Column : & core.Column {
222+ Name : name ,
223+ DataType : strings .ToLower (param .DatabaseTypeName ),
224+ NotNull : ! param .Nullable ,
225+ Unsigned : param .Unsigned ,
226+ Length : int32 (param .Length ),
227+ },
228+ })
229+ }
201230
202- rows , err := a .conn .QueryContext (ctx , wrappedQuery , args ... )
203- if err != nil {
204- // If wrapped query fails, try direct query with LIMIT 0
205- // Some queries may not support being wrapped (e.g., queries with UNION at the end)
206- return nil , err
207- }
208- defer rows .Close ()
231+ return nil
232+ })
209233
210- colTypes , err := rows .ColumnTypes ()
211234 if err != nil {
212- return nil , err
213- }
214-
215- var columns []* core.Column
216- for _ , col := range colTypes {
217- nullable , _ := col .Nullable ()
218- columns = append (columns , & core.Column {
219- Name : col .Name (),
220- DataType : strings .ToLower (col .DatabaseTypeName ()),
221- NotNull : ! nullable ,
222- })
235+ return nil , a .extractSqlErr (n , err )
223236 }
224237
225- return columns , nil
238+ return & result , nil
226239}
227240
228241// replaceDatabase replaces the database name in a MySQL DSN
@@ -253,47 +266,6 @@ func replaceDatabase(dsn string, newDB string) string {
253266 return dsn [:slashIdx + 1 ] + newDB + dsn [slashIdx + paramIdx :]
254267}
255268
256- // countParameters counts the number of ? placeholders in a query
257- func countParameters (query string ) int {
258- count := 0
259- inString := false
260- stringChar := byte (0 )
261- escaped := false
262-
263- for i := 0 ; i < len (query ); i ++ {
264- c := query [i ]
265-
266- if escaped {
267- escaped = false
268- continue
269- }
270-
271- if c == '\\' {
272- escaped = true
273- continue
274- }
275-
276- if inString {
277- if c == stringChar {
278- inString = false
279- }
280- continue
281- }
282-
283- if c == '\'' || c == '"' || c == '`' {
284- inString = true
285- stringChar = c
286- continue
287- }
288-
289- if c == '?' {
290- count ++
291- }
292- }
293-
294- return count
295- }
296-
297269func (a * Analyzer ) extractSqlErr (n ast.Node , err error ) error {
298270 if err == nil {
299271 return nil
0 commit comments