Skip to content

Commit fea4b4c

Browse files
kyleconroyclaude
andcommitted
feat(mysql): Use forked driver to get prepared statement metadata
Update the MySQL analyzer to use sqlc-dev/mysql fork which exposes column and parameter metadata from COM_STMT_PREPARE responses. Changes: - Add replace directive for github.com/go-sql-driver/mysql to use github.com/sqlc-dev/mysql@expose_query_metadata - Update analyzer to use driver.ConnPrepareContext and type-assert to mysql.StmtMetadata to access ColumnMetadata() and ParamMetadata() - Remove the old approach of executing wrapped queries with dummy params - Remove unused countParameters and isSelectQuery functions The forked driver reads and stores metadata that the upstream driver discards, allowing sqlc to get accurate type information for both query result columns and parameters directly from MySQL's PREPARE response without executing the query. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent e52cd4e commit fea4b4c

File tree

3 files changed

+86
-112
lines changed

3 files changed

+86
-112
lines changed

go.mod

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,5 @@ require (
6464
google.golang.org/genproto/googleapis/rpc v0.0.0-20251022142026-3a174f9686a8 // indirect
6565
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
6666
)
67+
68+
replace github.com/go-sql-driver/mysql => github.com/sqlc-dev/mysql v0.0.0-20251129233104-d81e1cac6db2

go.sum

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@ github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
2626
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
2727
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
2828
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
29-
github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo=
30-
github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU=
3129
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
3230
github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw=
3331
github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM=
@@ -159,6 +157,8 @@ github.com/spf13/cobra v1.10.1/go.mod h1:7SmJGaTHFVBY0jW4NXGluQoLvhqFQM+6XSKD+P4
159157
github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
160158
github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
161159
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
160+
github.com/sqlc-dev/mysql v0.0.0-20251129233104-d81e1cac6db2 h1:kmCAKKtOgK6EXXQX9oPdEASIhgor7TCpWxD8NtcqVcU=
161+
github.com/sqlc-dev/mysql v0.0.0-20251129233104-d81e1cac6db2/go.mod h1:TrDMWzjNTKvJeK2GC8uspG+PWyPLiY9QKvwdWpAdlZE=
162162
github.com/stoewer/go-strcase v1.2.0 h1:Z2iHWqGXH00XYgqDmNgQbIBxf3wrNq0F3feEy0ainaU=
163163
github.com/stoewer/go-strcase v1.2.0/go.mod h1:IBiWB2sKIp3wVVQ3Y035++gc+knqhUQag1KpM8ahLw8=
164164
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=

internal/engine/dolphin/analyzer/analyze.go

Lines changed: 82 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@ package analyzer
33
import (
44
"context"
55
"database/sql"
6+
"database/sql/driver"
67
"fmt"
78
"hash/fnv"
89
"io"
910
"strings"
1011
"sync"
1112

12-
_ "github.com/go-sql-driver/mysql"
13+
"github.com/go-sql-driver/mysql"
1314

1415
core "github.com/sqlc-dev/sqlc/internal/analysis"
1516
"github.com/sqlc-dev/sqlc/internal/config"
@@ -139,90 +140,102 @@ func (a *Analyzer) Analyze(ctx context.Context, n ast.Node, query string, migrat
139140
}
140141
}
141142

142-
// Count parameters in the query
143-
paramCount := countParameters(query)
144-
145-
// Try to prepare the statement first to validate syntax
146-
stmt, err := a.conn.PrepareContext(ctx, query)
143+
// Get metadata directly from prepared statement via driver connection
144+
result, err := a.getStatementMetadata(ctx, n, query, ps)
147145
if err != nil {
148-
return nil, a.extractSqlErr(n, err)
146+
return nil, err
149147
}
150-
stmt.Close()
151148

149+
return result, nil
150+
}
151+
152+
// getStatementMetadata uses the MySQL driver's prepared statement metadata API
153+
// to get column and parameter type information without executing the query
154+
func (a *Analyzer) getStatementMetadata(ctx context.Context, n ast.Node, query string, ps *named.ParamSet) (*core.Analysis, error) {
152155
var result core.Analysis
153156

154-
// For SELECT queries, execute with default parameter values to get column metadata
155-
if isSelectQuery(query) {
156-
cols, err := a.getColumnMetadata(ctx, query, paramCount)
157-
if err == nil {
158-
result.Columns = cols
159-
}
160-
// If we fail to get column metadata, fall through to return empty columns
161-
// and let the catalog-based inference handle it
157+
// Get a raw connection to access driver-level prepared statement
158+
conn, err := a.conn.Conn(ctx)
159+
if err != nil {
160+
return nil, a.extractSqlErr(n, fmt.Errorf("failed to get connection: %w", err))
162161
}
162+
defer conn.Close()
163163

164-
// Build parameter info
165-
for i := 1; i <= paramCount; i++ {
166-
name := ""
167-
if ps != nil {
168-
name, _ = ps.NameFor(i)
164+
err = conn.Raw(func(driverConn any) error {
165+
// Get the driver connection that supports PrepareContext
166+
preparer, ok := driverConn.(driver.ConnPrepareContext)
167+
if !ok {
168+
return fmt.Errorf("driver connection does not support PrepareContext")
169169
}
170-
result.Params = append(result.Params, &core.Parameter{
171-
Number: int32(i),
172-
Column: &core.Column{
173-
Name: name,
174-
DataType: "any",
175-
NotNull: false,
176-
},
177-
})
178-
}
179-
180-
return &result, nil
181-
}
182170

183-
// isSelectQuery checks if a query is a SELECT statement
184-
func isSelectQuery(query string) bool {
185-
trimmed := strings.TrimSpace(strings.ToUpper(query))
186-
return strings.HasPrefix(trimmed, "SELECT") ||
187-
strings.HasPrefix(trimmed, "WITH") // CTEs
188-
}
171+
// Prepare the statement - this sends COM_STMT_PREPARE to MySQL
172+
// and receives column and parameter metadata
173+
stmt, err := preparer.PrepareContext(ctx, query)
174+
if err != nil {
175+
return err
176+
}
177+
defer stmt.Close()
178+
179+
// Access the metadata via the StmtMetadata interface from our forked driver
180+
meta, ok := stmt.(mysql.StmtMetadata)
181+
if !ok {
182+
// Fallback: just use param count from NumInput
183+
paramCount := stmt.NumInput()
184+
for i := 1; i <= paramCount; i++ {
185+
name := ""
186+
if ps != nil {
187+
name, _ = ps.NameFor(i)
188+
}
189+
result.Params = append(result.Params, &core.Parameter{
190+
Number: int32(i),
191+
Column: &core.Column{
192+
Name: name,
193+
DataType: "any",
194+
NotNull: false,
195+
},
196+
})
197+
}
198+
return nil
199+
}
189200

190-
// getColumnMetadata executes the query with default values to retrieve column information
191-
func (a *Analyzer) getColumnMetadata(ctx context.Context, query string, paramCount int) ([]*core.Column, error) {
192-
// Generate default parameter values (use 1 for all - works for most types)
193-
args := make([]any, paramCount)
194-
for i := range args {
195-
args[i] = 1
196-
}
201+
// Get column metadata
202+
for _, col := range meta.ColumnMetadata() {
203+
result.Columns = append(result.Columns, &core.Column{
204+
Name: col.Name,
205+
DataType: strings.ToLower(col.DatabaseTypeName),
206+
NotNull: !col.Nullable,
207+
Unsigned: col.Unsigned,
208+
Length: int32(col.Length),
209+
})
210+
}
197211

198-
// Wrap query to avoid fetching data: SELECT * FROM (query) AS _sqlc_wrapper LIMIT 0
199-
// This ensures we get column metadata without executing the actual query
200-
wrappedQuery := fmt.Sprintf("SELECT * FROM (%s) AS _sqlc_wrapper LIMIT 0", query)
212+
// Get parameter metadata
213+
paramMeta := meta.ParamMetadata()
214+
for i, param := range paramMeta {
215+
name := ""
216+
if ps != nil {
217+
name, _ = ps.NameFor(i + 1)
218+
}
219+
result.Params = append(result.Params, &core.Parameter{
220+
Number: int32(i + 1),
221+
Column: &core.Column{
222+
Name: name,
223+
DataType: strings.ToLower(param.DatabaseTypeName),
224+
NotNull: !param.Nullable,
225+
Unsigned: param.Unsigned,
226+
Length: int32(param.Length),
227+
},
228+
})
229+
}
201230

202-
rows, err := a.conn.QueryContext(ctx, wrappedQuery, args...)
203-
if err != nil {
204-
// If wrapped query fails, try direct query with LIMIT 0
205-
// Some queries may not support being wrapped (e.g., queries with UNION at the end)
206-
return nil, err
207-
}
208-
defer rows.Close()
231+
return nil
232+
})
209233

210-
colTypes, err := rows.ColumnTypes()
211234
if err != nil {
212-
return nil, err
213-
}
214-
215-
var columns []*core.Column
216-
for _, col := range colTypes {
217-
nullable, _ := col.Nullable()
218-
columns = append(columns, &core.Column{
219-
Name: col.Name(),
220-
DataType: strings.ToLower(col.DatabaseTypeName()),
221-
NotNull: !nullable,
222-
})
235+
return nil, a.extractSqlErr(n, err)
223236
}
224237

225-
return columns, nil
238+
return &result, nil
226239
}
227240

228241
// replaceDatabase replaces the database name in a MySQL DSN
@@ -253,47 +266,6 @@ func replaceDatabase(dsn string, newDB string) string {
253266
return dsn[:slashIdx+1] + newDB + dsn[slashIdx+paramIdx:]
254267
}
255268

256-
// countParameters counts the number of ? placeholders in a query
257-
func countParameters(query string) int {
258-
count := 0
259-
inString := false
260-
stringChar := byte(0)
261-
escaped := false
262-
263-
for i := 0; i < len(query); i++ {
264-
c := query[i]
265-
266-
if escaped {
267-
escaped = false
268-
continue
269-
}
270-
271-
if c == '\\' {
272-
escaped = true
273-
continue
274-
}
275-
276-
if inString {
277-
if c == stringChar {
278-
inString = false
279-
}
280-
continue
281-
}
282-
283-
if c == '\'' || c == '"' || c == '`' {
284-
inString = true
285-
stringChar = c
286-
continue
287-
}
288-
289-
if c == '?' {
290-
count++
291-
}
292-
}
293-
294-
return count
295-
}
296-
297269
func (a *Analyzer) extractSqlErr(n ast.Node, err error) error {
298270
if err == nil {
299271
return nil

0 commit comments

Comments
 (0)