diff --git a/go/test/endtoend/vtgate/queries/normalize/normalize_test.go b/go/test/endtoend/vtgate/queries/normalize/normalize_test.go index 51d9f9f24bf..f7d6f45a784 100644 --- a/go/test/endtoend/vtgate/queries/normalize/normalize_test.go +++ b/go/test/endtoend/vtgate/queries/normalize/normalize_test.go @@ -28,7 +28,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "vitess.io/vitess/go/test/endtoend/cluster" "vitess.io/vitess/go/test/endtoend/utils" "vitess.io/vitess/go/mysql" @@ -41,11 +40,7 @@ func TestNormalizeAllFields(t *testing.T) { insertQuery := `insert into t1 values (1, "chars", "variable chars", x'73757265', 0x676F, 0.33, 9.99, 1, "1976-06-08", "small", "b", "{\"key\":\"value\"}", point(1,5), b'011', 0b0101)` normalizedInsertQuery := `insert into t1 values (:vtg1 /* INT64 */, :vtg2 /* VARCHAR */, :vtg3 /* VARCHAR */, :vtg4 /* HEXVAL */, :vtg5 /* HEXNUM */, :vtg6 /* DECIMAL(3,2) */, :vtg7 /* DECIMAL(3,2) */, :vtg8 /* INT64 */, :vtg9 /* VARCHAR */, :vtg10 /* VARCHAR */, :vtg11 /* VARCHAR */, :vtg12 /* VARCHAR */, point(:vtg13 /* INT64 */, :vtg14 /* INT64 */), :vtg15 /* BITNUM */, :vtg16 /* BITNUM */)` - vtgateVersion, err := cluster.GetMajorVersion("vtgate") - require.NoError(t, err) - if vtgateVersion < 20 { - normalizedInsertQuery = `insert into t1 values (:vtg1 /* INT64 */, :vtg2 /* VARCHAR */, :vtg3 /* VARCHAR */, :vtg4 /* HEXVAL */, :vtg5 /* HEXNUM */, :vtg6 /* DECIMAL */, :vtg7 /* DECIMAL */, :vtg8 /* INT64 */, :vtg9 /* VARCHAR */, :vtg10 /* VARCHAR */, :vtg11 /* VARCHAR */, :vtg12 /* VARCHAR */, point(:vtg13 /* INT64 */, :vtg14 /* INT64 */), :vtg15 /* BITNUM */, :vtg16 /* BITNUM */)` - } + selectQuery := "select * from t1" utils.Exec(t, conn, insertQuery) qr := utils.Exec(t, conn, selectQuery) diff --git a/go/test/endtoend/vtgate/queries/subquery/subquery_test.go b/go/test/endtoend/vtgate/queries/subquery/subquery_test.go index abbf5ff15e8..b8fcca34f1c 100644 --- a/go/test/endtoend/vtgate/queries/subquery/subquery_test.go +++ b/go/test/endtoend/vtgate/queries/subquery/subquery_test.go @@ -23,6 +23,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/test/endtoend/cluster" "vitess.io/vitess/go/test/endtoend/utils" ) @@ -34,7 +36,7 @@ func start(t *testing.T) (utils.MySQLCompare, func()) { deleteAll := func() { _, _ = utils.ExecAllowError(t, mcmp.VtConn, "set workload = oltp") - tables := []string{"t1", "t1_id2_idx", "t2", "t2_id4_idx"} + tables := []string{"t1", "t1_id2_idx", "t2", "t2_id4_idx", "user", "user_extra"} for _, table := range tables { _, _ = mcmp.ExecAndIgnore("delete from " + table) } @@ -232,3 +234,18 @@ func TestSubqueries(t *testing.T) { }) } } + +func TestProperTypesOfPullOutValue(t *testing.T) { + utils.SkipIfBinaryIsBelowVersion(t, 21, "vtgate") + + query := "select (select sum(id) from user) from user_extra" + + mcmp, closer := start(t) + defer closer() + + mcmp.Exec("INSERT INTO user (id, name) VALUES (1, 'Alice'), (2, 'Bob'), (3, 'Charlie'), (4, 'David'), (5, 'Eve'), (6, 'Frank'), (7, 'Grace'), (8, 'Hannah'), (9, 'Ivy'), (10, 'Jack')") + mcmp.Exec("INSERT INTO user_extra (user_id, extra_info) VALUES (1, 'info1'), (2, 'info1'), (3, 'info1'), (3, 'info2'), (4, 'info1'), (5, 'info1'), (6, 'info1'), (7, 'info1'), (8, 'info1')") + + r := mcmp.Exec(query) + require.True(t, r.Fields[0].Type == sqltypes.Decimal) +} diff --git a/go/vt/sqlparser/ast_format.go b/go/vt/sqlparser/ast_format.go index 49de08381d2..da88129ee63 100644 --- a/go/vt/sqlparser/ast_format.go +++ b/go/vt/sqlparser/ast_format.go @@ -1361,6 +1361,64 @@ func (node *Literal) Format(buf *TrackedBuffer) { // Format formats the node. func (node *Argument) Format(buf *TrackedBuffer) { + // We need to make sure that any value used still returns + // the right type when interpolated. For example, if we have a + // decimal type with 0 scale, we don't want it to be interpreted + // as an integer after interpolation as that would the default + // literal interpretation in MySQL. + switch { + case node.Type == sqltypes.Unknown: + // Ensure we handle unknown first as we don't want to treat + // the type as a bitmask for the further tests. + // do nothing, the default literal will be correct. + case sqltypes.IsDecimal(node.Type) && node.Scale == 0: + buf.WriteString("CAST(") + buf.WriteArg(":", node.Name) + buf.astPrintf(node, " AS DECIMAL(%d, %d))", node.Size, node.Scale) + return + case sqltypes.IsUnsigned(node.Type): + buf.WriteString("CAST(") + buf.WriteArg(":", node.Name) + buf.WriteString(" AS UNSIGNED)") + return + case node.Type == sqltypes.Float64: + buf.WriteString("CAST(") + buf.WriteArg(":", node.Name) + buf.WriteString(" AS DOUBLE)") + return + case node.Type == sqltypes.Float32: + buf.WriteString("CAST(") + buf.WriteArg(":", node.Name) + buf.WriteString(" AS FLOAT)") + return + case node.Type == sqltypes.Timestamp, node.Type == sqltypes.Datetime: + buf.WriteString("CAST(") + buf.WriteArg(":", node.Name) + buf.WriteString(" AS DATETIME") + if node.Size == 0 { + buf.WriteString(")") + return + } + buf.astPrintf(node, "(%d))", node.Size) + return + case sqltypes.IsDate(node.Type): + buf.WriteString("CAST(") + buf.WriteArg(":", node.Name) + buf.WriteString(" AS DATE") + buf.WriteString(")") + return + case node.Type == sqltypes.Time: + buf.WriteString("CAST(") + buf.WriteArg(":", node.Name) + buf.WriteString(" AS TIME") + if node.Size == 0 { + buf.WriteString(")") + return + } + buf.astPrintf(node, "(%d))", node.Size) + return + } + // Nothing special to do, the default literal will be correct. buf.WriteArg(":", node.Name) if node.Type >= 0 { // For bind variables that are statically typed, emit their type as an adjacent comment. diff --git a/go/vt/sqlparser/ast_format_fast.go b/go/vt/sqlparser/ast_format_fast.go index 87626f0b799..b1dd010f5ed 100644 --- a/go/vt/sqlparser/ast_format_fast.go +++ b/go/vt/sqlparser/ast_format_fast.go @@ -1780,6 +1780,72 @@ func (node *Literal) FormatFast(buf *TrackedBuffer) { // FormatFast formats the node. func (node *Argument) FormatFast(buf *TrackedBuffer) { + // We need to make sure that any value used still returns + // the right type when interpolated. For example, if we have a + // decimal type with 0 scale, we don't want it to be interpreted + // as an integer after interpolation as that would the default + // literal interpretation in MySQL. + switch { + case node.Type == sqltypes.Unknown: + // Ensure we handle unknown first as we don't want to treat + // the type as a bitmask for the further tests. + // do nothing, the default literal will be correct. + case sqltypes.IsDecimal(node.Type) && node.Scale == 0: + buf.WriteString("CAST(") + buf.WriteArg(":", node.Name) + buf.WriteString(" AS DECIMAL(") + buf.WriteString(fmt.Sprintf("%d", node.Size)) + buf.WriteString(", ") + buf.WriteString(fmt.Sprintf("%d", node.Scale)) + buf.WriteString("))") + return + case sqltypes.IsUnsigned(node.Type): + buf.WriteString("CAST(") + buf.WriteArg(":", node.Name) + buf.WriteString(" AS UNSIGNED)") + return + case node.Type == sqltypes.Float64: + buf.WriteString("CAST(") + buf.WriteArg(":", node.Name) + buf.WriteString(" AS DOUBLE)") + return + case node.Type == sqltypes.Float32: + buf.WriteString("CAST(") + buf.WriteArg(":", node.Name) + buf.WriteString(" AS FLOAT)") + return + case node.Type == sqltypes.Timestamp, node.Type == sqltypes.Datetime: + buf.WriteString("CAST(") + buf.WriteArg(":", node.Name) + buf.WriteString(" AS DATETIME") + if node.Size == 0 { + buf.WriteString(")") + return + } + buf.WriteByte('(') + buf.WriteString(fmt.Sprintf("%d", node.Size)) + buf.WriteString("))") + return + case sqltypes.IsDate(node.Type): + buf.WriteString("CAST(") + buf.WriteArg(":", node.Name) + buf.WriteString(" AS DATE") + buf.WriteString(")") + return + case node.Type == sqltypes.Time: + buf.WriteString("CAST(") + buf.WriteArg(":", node.Name) + buf.WriteString(" AS TIME") + if node.Size == 0 { + buf.WriteString(")") + return + } + buf.WriteByte('(') + buf.WriteString(fmt.Sprintf("%d", node.Size)) + buf.WriteString("))") + return + } + // Nothing special to do, the default literal will be correct. buf.WriteArg(":", node.Name) if node.Type >= 0 { // For bind variables that are statically typed, emit their type as an adjacent comment. diff --git a/go/vt/sqlparser/normalizer_test.go b/go/vt/sqlparser/normalizer_test.go index 19b0cfbcac6..c574b00832d 100644 --- a/go/vt/sqlparser/normalizer_test.go +++ b/go/vt/sqlparser/normalizer_test.go @@ -82,17 +82,24 @@ func TestNormalize(t *testing.T) { }, { // datetime val in: "select * from t where foobar = timestamp'2012-02-29 12:34:56.123456'", - outstmt: "select * from t where foobar = :foobar /* DATETIME(6) */", + outstmt: "select * from t where foobar = CAST(:foobar AS DATETIME(6))", outbv: map[string]*querypb.BindVariable{ "foobar": sqltypes.ValueBindVariable(sqltypes.NewDatetime("2012-02-29 12:34:56.123456")), }, }, { // time val in: "select * from t where foobar = time'12:34:56.123456'", - outstmt: "select * from t where foobar = :foobar /* TIME(6) */", + outstmt: "select * from t where foobar = CAST(:foobar AS TIME(6))", outbv: map[string]*querypb.BindVariable{ "foobar": sqltypes.ValueBindVariable(sqltypes.NewTime("12:34:56.123456")), }, + }, { + // time val + in: "select * from t where foobar = time'12:34:56'", + outstmt: "select * from t where foobar = CAST(:foobar AS TIME)", + outbv: map[string]*querypb.BindVariable{ + "foobar": sqltypes.ValueBindVariable(sqltypes.NewTime("12:34:56")), + }, }, { // multiple vals in: "select * from t where foo = 1.2 and bar = 2", @@ -334,21 +341,21 @@ func TestNormalize(t *testing.T) { }, { // DateVal should also be normalized in: `select date'2022-08-06'`, - outstmt: `select :bv1 /* DATE */ from dual`, + outstmt: `select CAST(:bv1 AS DATE) from dual`, outbv: map[string]*querypb.BindVariable{ "bv1": sqltypes.ValueBindVariable(sqltypes.MakeTrusted(sqltypes.Date, []byte("2022-08-06"))), }, }, { // TimeVal should also be normalized in: `select time'17:05:12'`, - outstmt: `select :bv1 /* TIME */ from dual`, + outstmt: `select CAST(:bv1 AS TIME) from dual`, outbv: map[string]*querypb.BindVariable{ "bv1": sqltypes.ValueBindVariable(sqltypes.MakeTrusted(sqltypes.Time, []byte("17:05:12"))), }, }, { // TimestampVal should also be normalized in: `select timestamp'2022-08-06 17:05:12'`, - outstmt: `select :bv1 /* DATETIME */ from dual`, + outstmt: `select CAST(:bv1 AS DATETIME) from dual`, outbv: map[string]*querypb.BindVariable{ "bv1": sqltypes.ValueBindVariable(sqltypes.MakeTrusted(sqltypes.Datetime, []byte("2022-08-06 17:05:12"))), }, diff --git a/go/vt/sqlparser/parsed_query_test.go b/go/vt/sqlparser/parsed_query_test.go index ef59676883f..8ade9d4d31c 100644 --- a/go/vt/sqlparser/parsed_query_test.go +++ b/go/vt/sqlparser/parsed_query_test.go @@ -20,10 +20,11 @@ import ( "reflect" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" - - "github.com/stretchr/testify/assert" ) func TestNewParsedQuery(t *testing.T) { @@ -205,3 +206,92 @@ func TestParseAndBind(t *testing.T) { }) } } + +func TestCastBindVars(t *testing.T) { + testcases := []struct { + typ sqltypes.Type + size int + binds map[string]*querypb.BindVariable + out string + }{ + { + typ: sqltypes.Decimal, + binds: map[string]*querypb.BindVariable{"arg": sqltypes.DecimalBindVariable("50")}, + out: "select CAST(50 AS DECIMAL(0, 0)) from ", + }, + { + typ: sqltypes.Uint32, + binds: map[string]*querypb.BindVariable{"arg": {Type: sqltypes.Uint32, Value: sqltypes.NewUint32(42).Raw()}}, + out: "select CAST(42 AS UNSIGNED) from ", + }, + { + typ: sqltypes.Float64, + binds: map[string]*querypb.BindVariable{"arg": {Type: sqltypes.Float64, Value: sqltypes.NewFloat64(42.42).Raw()}}, + out: "select CAST(42.42 AS DOUBLE) from ", + }, + { + typ: sqltypes.Float32, + binds: map[string]*querypb.BindVariable{"arg": {Type: sqltypes.Float32, Value: sqltypes.NewFloat32(42).Raw()}}, + out: "select CAST(42 AS FLOAT) from ", + }, + { + typ: sqltypes.Date, + binds: map[string]*querypb.BindVariable{"arg": {Type: sqltypes.Date, Value: sqltypes.NewDate("2021-10-30").Raw()}}, + out: "select CAST('2021-10-30' AS DATE) from ", + }, + { + typ: sqltypes.Time, + binds: map[string]*querypb.BindVariable{"arg": {Type: sqltypes.Time, Value: sqltypes.NewTime("12:00:00").Raw()}}, + out: "select CAST('12:00:00' AS TIME) from ", + }, + { + typ: sqltypes.Time, + size: 6, + binds: map[string]*querypb.BindVariable{"arg": {Type: sqltypes.Time, Value: sqltypes.NewTime("12:00:00").Raw()}}, + out: "select CAST('12:00:00' AS TIME(6)) from ", + }, + { + typ: sqltypes.Timestamp, + binds: map[string]*querypb.BindVariable{"arg": {Type: sqltypes.Timestamp, Value: sqltypes.NewTimestamp("2021-10-22 12:00:00").Raw()}}, + out: "select CAST('2021-10-22 12:00:00' AS DATETIME) from ", + }, + { + typ: sqltypes.Timestamp, + size: 6, + binds: map[string]*querypb.BindVariable{"arg": {Type: sqltypes.Timestamp, Value: sqltypes.NewTimestamp("2021-10-22 12:00:00").Raw()}}, + out: "select CAST('2021-10-22 12:00:00' AS DATETIME(6)) from ", + }, + { + typ: sqltypes.Datetime, + binds: map[string]*querypb.BindVariable{"arg": {Type: sqltypes.Datetime, Value: sqltypes.NewDatetime("2021-10-22 12:00:00").Raw()}}, + out: "select CAST('2021-10-22 12:00:00' AS DATETIME) from ", + }, + { + typ: sqltypes.Datetime, + size: 6, + binds: map[string]*querypb.BindVariable{"arg": {Type: sqltypes.Datetime, Value: sqltypes.NewDatetime("2021-10-22 12:00:00").Raw()}}, + out: "select CAST('2021-10-22 12:00:00' AS DATETIME(6)) from ", + }, + } + + for _, testcase := range testcases { + t.Run(testcase.out, func(t *testing.T) { + argument := NewTypedArgument("arg", testcase.typ) + if testcase.size > 0 { + argument.Size = int32(testcase.size) + } + + s := &Select{ + SelectExprs: SelectExprs{ + NewAliasedExpr(argument, ""), + }, + } + + pq := NewParsedQuery(s) + out, err := pq.GenerateQuery(testcase.binds, nil) + + require.NoError(t, err) + require.Equal(t, testcase.out, out) + }) + } +} diff --git a/go/vt/vtexplain/testdata/twopc-output/unsharded-output.txt b/go/vt/vtexplain/testdata/twopc-output/unsharded-output.txt index 0db31f10110..b7299002d01 100644 --- a/go/vt/vtexplain/testdata/twopc-output/unsharded-output.txt +++ b/go/vt/vtexplain/testdata/twopc-output/unsharded-output.txt @@ -45,4 +45,4 @@ insert into t1 (id,intval,floatval) values (1,2,3.14) on duplicate key update in 1 ks_unsharded/-: insert into t1(id, intval, floatval) values (1, 2, 3.14) on duplicate key update intval = 3, floatval = 3.14 2 ks_unsharded/-: commit ----------------------------------------------------------------------- +---------------------------------------------------------------------- \ No newline at end of file diff --git a/go/vt/vtgate/evalengine/compiler.go b/go/vt/vtgate/evalengine/compiler.go index bcb2281f1a6..b0a7edd285d 100644 --- a/go/vt/vtgate/evalengine/compiler.go +++ b/go/vt/vtgate/evalengine/compiler.go @@ -81,6 +81,10 @@ func (v *EnumSetValues) Equal(other *EnumSetValues) bool { return slices.Equal(*v, *other) } +func NewUnknownType() Type { + return NewType(sqltypes.Unknown, collations.Unknown) +} + func NewType(t sqltypes.Type, collation collations.ID) Type { // New types default to being nullable return NewTypeEx(t, collation, true, 0, 0, nil) diff --git a/go/vt/vtgate/evalengine/expr_bvar.go b/go/vt/vtgate/evalengine/expr_bvar.go index 0fffe3140a2..daf64296e98 100644 --- a/go/vt/vtgate/evalengine/expr_bvar.go +++ b/go/vt/vtgate/evalengine/expr_bvar.go @@ -83,9 +83,6 @@ func (bv *BindVariable) eval(env *ExpressionEnv) (eval, error) { return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "query argument '%s' cannot be a tuple", bv.Key) } typ := bvar.Type - if bv.typed() { - typ = bv.Type - } return valueToEval(sqltypes.MakeTrusted(typ, bvar.Value), typedCoercionCollation(typ, collations.CollationForType(typ, bv.Collation)), nil) } } diff --git a/go/vt/vtgate/planbuilder/operators/expressions.go b/go/vt/vtgate/planbuilder/operators/expressions.go index a39ae96fa88..17b4bc7c3f1 100644 --- a/go/vt/vtgate/planbuilder/operators/expressions.go +++ b/go/vt/vtgate/planbuilder/operators/expressions.go @@ -44,7 +44,11 @@ func breakExpressionInLHSandRHS( Name: bvName, Expr: nodeExpr, }) - arg := sqlparser.NewArgument(bvName) + typeForExpr, _ := ctx.TypeForExpr(nodeExpr) + arg := sqlparser.NewTypedArgument(bvName, typeForExpr.Type()) + arg.Scale = typeForExpr.Scale() + arg.Size = typeForExpr.Size() + // we are replacing one of the sides of the comparison with an argument, // but we don't want to lose the type information we have, so we copy it over ctx.SemTable.CopyExprInfo(nodeExpr, arg) diff --git a/go/vt/vtgate/planbuilder/operators/query_planning.go b/go/vt/vtgate/planbuilder/operators/query_planning.go index 533d740f300..2602e39f87e 100644 --- a/go/vt/vtgate/planbuilder/operators/query_planning.go +++ b/go/vt/vtgate/planbuilder/operators/query_planning.go @@ -536,26 +536,10 @@ func tryPushOrdering(ctx *plancontext.PlanningContext, in *Ordering) (Operator, return pushOrderingUnderAggr(ctx, in, src) case *SubQueryContainer: return pushOrderingToOuterOfSubqueryContainer(ctx, in, src) - case *SubQuery: - return pushOrderingToOuterOfSubquery(ctx, in, src) } return in, NoRewrite } -func pushOrderingToOuterOfSubquery(ctx *plancontext.PlanningContext, in *Ordering, sq *SubQuery) (Operator, *ApplyResult) { - outerTableID := TableID(sq.Outer) - for idx, order := range in.Order { - deps := ctx.SemTable.RecursiveDeps(order.Inner.Expr) - if !deps.IsSolvedBy(outerTableID) { - return in, NoRewrite - } - in.Order[idx].SimplifiedExpr = sq.rewriteColNameToArgument(order.SimplifiedExpr) - in.Order[idx].Inner.Expr = sq.rewriteColNameToArgument(order.Inner.Expr) - } - sq.Outer, in.Source = in, sq.Outer - return sq, Rewrote("push ordering into outer side of subquery") -} - func pushOrderingToOuterOfSubqueryContainer(ctx *plancontext.PlanningContext, in *Ordering, subq *SubQueryContainer) (Operator, *ApplyResult) { outerTableID := TableID(subq.Outer) for _, order := range in.Order { diff --git a/go/vt/vtgate/planbuilder/operators/queryprojection.go b/go/vt/vtgate/planbuilder/operators/queryprojection.go index 5729dbd0c2e..352a5ffc7a7 100644 --- a/go/vt/vtgate/planbuilder/operators/queryprojection.go +++ b/go/vt/vtgate/planbuilder/operators/queryprojection.go @@ -68,26 +68,22 @@ type ( // Aggr encodes all information needed for aggregation functions Aggr struct { - Original *sqlparser.AliasedExpr - Func sqlparser.AggrFunc // if we are missing a Func, it means this is a AggregateAnyValue - OpCode opcode.AggregateOpcode + Original *sqlparser.AliasedExpr // The original SQL expression for the aggregation + Func sqlparser.AggrFunc // The aggregation function (e.g., COUNT, SUM). If nil, it means AggregateAnyValue is used + OpCode opcode.AggregateOpcode // The opcode representing the type of aggregation being performed - // OriginalOpCode will contain opcode.AggregateUnassigned unless we are changing opcode while pushing them down + // OriginalOpCode will contain opcode.AggregateUnassigned unless we are changing the opcode while pushing them down OriginalOpCode opcode.AggregateOpcode - Alias string + Alias string // The alias name for the aggregation result - // The index at which the user expects to see this aggregated function. Set to nil, if the user does not ask for it - // Only used in the old Horizon Planner - Index *int + Distinct bool // Whether the aggregation function is DISTINCT - Distinct bool - - // the offsets point to columns on the same aggregator - ColOffset int - WSOffset int + // Offsets pointing to columns within the same aggregator + ColOffset int // Offset for the column being aggregated + WSOffset int // Offset for the weight string of the column - SubQueryExpression []*SubQuery + SubQueryExpression []*SubQuery // Subqueries associated with this aggregation } ) @@ -97,7 +93,7 @@ func (aggr Aggr) NeedsWeightString(ctx *plancontext.PlanningContext) bool { func (aggr Aggr) GetTypeCollation(ctx *plancontext.PlanningContext) evalengine.Type { if aggr.Func == nil { - return evalengine.Type{} + return evalengine.NewUnknownType() } switch aggr.OpCode { case opcode.AggregateMin, opcode.AggregateMax, opcode.AggregateSumDistinct, opcode.AggregateCountDistinct: @@ -442,14 +438,12 @@ func (qp *QueryProjection) AggregationExpressions(ctx *plancontext.PlanningConte // Here we go over the expressions we are returning. Since we know we are aggregating, // all expressions have to be either grouping expressions or aggregate expressions. // If we find an expression that is neither, we treat is as a special aggregation function AggrRandom - for idx, expr := range qp.SelectExprs { + for _, expr := range qp.SelectExprs { aliasedExpr, err := expr.GetAliasedExpr() if err != nil { panic(err) } - idxCopy := idx - if !ContainsAggr(ctx, expr.Col) { getExpr, err := expr.GetExpr() if err != nil { @@ -457,7 +451,6 @@ func (qp *QueryProjection) AggregationExpressions(ctx *plancontext.PlanningConte } if !qp.isExprInGroupByExprs(ctx, getExpr) { aggr := NewAggr(opcode.AggregateAnyValue, nil, aliasedExpr, aliasedExpr.ColumnName()) - aggr.Index = &idxCopy out = append(out, aggr) } continue @@ -466,14 +459,13 @@ func (qp *QueryProjection) AggregationExpressions(ctx *plancontext.PlanningConte panic(vterrors.VT12001("in scatter query: complex aggregate expression")) } - sqlparser.CopyOnRewrite(aliasedExpr.Expr, qp.extractAggr(ctx, idx, aliasedExpr, addAggr, makeComplex), nil, nil) + sqlparser.CopyOnRewrite(aliasedExpr.Expr, qp.extractAggr(ctx, aliasedExpr, addAggr, makeComplex), nil, nil) } return } func (qp *QueryProjection) extractAggr( ctx *plancontext.PlanningContext, - idx int, aliasedExpr *sqlparser.AliasedExpr, addAggr func(a Aggr), makeComplex func(), @@ -489,7 +481,6 @@ func (qp *QueryProjection) extractAggr( ae = aliasedExpr } aggrFunc := createAggrFromAggrFunc(aggr, ae) - aggrFunc.Index = &idx addAggr(aggrFunc) return false } @@ -497,7 +488,6 @@ func (qp *QueryProjection) extractAggr( // If we are here, we have a function that is an aggregation but not parsed into an AggrFunc. // This is the case for UDFs - we have to be careful with these because we can't evaluate them in VTGate. aggr := NewAggr(opcode.AggregateUDF, nil, aeWrap(ex), "") - aggr.Index = &idx addAggr(aggr) return false } @@ -507,7 +497,6 @@ func (qp *QueryProjection) extractAggr( } if !qp.isExprInGroupByExprs(ctx, ex) { aggr := NewAggr(opcode.AggregateAnyValue, nil, aeWrap(ex), "") - aggr.Index = &idx addAggr(aggr) } return false diff --git a/go/vt/vtgate/planbuilder/operators/subquery.go b/go/vt/vtgate/planbuilder/operators/subquery.go index a950c3720c2..5ae0fb52e7f 100644 --- a/go/vt/vtgate/planbuilder/operators/subquery.go +++ b/go/vt/vtgate/planbuilder/operators/subquery.go @@ -309,18 +309,3 @@ func (sq *SubQuery) mapExpr(f func(expr sqlparser.Expr) sqlparser.Expr) { sq.Original = f(sq.Original) sq.originalSubquery = f(sq.originalSubquery).(*sqlparser.Subquery) } - -func (sq *SubQuery) rewriteColNameToArgument(expr sqlparser.Expr) sqlparser.Expr { - pre := func(cursor *sqlparser.Cursor) bool { - colName, ok := cursor.Node().(*sqlparser.ColName) - if !ok || colName.Qualifier.NonEmpty() || !colName.Name.EqualString(sq.ArgName) { - // we only want to rewrite the column name to an argument if it's the right column - return true - } - - cursor.Replace(sqlparser.NewArgument(sq.ArgName)) - return true - } - - return sqlparser.Rewrite(expr, pre, nil).(sqlparser.Expr) -} diff --git a/go/vt/vtgate/planbuilder/operators/subquery_planning.go b/go/vt/vtgate/planbuilder/operators/subquery_planning.go index cdc0b8b191a..5a0aed3f10d 100644 --- a/go/vt/vtgate/planbuilder/operators/subquery_planning.go +++ b/go/vt/vtgate/planbuilder/operators/subquery_planning.go @@ -364,36 +364,72 @@ func rewriteOriginalPushedToRHS(ctx *plancontext.PlanningContext, expression sql // need to find the argument name for it and use that instead // we can't use the column name directly, because we're in the RHS of the join name := outer.findOrAddColNameBindVarName(ctx, col) - cursor.Replace(sqlparser.NewArgument(name)) + typ, _ := ctx.TypeForExpr(col) + arg := sqlparser.NewTypedArgument(name, typ.Type()) + arg.Scale = typ.Scale() + arg.Size = typ.Size() + cursor.Replace(arg) }, nil) return result.(sqlparser.Expr) } -func rewriteColNameToArgument(ctx *plancontext.PlanningContext, in sqlparser.Expr, se SubQueryExpression, subqueries ...*SubQuery) sqlparser.Expr { +// rewriteColNameToArgument rewrites the column names in the expression to use the argument names instead +// this is used when we push an operator from above the subquery into the outer side of the subquery +func rewriteColNameToArgument( + ctx *plancontext.PlanningContext, + in sqlparser.Expr, // the expression to rewrite + se SubQueryExpression, // the subquery expression we are rewriting + subqueries ...*SubQuery, // the inner subquery operators +) sqlparser.Expr { + // the visitor function that will rewrite the expression tree + // it will be invoked on unqualified column names, and replace them with arguments + // when the column is representing a subquery rewriteIt := func(s string) sqlparser.SQLNode { - for _, sq1 := range se { - if sq1.ArgName != s && sq1.HasValuesName != s { - continue + var sq1, sq2 *SubQuery + for _, sq := range se { + if sq.ArgName == s || sq.HasValuesName == s { + sq1 = sq + break + } + } + for _, sq := range subqueries { + if s == sq.ArgName { + sq2 = sq + break } + } - for _, sq2 := range subqueries { - if s == sq2.ArgName { - switch { - case sq1.FilterType.NeedsListArg(): - return sqlparser.NewListArg(s) - case sq1.FilterType == opcode.PulloutExists: - if sq1.HasValuesName == "" { - sq1.HasValuesName = ctx.ReservedVars.ReserveHasValuesSubQuery() - sq2.HasValuesName = sq1.HasValuesName - } - return sqlparser.NewArgument(sq1.HasValuesName) - default: - return sqlparser.NewArgument(s) - } - } + if sq1 == nil || sq2 == nil { + return nil + } + + switch { + case sq1.FilterType.NeedsListArg(): + return sqlparser.NewListArg(s) + case sq1.FilterType == opcode.PulloutExists: + if sq1.HasValuesName == "" { + sq1.HasValuesName = ctx.ReservedVars.ReserveHasValuesSubQuery() + sq2.HasValuesName = sq1.HasValuesName + } + return sqlparser.NewArgument(sq1.HasValuesName) + default: + // for scalar value subqueries, the argument is typed based on the first expression in the subquery + // so here we make an attempt at figuring out the type of the argument + ae, isAe := sq2.originalSubquery.Select.GetColumns()[0].(*sqlparser.AliasedExpr) + if !isAe { + return sqlparser.NewArgument(s) } + + argType, found := ctx.TypeForExpr(ae.Expr) + if !found { + return sqlparser.NewArgument(s) + } + + arg := sqlparser.NewTypedArgument(s, argType.Type()) + arg.Scale = argType.Scale() + arg.Size = argType.Size() + return arg } - return nil } // replace the ColNames with Argument inside the subquery diff --git a/go/vt/vtgate/planbuilder/plancontext/planning_context.go b/go/vt/vtgate/planbuilder/plancontext/planning_context.go index 90a6bdac6f8..2f33539f858 100644 --- a/go/vt/vtgate/planbuilder/plancontext/planning_context.go +++ b/go/vt/vtgate/planbuilder/plancontext/planning_context.go @@ -17,6 +17,7 @@ limitations under the License. package plancontext import ( + "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" @@ -222,3 +223,12 @@ func (ctx *PlanningContext) TypeForExpr(e sqlparser.Expr) (evalengine.Type, bool } return t, true } + +// SQLTypeForExpr returns the sql type of the given expression, with nullable set if the expression is from an outer table. +func (ctx *PlanningContext) SQLTypeForExpr(e sqlparser.Expr) sqltypes.Type { + t, found := ctx.TypeForExpr(e) + if !found { + return sqltypes.Unknown + } + return t.Type() +} diff --git a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json index a272954725d..6942464665c 100644 --- a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json @@ -3372,7 +3372,7 @@ "Sharded": true }, "FieldQuery": "select 1 from user_extra where 1 != 1 group by .0", - "Query": "select 1 from user_extra where user_extra.col = :user_col group by .0", + "Query": "select 1 from user_extra where user_extra.col = :user_col /* INT16 */ group by .0", "Table": "user_extra" } ] @@ -3862,7 +3862,7 @@ "Sharded": true }, "FieldQuery": "select count(*) from user_extra as ue where 1 != 1 group by .0", - "Query": "select count(*) from user_extra as ue where ue.col = :u_col group by .0", + "Query": "select count(*) from user_extra as ue where ue.col = :u_col /* INT16 */ group by .0", "Table": "user_extra" } ] @@ -3922,7 +3922,7 @@ "Sharded": true }, "FieldQuery": "select count(ue.id) from user_extra as ue where 1 != 1 group by .0", - "Query": "select count(ue.id) from user_extra as ue where ue.col = :u_col group by .0", + "Query": "select count(ue.id) from user_extra as ue where ue.col = :u_col /* INT16 */ group by .0", "Table": "user_extra" } ] @@ -5153,7 +5153,7 @@ "Sharded": true }, "FieldQuery": "select min(user_extra.foo), max(user_extra.bar), weight_string(user_extra.foo), weight_string(user_extra.bar) from user_extra where 1 != 1 group by .0, weight_string(user_extra.foo), weight_string(user_extra.bar)", - "Query": "select min(user_extra.foo), max(user_extra.bar), weight_string(user_extra.foo), weight_string(user_extra.bar) from user_extra where user_extra.bar = :user_col group by .0, weight_string(user_extra.foo), weight_string(user_extra.bar)", + "Query": "select min(user_extra.foo), max(user_extra.bar), weight_string(user_extra.foo), weight_string(user_extra.bar) from user_extra where user_extra.bar = :user_col /* INT16 */ group by .0, weight_string(user_extra.foo), weight_string(user_extra.bar)", "Table": "user_extra" } ] @@ -5431,7 +5431,7 @@ "Sharded": true }, "FieldQuery": "select count(*), sum(user_extra.bar) from user_extra where 1 != 1 group by .0", - "Query": "select count(*), sum(user_extra.bar) from user_extra where user_extra.col = :user_col group by .0", + "Query": "select count(*), sum(user_extra.bar) from user_extra where user_extra.col = :user_col /* INT16 */ group by .0", "Table": "user_extra" } ] @@ -5781,7 +5781,7 @@ "Sharded": true }, "FieldQuery": "select 1 from music as m where 1 != 1", - "Query": "select 1 from music as m where m.col = :u_col", + "Query": "select 1 from music as m where m.col = :u_col /* INT16 */", "Table": "music" } ] @@ -6114,8 +6114,8 @@ "Name": "main", "Sharded": false }, - "FieldQuery": "select :__sq1 + :__sq2 as `(select count(*) from ``user``) + (select count(*) from user_extra)` from dual where 1 != 1", - "Query": "select :__sq1 + :__sq2 as `(select count(*) from ``user``) + (select count(*) from user_extra)` from dual", + "FieldQuery": "select :__sq1 /* INT64 */ + :__sq2 /* INT64 */ as `(select count(*) from ``user``) + (select count(*) from user_extra)` from dual where 1 != 1", + "Query": "select :__sq1 /* INT64 */ + :__sq2 /* INT64 */ as `(select count(*) from ``user``) + (select count(*) from user_extra)` from dual", "Table": "dual" } ] @@ -6764,8 +6764,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select max(:__sq1), weight_string(:__sq1) from `user` where 1 != 1 group by weight_string(:__sq1)", - "Query": "select max(:__sq1), weight_string(:__sq1) from `user` where id = 2 group by weight_string(:__sq1)", + "FieldQuery": "select max(:__sq1 /* INT16 */), weight_string(:__sq1 /* INT16 */) from `user` where 1 != 1 group by weight_string(:__sq1 /* INT16 */)", + "Query": "select max(:__sq1 /* INT16 */), weight_string(:__sq1 /* INT16 */) from `user` where id = 2 group by weight_string(:__sq1 /* INT16 */)", "Table": "`user`", "Values": [ "2" diff --git a/go/vt/vtgate/planbuilder/testdata/cte_cases.json b/go/vt/vtgate/planbuilder/testdata/cte_cases.json index 09d155b19f6..4a69fd85fad 100644 --- a/go/vt/vtgate/planbuilder/testdata/cte_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/cte_cases.json @@ -1132,7 +1132,7 @@ "Sharded": true }, "FieldQuery": "select 1 from user_extra where 1 != 1", - "Query": "select 1 from user_extra where user_extra.col = :user_col", + "Query": "select 1 from user_extra where user_extra.col = :user_col /* INT16 */", "Table": "user_extra" } ] diff --git a/go/vt/vtgate/planbuilder/testdata/dml_cases.json b/go/vt/vtgate/planbuilder/testdata/dml_cases.json index 9c2ed1920ee..b5d0fa8951f 100644 --- a/go/vt/vtgate/planbuilder/testdata/dml_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/dml_cases.json @@ -5165,7 +5165,7 @@ "Sharded": true }, "FieldQuery": "select 1 from music as m where 1 != 1", - "Query": "select 1 from music as m where m.col = :u_col", + "Query": "select 1 from music as m where m.col = :u_col /* INT16 */", "Table": "music" } ] @@ -5335,7 +5335,7 @@ "Sharded": true }, "FieldQuery": "select 1 from music as m, user_extra as ue where 1 != 1", - "Query": "select 1 from music as m, user_extra as ue where m.bar = 40 and m.col = :u_col and ue.foo = 20 and m.user_id = ue.user_id", + "Query": "select 1 from music as m, user_extra as ue where m.bar = 40 and m.col = :u_col /* INT16 */ and ue.foo = 20 and m.user_id = ue.user_id", "Table": "music, user_extra" } ] @@ -5408,7 +5408,7 @@ "Sharded": true }, "FieldQuery": "select m.id from music as m, user_extra as ue where 1 != 1", - "Query": "select m.id from music as m, user_extra as ue where m.bar = 40 and m.col = :u_col and ue.foo = 20 and m.user_id = ue.user_id", + "Query": "select m.id from music as m, user_extra as ue where m.bar = 40 and m.col = :u_col /* INT16 */ and ue.foo = 20 and m.user_id = ue.user_id", "Table": "music, user_extra" } ] @@ -5880,7 +5880,7 @@ "Sharded": true }, "TargetTabletType": "PRIMARY", - "Query": "update `user` as u set u.col = :ue_col where u.id in ::dml_vals", + "Query": "update `user` as u set u.col = :ue_col /* INT16 */ where u.id in ::dml_vals", "Table": "user", "Values": [ "::dml_vals" @@ -6483,7 +6483,7 @@ "Sharded": true }, "FieldQuery": "select m.id from music as m where 1 != 1", - "Query": "select m.id from music as m where m.baz = 21 and m.bar = :u_foo and m.col = :u_col for update", + "Query": "select m.id from music as m where m.baz = 21 and m.bar = :u_foo and m.col = :u_col /* INT16 */ for update", "Table": "music" } ] @@ -6738,7 +6738,7 @@ "Sharded": true }, "FieldQuery": "select m.id from music as m where 1 != 1", - "Query": "select m.id from music as m where m.col = :u_col for update", + "Query": "select m.id from music as m where m.col = :u_col /* INT16 */ for update", "Table": "music" } ] @@ -6961,7 +6961,7 @@ "Sharded": true }, "FieldQuery": "select m.id from music as m where 1 != 1", - "Query": "select m.id from music as m where m.col = :u_col for update", + "Query": "select m.id from music as m where m.col = :u_col /* INT16 */ for update", "Table": "music" } ] diff --git a/go/vt/vtgate/planbuilder/testdata/filter_cases.json b/go/vt/vtgate/planbuilder/testdata/filter_cases.json index d36c060ed6d..b60e8812dda 100644 --- a/go/vt/vtgate/planbuilder/testdata/filter_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/filter_cases.json @@ -1141,7 +1141,7 @@ "Sharded": true }, "FieldQuery": "select user_extra.id from user_extra where 1 != 1", - "Query": "select user_extra.id from user_extra where user_extra.col = :user_col", + "Query": "select user_extra.id from user_extra where user_extra.col = :user_col /* INT16 */", "Table": "user_extra" } ] @@ -1213,7 +1213,7 @@ "Sharded": true }, "FieldQuery": "select user_extra.id from user_extra where 1 != 1", - "Query": "select user_extra.id from user_extra where user_extra.user_id = :user_col and user_extra.col = :user_col", + "Query": "select user_extra.id from user_extra where user_extra.user_id = :user_col /* INT16 */ and user_extra.col = :user_col /* INT16 */", "Table": "user_extra", "Values": [ ":user_col" @@ -1262,7 +1262,7 @@ "Sharded": true }, "FieldQuery": "select user_extra.id from user_extra where 1 != 1", - "Query": "select user_extra.id from user_extra where user_extra.col = :user_col and 1 = 1", + "Query": "select user_extra.id from user_extra where user_extra.col = :user_col /* INT16 */ and 1 = 1", "Table": "user_extra" } ] @@ -1614,7 +1614,7 @@ "Sharded": true }, "FieldQuery": "select u.m from `user` as u where 1 != 1", - "Query": "select u.m from `user` as u where u.id in ::__vals and u.id in (select m2 from `user` where `user`.id = u.id and `user`.col = :user_extra_col)", + "Query": "select u.m from `user` as u where u.id in ::__vals and u.id in (select m2 from `user` where `user`.id = u.id and `user`.col = :user_extra_col /* INT16 */)", "Table": "`user`", "Values": [ "(:user_extra_col, 1)" @@ -1758,7 +1758,7 @@ "Sharded": true }, "FieldQuery": "select u.m from `user` as u where 1 != 1", - "Query": "select u.m from `user` as u where u.id in ::__vals and u.id in (select m2 from `user` where `user`.id = u.id and `user`.col = :user_extra_col and `user`.id in (select m3 from user_extra where user_extra.user_id = `user`.id))", + "Query": "select u.m from `user` as u where u.id in ::__vals and u.id in (select m2 from `user` where `user`.id = u.id and `user`.col = :user_extra_col /* INT16 */ and `user`.id in (select m3 from user_extra where user_extra.user_id = `user`.id))", "Table": "`user`", "Values": [ "(:user_extra_col, 1)" @@ -3100,7 +3100,7 @@ "Sharded": true }, "FieldQuery": "select id from `user` where 1 != 1", - "Query": "select id from `user` where `user`.id = :user_extra_col", + "Query": "select id from `user` where `user`.id = :user_extra_col /* INT16 */", "Table": "`user`", "Values": [ ":user_extra_col" @@ -3171,7 +3171,7 @@ "Sharded": true }, "FieldQuery": "select 1 from user_extra where 1 != 1", - "Query": "select 1 from user_extra where user_extra.foobar = 5 and user_extra.col = :user_col", + "Query": "select 1 from user_extra where user_extra.foobar = 5 and user_extra.col = :user_col /* INT16 */", "Table": "user_extra" } ] @@ -3226,7 +3226,7 @@ "Sharded": true }, "FieldQuery": "select user_extra.id from user_extra where 1 != 1", - "Query": "select user_extra.id from user_extra where user_extra.col = :user_col", + "Query": "select user_extra.id from user_extra where user_extra.col = :user_col /* INT16 */", "Table": "user_extra" } ] @@ -4228,7 +4228,7 @@ "Sharded": true }, "FieldQuery": "select count(*) from `user` as b where 1 != 1 group by .0", - "Query": "select count(*) from `user` as b where b.textcol2 = :a_textcol1 group by .0", + "Query": "select count(*) from `user` as b where b.textcol2 = :a_textcol1 /* VARCHAR */ group by .0", "Table": "`user`" } ] diff --git a/go/vt/vtgate/planbuilder/testdata/from_cases.json b/go/vt/vtgate/planbuilder/testdata/from_cases.json index 81381f3d7d7..6db17511a2a 100644 --- a/go/vt/vtgate/planbuilder/testdata/from_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/from_cases.json @@ -585,7 +585,7 @@ "Sharded": false }, "FieldQuery": "select m1.col from unsharded as m1 where 1 != 1", - "Query": "select m1.col from unsharded as m1 where m1.col = :user_col", + "Query": "select m1.col from unsharded as m1 where m1.col = :user_col /* INT16 */", "Table": "unsharded" } ] @@ -651,7 +651,7 @@ "Sharded": true }, "FieldQuery": "select e.col from user_extra as e where 1 != 1", - "Query": "select e.col from user_extra as e where e.col = :user_col", + "Query": "select e.col from user_extra as e where e.col = :user_col /* INT16 */", "Table": "user_extra" }, { @@ -662,7 +662,7 @@ "Sharded": false }, "FieldQuery": "select 1 from unsharded as m1 where 1 != 1", - "Query": "select 1 from unsharded as m1 where m1.col = :e_col", + "Query": "select 1 from unsharded as m1 where m1.col = :e_col /* INT16 */", "Table": "unsharded" } ] @@ -1221,7 +1221,7 @@ "Sharded": true }, "FieldQuery": "select `user`.col from `user` where 1 != 1", - "Query": "select `user`.col from `user` where `user`.id = :user_extra_col", + "Query": "select `user`.col from `user` where `user`.id = :user_extra_col /* INT16 */", "Table": "`user`", "Values": [ ":user_extra_col" @@ -1924,7 +1924,7 @@ "Sharded": true }, "FieldQuery": "select 1 from user_extra where 1 != 1", - "Query": "select 1 from user_extra where user_extra.col = :user_col", + "Query": "select 1 from user_extra where user_extra.col = :user_col /* INT16 */", "Table": "user_extra" } ] @@ -3244,7 +3244,7 @@ "Sharded": true }, "FieldQuery": "select 1 from `user` as uu where 1 != 1", - "Query": "select 1 from `user` as uu where uu.intcol = :u_intcol", + "Query": "select 1 from `user` as uu where uu.intcol = :u_intcol /* INT16 */", "Table": "`user`" } ] @@ -3357,7 +3357,7 @@ "Sharded": true }, "FieldQuery": "select user_extra.col from user_extra where 1 != 1", - "Query": "select user_extra.col from user_extra where user_extra.col = :user_col", + "Query": "select user_extra.col from user_extra where user_extra.col = :user_col /* INT16 */", "Table": "user_extra" } ] @@ -3594,7 +3594,7 @@ "Sharded": true }, "FieldQuery": "select user_extra.col from user_extra where 1 != 1", - "Query": "select user_extra.col from user_extra where user_extra.col = :user_col", + "Query": "select user_extra.col from user_extra where user_extra.col = :user_col /* INT16 */", "Table": "user_extra" } ] @@ -3654,7 +3654,7 @@ "Sharded": true }, "FieldQuery": "select user_extra.col from user_extra where 1 != 1", - "Query": "select user_extra.col from user_extra where user_extra.col = :user_col", + "Query": "select user_extra.col from user_extra where user_extra.col = :user_col /* INT16 */", "Table": "user_extra" } ] @@ -3720,7 +3720,7 @@ "Sharded": true }, "FieldQuery": "select user_extra.col from user_extra where 1 != 1", - "Query": "select user_extra.col from user_extra where user_extra.col = :user_col", + "Query": "select user_extra.col from user_extra where user_extra.col = :user_col /* INT16 */", "Table": "user_extra" } ] @@ -3799,7 +3799,7 @@ "Sharded": false }, "FieldQuery": "select unsharded_authoritative.col2 from unsharded_authoritative where 1 != 1", - "Query": "select unsharded_authoritative.col2 from unsharded_authoritative where unsharded_authoritative.col1 = :authoritative_col1", + "Query": "select unsharded_authoritative.col2 from unsharded_authoritative where unsharded_authoritative.col1 = :authoritative_col1 /* VARCHAR */", "Table": "unsharded_authoritative" } ] @@ -3921,7 +3921,7 @@ "Sharded": true }, "FieldQuery": "select user_extra.col from user_extra where 1 != 1", - "Query": "select user_extra.col from user_extra where user_extra.col = :user_col", + "Query": "select user_extra.col from user_extra where user_extra.col = :user_col /* INT16 */", "Table": "user_extra" } ] @@ -4430,7 +4430,7 @@ "Sharded": true }, "FieldQuery": "select 1 from music as m where 1 != 1", - "Query": "select 1 from music as m where m.user_id = 5 and m.id = 20 and m.col = :u_col", + "Query": "select 1 from music as m where m.user_id = 5 and m.id = 20 and m.col = :u_col /* INT16 */", "Table": "music", "Values": [ "20" diff --git a/go/vt/vtgate/planbuilder/testdata/info_schema57_cases.json b/go/vt/vtgate/planbuilder/testdata/info_schema57_cases.json index 12ddfa6e049..31246a2f40f 100644 --- a/go/vt/vtgate/planbuilder/testdata/info_schema57_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/info_schema57_cases.json @@ -319,7 +319,7 @@ "Sharded": false }, "FieldQuery": "select rc.delete_rule as delete_rule, rc.update_rule as update_rule from information_schema.referential_constraints as rc where 1 != 1", - "Query": "select rc.delete_rule as delete_rule, rc.update_rule as update_rule from information_schema.referential_constraints as rc where rc.constraint_schema = :__vtschemaname /* VARCHAR */ and rc.constraint_name = :kcu_constraint_name", + "Query": "select rc.delete_rule as delete_rule, rc.update_rule as update_rule from information_schema.referential_constraints as rc where rc.constraint_schema = :__vtschemaname /* VARCHAR */ and rc.constraint_name = :kcu_constraint_name /* VARCHAR(64) */", "SysTableTableSchema": "[:v2]", "Table": "information_schema.referential_constraints" } @@ -723,7 +723,7 @@ "Sharded": true }, "FieldQuery": "select 1 from `user` where 1 != 1", - "Query": "select 1 from `user` where `user`.id = :x_COLUMN_NAME", + "Query": "select 1 from `user` where `user`.id = :x_COLUMN_NAME /* VARCHAR(64) */", "Table": "`user`", "Values": [ ":x_COLUMN_NAME" diff --git a/go/vt/vtgate/planbuilder/testdata/info_schema80_cases.json b/go/vt/vtgate/planbuilder/testdata/info_schema80_cases.json index 3eec3685fd2..9553210174c 100644 --- a/go/vt/vtgate/planbuilder/testdata/info_schema80_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/info_schema80_cases.json @@ -319,7 +319,7 @@ "Sharded": false }, "FieldQuery": "select rc.delete_rule as delete_rule, rc.update_rule as update_rule from information_schema.referential_constraints as rc where 1 != 1", - "Query": "select rc.delete_rule as delete_rule, rc.update_rule as update_rule from information_schema.referential_constraints as rc where rc.constraint_schema = :__vtschemaname /* VARCHAR */ and rc.constraint_name = :kcu_constraint_name", + "Query": "select rc.delete_rule as delete_rule, rc.update_rule as update_rule from information_schema.referential_constraints as rc where rc.constraint_schema = :__vtschemaname /* VARCHAR */ and rc.constraint_name = :kcu_constraint_name /* VARCHAR(64) */", "SysTableTableSchema": "[:v2]", "Table": "information_schema.referential_constraints" } @@ -445,7 +445,7 @@ "Sharded": false }, "FieldQuery": "select 1 from information_schema.table_constraints as tc where 1 != 1", - "Query": "select 1 from information_schema.table_constraints as tc where tc.table_schema = :__vtschemaname /* VARCHAR */ and tc.table_name = :tc_table_name /* VARCHAR */ and tc.constraint_name = :cc_constraint_name and tc.constraint_schema = :__vtschemaname /* VARCHAR */", + "Query": "select 1 from information_schema.table_constraints as tc where tc.table_schema = :__vtschemaname /* VARCHAR */ and tc.table_name = :tc_table_name /* VARCHAR */ and tc.constraint_name = :cc_constraint_name /* VARCHAR(64) */ and tc.constraint_schema = :__vtschemaname /* VARCHAR */", "SysTableTableName": "[tc_table_name:'table_name']", "SysTableTableSchema": "['table_schema', :cc_constraint_schema]", "Table": "information_schema.table_constraints" @@ -788,7 +788,7 @@ "Sharded": true }, "FieldQuery": "select 1 from `user` where 1 != 1", - "Query": "select 1 from `user` where `user`.id = :x_COLUMN_NAME", + "Query": "select 1 from `user` where `user`.id = :x_COLUMN_NAME /* VARCHAR(64) */", "Table": "`user`", "Values": [ ":x_COLUMN_NAME" diff --git a/go/vt/vtgate/planbuilder/testdata/postprocess_cases.json b/go/vt/vtgate/planbuilder/testdata/postprocess_cases.json index 454740f0498..010e22c2108 100644 --- a/go/vt/vtgate/planbuilder/testdata/postprocess_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/postprocess_cases.json @@ -949,7 +949,7 @@ "Sharded": true }, "FieldQuery": "select e.id from user_extra as e where 1 != 1", - "Query": "select e.id from user_extra as e where e.col = :u_col", + "Query": "select e.id from user_extra as e where e.col = :u_col /* INT16 */", "Table": "user_extra" } ] @@ -2060,8 +2060,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select coalesce(:user_col, user_extra.col), weight_string(coalesce(:user_col, user_extra.col)) from user_extra where 1 != 1", - "Query": "select coalesce(:user_col, user_extra.col), weight_string(coalesce(:user_col, user_extra.col)) from user_extra", + "FieldQuery": "select coalesce(:user_col /* INT16 */, user_extra.col), weight_string(coalesce(:user_col /* INT16 */, user_extra.col)) from user_extra where 1 != 1", + "Query": "select coalesce(:user_col /* INT16 */, user_extra.col), weight_string(coalesce(:user_col /* INT16 */, user_extra.col)) from user_extra", "Table": "user_extra" } ] diff --git a/go/vt/vtgate/planbuilder/testdata/rails_cases.json b/go/vt/vtgate/planbuilder/testdata/rails_cases.json index c8ab8b7b9d8..3887547e628 100644 --- a/go/vt/vtgate/planbuilder/testdata/rails_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/rails_cases.json @@ -62,7 +62,7 @@ "Sharded": true }, "FieldQuery": "select 1 from book6s_order2s where 1 != 1", - "Query": "select 1 from book6s_order2s where book6s_order2s.order2_id = :order2s_id and book6s_order2s.book6_id = :book6s_id", + "Query": "select 1 from book6s_order2s where book6s_order2s.order2_id = :order2s_id /* INT64 */ and book6s_order2s.book6_id = :book6s_id /* INT64 */", "Table": "book6s_order2s", "Values": [ ":book6s_id" @@ -79,7 +79,7 @@ "Sharded": true }, "FieldQuery": "select 1 from supplier5s where 1 != 1", - "Query": "select 1 from supplier5s where supplier5s.id = :book6s_supplier5_id", + "Query": "select 1 from supplier5s where supplier5s.id = :book6s_supplier5_id /* INT64 */", "Table": "supplier5s", "Values": [ ":book6s_supplier5_id" diff --git a/go/vt/vtgate/planbuilder/testdata/select_cases.json b/go/vt/vtgate/planbuilder/testdata/select_cases.json index fdb189d067b..51aae618daf 100644 --- a/go/vt/vtgate/planbuilder/testdata/select_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/select_cases.json @@ -1418,8 +1418,8 @@ "Name": "main", "Sharded": false }, - "FieldQuery": "select a, :__sq1 as `(select col from ``user``)` from unsharded where 1 != 1", - "Query": "select a, :__sq1 as `(select col from ``user``)` from unsharded", + "FieldQuery": "select a, :__sq1 /* INT16 */ as `(select col from ``user``)` from unsharded where 1 != 1", + "Query": "select a, :__sq1 /* INT16 */ as `(select col from ``user``)` from unsharded", "Table": "unsharded" } ] @@ -1463,8 +1463,8 @@ "Name": "main", "Sharded": false }, - "FieldQuery": "select a, 1 + :__sq1 as `1 + (select col from ``user``)` from unsharded where 1 != 1", - "Query": "select a, 1 + :__sq1 as `1 + (select col from ``user``)` from unsharded", + "FieldQuery": "select a, 1 + :__sq1 /* INT16 */ as `1 + (select col from ``user``)` from unsharded where 1 != 1", + "Query": "select a, 1 + :__sq1 /* INT16 */ as `1 + (select col from ``user``)` from unsharded", "Table": "unsharded" } ] @@ -2233,7 +2233,7 @@ "Sharded": true }, "FieldQuery": "select `user`.col from `user` where 1 != 1", - "Query": "select `user`.col from `user` where `user`.col = :t_title and `user`.id <= 4", + "Query": "select `user`.col from `user` where `user`.col = :t_title /* VARCHAR */ and `user`.id <= 4", "Table": "`user`" } ] @@ -2510,8 +2510,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select t.a from (select :__sq1 as a from `user` where 1 != 1) as t where 1 != 1", - "Query": "select t.a from (select :__sq1 as a from `user`) as t", + "FieldQuery": "select t.a from (select :__sq1 /* INT16 */ as a from `user` where 1 != 1) as t where 1 != 1", + "Query": "select t.a from (select :__sq1 /* INT16 */ as a from `user`) as t", "Table": "`user`" } ] @@ -2580,8 +2580,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select :__sq1 as a from `user` where 1 != 1", - "Query": "select :__sq1 as a from `user`", + "FieldQuery": "select :__sq1 /* INT16 */ as a from `user` where 1 != 1", + "Query": "select :__sq1 /* INT16 */ as a from `user`", "Table": "`user`" } ] @@ -2717,6 +2717,58 @@ ] } }, + { + "comment": "PullOut subquery with an aggregation that should be typed in the final output", + "query": "select (select sum(col) from user) from user_extra", + "plan": { + "QueryType": "SELECT", + "Original": "select (select sum(col) from user) from user_extra", + "Instructions": { + "OperatorType": "UncorrelatedSubquery", + "Variant": "PulloutValue", + "PulloutVars": [ + "__sq1" + ], + "Inputs": [ + { + "InputName": "SubQuery", + "OperatorType": "Aggregate", + "Variant": "Scalar", + "Aggregates": "sum(0) AS sum(col)", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select sum(col) from `user` where 1 != 1", + "Query": "select sum(col) from `user`", + "Table": "`user`" + } + ] + }, + { + "InputName": "Outer", + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select CAST(:__sq1 AS DECIMAL(0, 0)) as `(select sum(col) from ``user``)` from user_extra where 1 != 1", + "Query": "select CAST(:__sq1 AS DECIMAL(0, 0)) as `(select sum(col) from ``user``)` from user_extra", + "Table": "user_extra" + } + ] + }, + "TablesUsed": [ + "user.user", + "user.user_extra" + ] + } + }, { "comment": "Straight Join preserved in MySQL query", "query": "select user.id, user_extra.user_id from user straight_join user_extra where user.id = user_extra.user_id", @@ -2915,7 +2967,7 @@ "Sharded": true }, "FieldQuery": "select 1 from user_extra as ue where 1 != 1", - "Query": "select 1 from user_extra as ue where ue.col = :u1_col and ue.col = :u2_col limit 1", + "Query": "select 1 from user_extra as ue where ue.col = :u1_col /* INT16 */ and ue.col = :u2_col /* INT16 */ limit 1", "Table": "user_extra" } ] @@ -2972,7 +3024,7 @@ "Sharded": true }, "FieldQuery": "select 1 from user_extra as ue where 1 != 1", - "Query": "select 1 from user_extra as ue where ue.col = :u_col and ue.col2 = :u_col limit 1", + "Query": "select 1 from user_extra as ue where ue.col = :u_col /* INT16 */ and ue.col2 = :u_col /* INT16 */ limit 1", "Table": "user_extra" } ] @@ -3111,8 +3163,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select :user_extra_col + `user`.col as `user_extra.col + ``user``.col` from `user` where 1 != 1", - "Query": "select :user_extra_col + `user`.col as `user_extra.col + ``user``.col` from `user` where `user`.id = :user_extra_id", + "FieldQuery": "select :user_extra_col /* INT16 */ + `user`.col as `user_extra.col + ``user``.col` from `user` where 1 != 1", + "Query": "select :user_extra_col /* INT16 */ + `user`.col as `user_extra.col + ``user``.col` from `user` where `user`.id = :user_extra_id", "Table": "`user`", "Values": [ ":user_extra_id" @@ -3689,7 +3741,7 @@ "Sharded": true }, "FieldQuery": "select user_metadata.user_id from user_extra, user_metadata where 1 != 1", - "Query": "select user_metadata.user_id from user_extra, user_metadata where user_extra.col = :user_col and user_extra.user_id = user_metadata.user_id", + "Query": "select user_metadata.user_id from user_extra, user_metadata where user_extra.col = :user_col /* INT16 */ and user_extra.user_id = user_metadata.user_id", "Table": "user_extra, user_metadata" } ] diff --git a/go/vt/vtgate/planbuilder/testdata/vindex_func_cases.json b/go/vt/vtgate/planbuilder/testdata/vindex_func_cases.json index de5356346b2..3ac35761051 100644 --- a/go/vt/vtgate/planbuilder/testdata/vindex_func_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/vindex_func_cases.json @@ -265,7 +265,7 @@ "Sharded": false }, "FieldQuery": "select unsharded.id from unsharded where 1 != 1", - "Query": "select unsharded.id from unsharded where unsharded.id = :user_index_id", + "Query": "select unsharded.id from unsharded where unsharded.id = :user_index_id /* VARBINARY */", "Table": "unsharded" } ] @@ -313,7 +313,7 @@ "Sharded": false }, "FieldQuery": "select unsharded.id from unsharded where 1 != 1", - "Query": "select unsharded.id from unsharded where unsharded.id = :user_index_id", + "Query": "select unsharded.id from unsharded where unsharded.id = :user_index_id /* VARBINARY */", "Table": "unsharded" } ] @@ -361,7 +361,7 @@ "Sharded": false }, "FieldQuery": "select unsharded.id from unsharded where 1 != 1", - "Query": "select unsharded.id from unsharded where unsharded.id = :user_index_id", + "Query": "select unsharded.id from unsharded where unsharded.id = :user_index_id /* VARBINARY */", "Table": "unsharded" } ] @@ -409,7 +409,7 @@ "Sharded": false }, "FieldQuery": "select unsharded.id from unsharded where 1 != 1", - "Query": "select unsharded.id from unsharded where unsharded.id = :ui_id", + "Query": "select unsharded.id from unsharded where unsharded.id = :ui_id /* VARBINARY */", "Table": "unsharded" } ] diff --git a/go/vt/vtgate/planbuilder/testdata/wireup_cases.json b/go/vt/vtgate/planbuilder/testdata/wireup_cases.json index 3aca1f1dc66..62a3e65a35f 100644 --- a/go/vt/vtgate/planbuilder/testdata/wireup_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/wireup_cases.json @@ -148,7 +148,7 @@ "Sharded": true }, "FieldQuery": "select 1 from `user` as u3 where 1 != 1", - "Query": "select 1 from `user` as u3 where u3.col = :u1_col", + "Query": "select 1 from `user` as u3 where u3.col = :u1_col /* INT16 */", "Table": "`user`" } ] @@ -210,7 +210,7 @@ "Sharded": true }, "FieldQuery": "select 1 from `user` as u3 where 1 != 1", - "Query": "select 1 from `user` as u3 where u3.col = :u2_col", + "Query": "select 1 from `user` as u3 where u3.col = :u2_col /* INT16 */", "Table": "`user`" } ] @@ -265,7 +265,7 @@ "Sharded": true }, "FieldQuery": "select u1.id, u1.col from `user` as u1 where 1 != 1", - "Query": "select u1.id, u1.col from `user` as u1 where u1.col = :u3_col", + "Query": "select u1.id, u1.col from `user` as u1 where u1.col = :u3_col /* INT16 */", "Table": "`user`" }, { @@ -276,7 +276,7 @@ "Sharded": true }, "FieldQuery": "select 1 from `user` as u2 where 1 != 1", - "Query": "select 1 from `user` as u2 where u2.col = :u1_col", + "Query": "select 1 from `user` as u2 where u2.col = :u1_col /* INT16 */", "Table": "`user`" } ] @@ -348,7 +348,7 @@ "Sharded": true }, "FieldQuery": "select u1.id, u1.col from `user` as u1 where 1 != 1", - "Query": "select u1.id, u1.col from `user` as u1 where u1.col = :u4_col", + "Query": "select u1.id, u1.col from `user` as u1 where u1.col = :u4_col /* INT16 */", "Table": "`user`" }, { @@ -359,7 +359,7 @@ "Sharded": true }, "FieldQuery": "select 1 from `user` as u3 where 1 != 1", - "Query": "select 1 from `user` as u3 where u3.id = :u1_col", + "Query": "select 1 from `user` as u3 where u3.id = :u1_col /* INT16 */", "Table": "`user`", "Values": [ ":u1_col" @@ -420,7 +420,7 @@ "Sharded": true }, "FieldQuery": "select 1 from `user` as u2 where 1 != 1", - "Query": "select 1 from `user` as u2 where u2.id = :u1_col", + "Query": "select 1 from `user` as u2 where u2.id = :u1_col /* INT16 */", "Table": "`user`", "Values": [ ":u1_col" @@ -437,7 +437,7 @@ "Sharded": true }, "FieldQuery": "select 1 from `user` as u3 where 1 != 1", - "Query": "select 1 from `user` as u3 where u3.id = :u1_col", + "Query": "select 1 from `user` as u3 where u3.id = :u1_col /* INT16 */", "Table": "`user`", "Values": [ ":u1_col" @@ -591,7 +591,7 @@ "Sharded": true }, "FieldQuery": "select e.id from user_extra as e where 1 != 1", - "Query": "select e.id from user_extra as e where e.id = :u_col limit 10", + "Query": "select e.id from user_extra as e where e.id = :u_col /* INT16 */ limit 10", "Table": "user_extra" } ] @@ -658,7 +658,7 @@ "Sharded": true }, "FieldQuery": "select :u_id + e.id as `u.id + e.id` from user_extra as e where 1 != 1", - "Query": "select :u_id + e.id as `u.id + e.id` from user_extra as e where e.id = :u_col limit 10", + "Query": "select :u_id + e.id as `u.id + e.id` from user_extra as e where e.id = :u_col /* INT16 */ limit 10", "Table": "user_extra" } ] @@ -737,8 +737,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select u.id, :__sq1 as `(select col from ``user``)`, u.col from `user` as u where 1 != 1", - "Query": "select u.id, :__sq1 as `(select col from ``user``)`, u.col from `user` as u", + "FieldQuery": "select u.id, :__sq1 /* INT16 */ as `(select col from ``user``)`, u.col from `user` as u where 1 != 1", + "Query": "select u.id, :__sq1 /* INT16 */ as `(select col from ``user``)`, u.col from `user` as u", "Table": "`user`" } ] @@ -751,7 +751,7 @@ "Sharded": true }, "FieldQuery": "select e.id from user_extra as e where 1 != 1", - "Query": "select e.id from user_extra as e where e.id = :u_col", + "Query": "select e.id from user_extra as e where e.id = :u_col /* INT16 */", "Table": "user_extra" } ] diff --git a/go/vt/vtgate/semantics/binder.go b/go/vt/vtgate/semantics/binder.go index 90a36b1f0d7..78148f4bb1f 100644 --- a/go/vt/vtgate/semantics/binder.go +++ b/go/vt/vtgate/semantics/binder.go @@ -169,7 +169,7 @@ func (b *binder) findDependentTableSet(current *scope, target sqlparser.TableNam continue } ts := b.org.tableSetFor(table.GetAliasedTableExpr()) - c := createCertain(ts, ts, evalengine.Type{}) + c := createCertain(ts, ts, evalengine.NewUnknownType()) deps = deps.merge(c, false) } finalDep, err := deps.get(nil) diff --git a/go/vt/vtgate/semantics/semantic_state.go b/go/vt/vtgate/semantics/semantic_state.go index 1dcaaf87061..0544764b04f 100644 --- a/go/vt/vtgate/semantics/semantic_state.go +++ b/go/vt/vtgate/semantics/semantic_state.go @@ -671,7 +671,7 @@ func (st *SemTable) TypeForExpr(e sqlparser.Expr) (evalengine.Type, bool) { return evalengine.NewTypeEx(sqltypes.VarBinary, collations.CollationBinaryID, wt.Nullable(), 0, 0, nil), true } - return evalengine.Type{}, false + return evalengine.NewUnknownType(), false } // NeedsWeightString returns true if the given expression needs weight_string to do safe comparisons diff --git a/go/vt/vtgate/semantics/table_collector.go b/go/vt/vtgate/semantics/table_collector.go index ae107cc070c..948edb37d47 100644 --- a/go/vt/vtgate/semantics/table_collector.go +++ b/go/vt/vtgate/semantics/table_collector.go @@ -19,8 +19,6 @@ package semantics import ( "fmt" - "vitess.io/vitess/go/mysql/collations" - "vitess.io/vitess/go/sqltypes" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" querypb "vitess.io/vitess/go/vt/proto/query" @@ -234,7 +232,7 @@ for2: continue for2 } } - types = append(types, evalengine.NewType(sqltypes.Unknown, collations.Unknown)) + types = append(types, evalengine.NewUnknownType()) } return colNames, types }