Skip to content

Commit

Permalink
feat: make the arguments print themselves with type info (#16232)
Browse files Browse the repository at this point in the history
Signed-off-by: Andres Taylor <andres@planetscale.com>
Signed-off-by: Florent Poinsard <florent.poinsard@outlook.fr>
Signed-off-by: Manan Gupta <manan@planetscale.com>
Signed-off-by: Dirkjan Bussink <d.bussink@gmail.com>
Co-authored-by: Florent Poinsard <florent.poinsard@outlook.fr>
Co-authored-by: Manan Gupta <manan@planetscale.com>
Co-authored-by: Dirkjan Bussink <d.bussink@gmail.com>
  • Loading branch information
4 people authored Jun 27, 2024
1 parent 96f9c3d commit 4922a3a
Show file tree
Hide file tree
Showing 30 changed files with 473 additions and 181 deletions.
7 changes: 1 addition & 6 deletions go/test/endtoend/vtgate/queries/normalize/normalize_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"vitess.io/vitess/go/test/endtoend/cluster"
"vitess.io/vitess/go/test/endtoend/utils"

"vitess.io/vitess/go/mysql"
Expand All @@ -41,11 +40,7 @@ func TestNormalizeAllFields(t *testing.T) {

insertQuery := `insert into t1 values (1, "chars", "variable chars", x'73757265', 0x676F, 0.33, 9.99, 1, "1976-06-08", "small", "b", "{\"key\":\"value\"}", point(1,5), b'011', 0b0101)`
normalizedInsertQuery := `insert into t1 values (:vtg1 /* INT64 */, :vtg2 /* VARCHAR */, :vtg3 /* VARCHAR */, :vtg4 /* HEXVAL */, :vtg5 /* HEXNUM */, :vtg6 /* DECIMAL(3,2) */, :vtg7 /* DECIMAL(3,2) */, :vtg8 /* INT64 */, :vtg9 /* VARCHAR */, :vtg10 /* VARCHAR */, :vtg11 /* VARCHAR */, :vtg12 /* VARCHAR */, point(:vtg13 /* INT64 */, :vtg14 /* INT64 */), :vtg15 /* BITNUM */, :vtg16 /* BITNUM */)`
vtgateVersion, err := cluster.GetMajorVersion("vtgate")
require.NoError(t, err)
if vtgateVersion < 20 {
normalizedInsertQuery = `insert into t1 values (:vtg1 /* INT64 */, :vtg2 /* VARCHAR */, :vtg3 /* VARCHAR */, :vtg4 /* HEXVAL */, :vtg5 /* HEXNUM */, :vtg6 /* DECIMAL */, :vtg7 /* DECIMAL */, :vtg8 /* INT64 */, :vtg9 /* VARCHAR */, :vtg10 /* VARCHAR */, :vtg11 /* VARCHAR */, :vtg12 /* VARCHAR */, point(:vtg13 /* INT64 */, :vtg14 /* INT64 */), :vtg15 /* BITNUM */, :vtg16 /* BITNUM */)`
}

selectQuery := "select * from t1"
utils.Exec(t, conn, insertQuery)
qr := utils.Exec(t, conn, selectQuery)
Expand Down
19 changes: 18 additions & 1 deletion go/test/endtoend/vtgate/queries/subquery/subquery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"vitess.io/vitess/go/sqltypes"

"vitess.io/vitess/go/test/endtoend/cluster"
"vitess.io/vitess/go/test/endtoend/utils"
)
Expand All @@ -34,7 +36,7 @@ func start(t *testing.T) (utils.MySQLCompare, func()) {
deleteAll := func() {
_, _ = utils.ExecAllowError(t, mcmp.VtConn, "set workload = oltp")

tables := []string{"t1", "t1_id2_idx", "t2", "t2_id4_idx"}
tables := []string{"t1", "t1_id2_idx", "t2", "t2_id4_idx", "user", "user_extra"}
for _, table := range tables {
_, _ = mcmp.ExecAndIgnore("delete from " + table)
}
Expand Down Expand Up @@ -232,3 +234,18 @@ func TestSubqueries(t *testing.T) {
})
}
}

func TestProperTypesOfPullOutValue(t *testing.T) {
utils.SkipIfBinaryIsBelowVersion(t, 21, "vtgate")

query := "select (select sum(id) from user) from user_extra"

mcmp, closer := start(t)
defer closer()

mcmp.Exec("INSERT INTO user (id, name) VALUES (1, 'Alice'), (2, 'Bob'), (3, 'Charlie'), (4, 'David'), (5, 'Eve'), (6, 'Frank'), (7, 'Grace'), (8, 'Hannah'), (9, 'Ivy'), (10, 'Jack')")
mcmp.Exec("INSERT INTO user_extra (user_id, extra_info) VALUES (1, 'info1'), (2, 'info1'), (3, 'info1'), (3, 'info2'), (4, 'info1'), (5, 'info1'), (6, 'info1'), (7, 'info1'), (8, 'info1')")

r := mcmp.Exec(query)
require.True(t, r.Fields[0].Type == sqltypes.Decimal)
}
58 changes: 58 additions & 0 deletions go/vt/sqlparser/ast_format.go
Original file line number Diff line number Diff line change
Expand Up @@ -1361,6 +1361,64 @@ func (node *Literal) Format(buf *TrackedBuffer) {

// Format formats the node.
func (node *Argument) Format(buf *TrackedBuffer) {
// We need to make sure that any value used still returns
// the right type when interpolated. For example, if we have a
// decimal type with 0 scale, we don't want it to be interpreted
// as an integer after interpolation as that would the default
// literal interpretation in MySQL.
switch {
case node.Type == sqltypes.Unknown:
// Ensure we handle unknown first as we don't want to treat
// the type as a bitmask for the further tests.
// do nothing, the default literal will be correct.
case sqltypes.IsDecimal(node.Type) && node.Scale == 0:
buf.WriteString("CAST(")
buf.WriteArg(":", node.Name)
buf.astPrintf(node, " AS DECIMAL(%d, %d))", node.Size, node.Scale)
return
case sqltypes.IsUnsigned(node.Type):
buf.WriteString("CAST(")
buf.WriteArg(":", node.Name)
buf.WriteString(" AS UNSIGNED)")
return
case node.Type == sqltypes.Float64:
buf.WriteString("CAST(")
buf.WriteArg(":", node.Name)
buf.WriteString(" AS DOUBLE)")
return
case node.Type == sqltypes.Float32:
buf.WriteString("CAST(")
buf.WriteArg(":", node.Name)
buf.WriteString(" AS FLOAT)")
return
case node.Type == sqltypes.Timestamp, node.Type == sqltypes.Datetime:
buf.WriteString("CAST(")
buf.WriteArg(":", node.Name)
buf.WriteString(" AS DATETIME")
if node.Size == 0 {
buf.WriteString(")")
return
}
buf.astPrintf(node, "(%d))", node.Size)
return
case sqltypes.IsDate(node.Type):
buf.WriteString("CAST(")
buf.WriteArg(":", node.Name)
buf.WriteString(" AS DATE")
buf.WriteString(")")
return
case node.Type == sqltypes.Time:
buf.WriteString("CAST(")
buf.WriteArg(":", node.Name)
buf.WriteString(" AS TIME")
if node.Size == 0 {
buf.WriteString(")")
return
}
buf.astPrintf(node, "(%d))", node.Size)
return
}
// Nothing special to do, the default literal will be correct.
buf.WriteArg(":", node.Name)
if node.Type >= 0 {
// For bind variables that are statically typed, emit their type as an adjacent comment.
Expand Down
66 changes: 66 additions & 0 deletions go/vt/sqlparser/ast_format_fast.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 12 additions & 5 deletions go/vt/sqlparser/normalizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,17 +82,24 @@ func TestNormalize(t *testing.T) {
}, {
// datetime val
in: "select * from t where foobar = timestamp'2012-02-29 12:34:56.123456'",
outstmt: "select * from t where foobar = :foobar /* DATETIME(6) */",
outstmt: "select * from t where foobar = CAST(:foobar AS DATETIME(6))",
outbv: map[string]*querypb.BindVariable{
"foobar": sqltypes.ValueBindVariable(sqltypes.NewDatetime("2012-02-29 12:34:56.123456")),
},
}, {
// time val
in: "select * from t where foobar = time'12:34:56.123456'",
outstmt: "select * from t where foobar = :foobar /* TIME(6) */",
outstmt: "select * from t where foobar = CAST(:foobar AS TIME(6))",
outbv: map[string]*querypb.BindVariable{
"foobar": sqltypes.ValueBindVariable(sqltypes.NewTime("12:34:56.123456")),
},
}, {
// time val
in: "select * from t where foobar = time'12:34:56'",
outstmt: "select * from t where foobar = CAST(:foobar AS TIME)",
outbv: map[string]*querypb.BindVariable{
"foobar": sqltypes.ValueBindVariable(sqltypes.NewTime("12:34:56")),
},
}, {
// multiple vals
in: "select * from t where foo = 1.2 and bar = 2",
Expand Down Expand Up @@ -334,21 +341,21 @@ func TestNormalize(t *testing.T) {
}, {
// DateVal should also be normalized
in: `select date'2022-08-06'`,
outstmt: `select :bv1 /* DATE */ from dual`,
outstmt: `select CAST(:bv1 AS DATE) from dual`,
outbv: map[string]*querypb.BindVariable{
"bv1": sqltypes.ValueBindVariable(sqltypes.MakeTrusted(sqltypes.Date, []byte("2022-08-06"))),
},
}, {
// TimeVal should also be normalized
in: `select time'17:05:12'`,
outstmt: `select :bv1 /* TIME */ from dual`,
outstmt: `select CAST(:bv1 AS TIME) from dual`,
outbv: map[string]*querypb.BindVariable{
"bv1": sqltypes.ValueBindVariable(sqltypes.MakeTrusted(sqltypes.Time, []byte("17:05:12"))),
},
}, {
// TimestampVal should also be normalized
in: `select timestamp'2022-08-06 17:05:12'`,
outstmt: `select :bv1 /* DATETIME */ from dual`,
outstmt: `select CAST(:bv1 AS DATETIME) from dual`,
outbv: map[string]*querypb.BindVariable{
"bv1": sqltypes.ValueBindVariable(sqltypes.MakeTrusted(sqltypes.Datetime, []byte("2022-08-06 17:05:12"))),
},
Expand Down
94 changes: 92 additions & 2 deletions go/vt/sqlparser/parsed_query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ import (
"reflect"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"

"github.com/stretchr/testify/assert"
)

func TestNewParsedQuery(t *testing.T) {
Expand Down Expand Up @@ -205,3 +206,92 @@ func TestParseAndBind(t *testing.T) {
})
}
}

func TestCastBindVars(t *testing.T) {
testcases := []struct {
typ sqltypes.Type
size int
binds map[string]*querypb.BindVariable
out string
}{
{
typ: sqltypes.Decimal,
binds: map[string]*querypb.BindVariable{"arg": sqltypes.DecimalBindVariable("50")},
out: "select CAST(50 AS DECIMAL(0, 0)) from ",
},
{
typ: sqltypes.Uint32,
binds: map[string]*querypb.BindVariable{"arg": {Type: sqltypes.Uint32, Value: sqltypes.NewUint32(42).Raw()}},
out: "select CAST(42 AS UNSIGNED) from ",
},
{
typ: sqltypes.Float64,
binds: map[string]*querypb.BindVariable{"arg": {Type: sqltypes.Float64, Value: sqltypes.NewFloat64(42.42).Raw()}},
out: "select CAST(42.42 AS DOUBLE) from ",
},
{
typ: sqltypes.Float32,
binds: map[string]*querypb.BindVariable{"arg": {Type: sqltypes.Float32, Value: sqltypes.NewFloat32(42).Raw()}},
out: "select CAST(42 AS FLOAT) from ",
},
{
typ: sqltypes.Date,
binds: map[string]*querypb.BindVariable{"arg": {Type: sqltypes.Date, Value: sqltypes.NewDate("2021-10-30").Raw()}},
out: "select CAST('2021-10-30' AS DATE) from ",
},
{
typ: sqltypes.Time,
binds: map[string]*querypb.BindVariable{"arg": {Type: sqltypes.Time, Value: sqltypes.NewTime("12:00:00").Raw()}},
out: "select CAST('12:00:00' AS TIME) from ",
},
{
typ: sqltypes.Time,
size: 6,
binds: map[string]*querypb.BindVariable{"arg": {Type: sqltypes.Time, Value: sqltypes.NewTime("12:00:00").Raw()}},
out: "select CAST('12:00:00' AS TIME(6)) from ",
},
{
typ: sqltypes.Timestamp,
binds: map[string]*querypb.BindVariable{"arg": {Type: sqltypes.Timestamp, Value: sqltypes.NewTimestamp("2021-10-22 12:00:00").Raw()}},
out: "select CAST('2021-10-22 12:00:00' AS DATETIME) from ",
},
{
typ: sqltypes.Timestamp,
size: 6,
binds: map[string]*querypb.BindVariable{"arg": {Type: sqltypes.Timestamp, Value: sqltypes.NewTimestamp("2021-10-22 12:00:00").Raw()}},
out: "select CAST('2021-10-22 12:00:00' AS DATETIME(6)) from ",
},
{
typ: sqltypes.Datetime,
binds: map[string]*querypb.BindVariable{"arg": {Type: sqltypes.Datetime, Value: sqltypes.NewDatetime("2021-10-22 12:00:00").Raw()}},
out: "select CAST('2021-10-22 12:00:00' AS DATETIME) from ",
},
{
typ: sqltypes.Datetime,
size: 6,
binds: map[string]*querypb.BindVariable{"arg": {Type: sqltypes.Datetime, Value: sqltypes.NewDatetime("2021-10-22 12:00:00").Raw()}},
out: "select CAST('2021-10-22 12:00:00' AS DATETIME(6)) from ",
},
}

for _, testcase := range testcases {
t.Run(testcase.out, func(t *testing.T) {
argument := NewTypedArgument("arg", testcase.typ)
if testcase.size > 0 {
argument.Size = int32(testcase.size)
}

s := &Select{
SelectExprs: SelectExprs{
NewAliasedExpr(argument, ""),
},
}

pq := NewParsedQuery(s)
out, err := pq.GenerateQuery(testcase.binds, nil)

require.NoError(t, err)
require.Equal(t, testcase.out, out)
})
}
}
2 changes: 1 addition & 1 deletion go/vt/vtexplain/testdata/twopc-output/unsharded-output.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,4 @@ insert into t1 (id,intval,floatval) values (1,2,3.14) on duplicate key update in
1 ks_unsharded/-: insert into t1(id, intval, floatval) values (1, 2, 3.14) on duplicate key update intval = 3, floatval = 3.14
2 ks_unsharded/-: commit

----------------------------------------------------------------------
----------------------------------------------------------------------
4 changes: 4 additions & 0 deletions go/vt/vtgate/evalengine/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ func (v *EnumSetValues) Equal(other *EnumSetValues) bool {
return slices.Equal(*v, *other)
}

func NewUnknownType() Type {
return NewType(sqltypes.Unknown, collations.Unknown)
}

func NewType(t sqltypes.Type, collation collations.ID) Type {
// New types default to being nullable
return NewTypeEx(t, collation, true, 0, 0, nil)
Expand Down
3 changes: 0 additions & 3 deletions go/vt/vtgate/evalengine/expr_bvar.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,6 @@ func (bv *BindVariable) eval(env *ExpressionEnv) (eval, error) {
return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "query argument '%s' cannot be a tuple", bv.Key)
}
typ := bvar.Type
if bv.typed() {
typ = bv.Type
}
return valueToEval(sqltypes.MakeTrusted(typ, bvar.Value), typedCoercionCollation(typ, collations.CollationForType(typ, bv.Collation)), nil)
}
}
Expand Down
6 changes: 5 additions & 1 deletion go/vt/vtgate/planbuilder/operators/expressions.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ func breakExpressionInLHSandRHS(
Name: bvName,
Expr: nodeExpr,
})
arg := sqlparser.NewArgument(bvName)
typeForExpr, _ := ctx.TypeForExpr(nodeExpr)
arg := sqlparser.NewTypedArgument(bvName, typeForExpr.Type())
arg.Scale = typeForExpr.Scale()
arg.Size = typeForExpr.Size()

// we are replacing one of the sides of the comparison with an argument,
// but we don't want to lose the type information we have, so we copy it over
ctx.SemTable.CopyExprInfo(nodeExpr, arg)
Expand Down
Loading

0 comments on commit 4922a3a

Please sign in to comment.