diff --git a/go/test/endtoend/vtgate/queries/subquery/subquery_test.go b/go/test/endtoend/vtgate/queries/subquery/subquery_test.go index af3682baa35..59dc42de060 100644 --- a/go/test/endtoend/vtgate/queries/subquery/subquery_test.go +++ b/go/test/endtoend/vtgate/queries/subquery/subquery_test.go @@ -116,7 +116,6 @@ func TestSubqueryInINClause(t *testing.T) { } func TestSubqueryInUpdate(t *testing.T) { - utils.SkipIfBinaryIsBelowVersion(t, 14, "vtgate") mcmp, closer := start(t) defer closer() @@ -131,7 +130,6 @@ func TestSubqueryInUpdate(t *testing.T) { } func TestSubqueryInReference(t *testing.T) { - utils.SkipIfBinaryIsBelowVersion(t, 14, "vtgate") mcmp, closer := start(t) defer closer() @@ -177,3 +175,16 @@ func TestSubqueryInAggregation(t *testing.T) { // This fails as the planner adds `weight_string` method which make the query fail on MySQL. // mcmp.Exec(`SELECT max((select min(id2) from t1 where t1.id1 = t.id1)) FROM t1 t`) } + +// TestSubqueryInDerivedTable tests that subqueries and derived tables +// are handled correctly when there are joins inside the derived table +func TestSubqueryInDerivedTable(t *testing.T) { + utils.SkipIfBinaryIsBelowVersion(t, 20, "vtgate") + mcmp, closer := start(t) + defer closer() + + mcmp.Exec("INSERT INTO t1 (id1, id2) VALUES (1, 100), (2, 200), (3, 300), (4, 400), (5, 500);") + mcmp.Exec("INSERT INTO t2 (id3, id4) VALUES (10, 1), (20, 2), (30, 3), (40, 4), (50, 99)") + mcmp.Exec(`select t.a from (select t1.id2, t2.id3, (select id2 from t1 order by id2 limit 1) as a from t1 join t2 on t1.id1 = t2.id4) t`) + mcmp.Exec(`SELECT COUNT(*) FROM (SELECT DISTINCT t1.id1 FROM t1 JOIN t2 ON t1.id1 = t2.id4) dt`) +} diff --git a/go/test/endtoend/vtgate/schematracker/sharded_prs/st_sharded_test.go b/go/test/endtoend/vtgate/schematracker/sharded_prs/st_sharded_test.go index 3ff0b61b482..6ff8e69bb52 100644 --- a/go/test/endtoend/vtgate/schematracker/sharded_prs/st_sharded_test.go +++ b/go/test/endtoend/vtgate/schematracker/sharded_prs/st_sharded_test.go @@ -209,7 +209,6 @@ func TestMain(m *testing.M) { func TestAddColumn(t *testing.T) { defer cluster.PanicHandler(t) - utils.SkipIfBinaryIsBelowVersion(t, 14, "vtgate") ctx := context.Background() conn, err := mysql.Connect(ctx, &vtParams) require.NoError(t, err) diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index 07003876364..a23c6311d24 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -429,6 +429,7 @@ func (node TableName) IsEmpty() bool { // If Name is empty, Qualifier is also empty. return node.Name.IsEmpty() } +func (node TableName) NonEmpty() bool { return !node.Name.IsEmpty() } // NewWhere creates a WHERE or HAVING clause out // of a Expr. If the expression is nil, it returns nil. diff --git a/go/vt/vtgate/executor_select_test.go b/go/vt/vtgate/executor_select_test.go index 9603b5da5bc..12f444bc4ac 100644 --- a/go/vt/vtgate/executor_select_test.go +++ b/go/vt/vtgate/executor_select_test.go @@ -26,28 +26,26 @@ import ( "testing" "time" - _flag "vitess.io/vitess/go/internal/flag" - "vitess.io/vitess/go/mysql/collations" - "vitess.io/vitess/go/streamlog" - "vitess.io/vitess/go/vt/topo/topoproto" - "vitess.io/vitess/go/vt/vtenv" - "vitess.io/vitess/go/vt/vtgate/logstats" - "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + _flag "vitess.io/vitess/go/internal/flag" + "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/streamlog" "vitess.io/vitess/go/test/utils" "vitess.io/vitess/go/vt/discovery" - "vitess.io/vitess/go/vt/vterrors" - _ "vitess.io/vitess/go/vt/vtgate/vindexes" - "vitess.io/vitess/go/vt/vttablet/sandboxconn" - querypb "vitess.io/vitess/go/vt/proto/query" topodatapb "vitess.io/vitess/go/vt/proto/topodata" vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/topo/topoproto" + "vitess.io/vitess/go/vt/vtenv" + "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/logstats" + _ "vitess.io/vitess/go/vt/vtgate/vindexes" + "vitess.io/vitess/go/vt/vttablet/sandboxconn" ) func TestSelectNext(t *testing.T) { @@ -3912,14 +3910,14 @@ func TestSelectAggregationNoData(t *testing.T) { { sql: `select count(*) from (select col1, col2 from user limit 2) x`, sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|1", "int64|int64|int64")), - expSandboxQ: "select col1, col2, 1 from (select col1, col2 from `user`) as x limit :__upper_limit", + expSandboxQ: "select x.col1, x.col2, 1 from (select col1, col2 from `user`) as x limit :__upper_limit", expField: `[name:"count(*)" type:INT64]`, expRow: `[[INT64(0)]]`, }, { sql: `select col2, count(*) from (select col1, col2 from user limit 2) x group by col2`, sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|1|weight_string(col2)", "int64|int64|int64|varbinary")), - expSandboxQ: "select col1, col2, 1, weight_string(col2) from (select col1, col2 from `user`) as x limit :__upper_limit", + expSandboxQ: "select x.col1, x.col2, 1, weight_string(x.col2) from (select col1, col2 from `user`) as x limit :__upper_limit", expField: `[name:"col2" type:INT64 name:"count(*)" type:INT64]`, expRow: `[]`, }, @@ -4004,70 +4002,70 @@ func TestSelectAggregationData(t *testing.T) { { sql: `select count(*) from (select col1, col2 from user limit 2) x`, sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|1", "int64|int64|int64"), "100|200|1", "200|300|1"), - expSandboxQ: "select col1, col2, 1 from (select col1, col2 from `user`) as x limit :__upper_limit", + expSandboxQ: "select x.col1, x.col2, 1 from (select col1, col2 from `user`) as x limit :__upper_limit", expField: `[name:"count(*)" type:INT64]`, expRow: `[[INT64(2)]]`, }, { sql: `select col2, count(*) from (select col1, col2 from user limit 9) x group by col2`, sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|1|weight_string(col2)", "int64|int64|int64|varbinary"), "100|3|1|NULL", "200|2|1|NULL"), - expSandboxQ: "select col1, col2, 1, weight_string(col2) from (select col1, col2 from `user`) as x limit :__upper_limit", + expSandboxQ: "select x.col1, x.col2, 1, weight_string(x.col2) from (select col1, col2 from `user`) as x limit :__upper_limit", expField: `[name:"col2" type:INT64 name:"count(*)" type:INT64]`, expRow: `[[INT64(2) INT64(4)] [INT64(3) INT64(5)]]`, }, { sql: `select count(col1) from (select id, col1 from user limit 2) x`, sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col1", "int64|varchar"), "1|a", "2|b"), - expSandboxQ: "select id, col1 from (select id, col1 from `user`) as x limit :__upper_limit", + expSandboxQ: "select x.id, x.col1 from (select id, col1 from `user`) as x limit :__upper_limit", expField: `[name:"count(col1)" type:INT64]`, expRow: `[[INT64(2)]]`, }, { sql: `select count(col1), col2 from (select col2, col1 from user limit 9) x group by col2`, sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col2|col1|weight_string(col2)", "int64|varchar|varbinary"), "3|a|NULL", "2|b|NULL"), - expSandboxQ: "select col2, col1, weight_string(col2) from (select col2, col1 from `user`) as x limit :__upper_limit", + expSandboxQ: "select x.col2, x.col1, weight_string(x.col2) from (select col2, col1 from `user`) as x limit :__upper_limit", expField: `[name:"count(col1)" type:INT64 name:"col2" type:INT64]`, expRow: `[[INT64(4) INT64(2)] [INT64(5) INT64(3)]]`, }, { sql: `select col1, count(col2) from (select col1, col2 from user limit 9) x group by col1`, sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|weight_string(col1)", "varchar|int64|varbinary"), "a|1|a", "b|null|b"), - expSandboxQ: "select col1, col2, weight_string(col1) from (select col1, col2 from `user`) as x limit :__upper_limit", + expSandboxQ: "select x.col1, x.col2, weight_string(x.col1) from (select col1, col2 from `user`) as x limit :__upper_limit", expField: `[name:"col1" type:VARCHAR name:"count(col2)" type:INT64]`, expRow: `[[VARCHAR("a") INT64(5)] [VARCHAR("b") INT64(0)]]`, }, { sql: `select col1, count(col2) from (select col1, col2 from user limit 32) x group by col1`, sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|weight_string(col1)", "varchar|int64|varbinary"), "null|1|null", "null|null|null", "a|1|a", "b|null|b"), - expSandboxQ: "select col1, col2, weight_string(col1) from (select col1, col2 from `user`) as x limit :__upper_limit", + expSandboxQ: "select x.col1, x.col2, weight_string(x.col1) from (select col1, col2 from `user`) as x limit :__upper_limit", expField: `[name:"col1" type:VARCHAR name:"count(col2)" type:INT64]`, expRow: `[[NULL INT64(8)] [VARCHAR("a") INT64(8)] [VARCHAR("b") INT64(0)]]`, }, { sql: `select col1, sum(col2) from (select col1, col2 from user limit 4) x group by col1`, sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|weight_string(col1)", "varchar|int64|varbinary"), "a|3|a"), - expSandboxQ: "select col1, col2, weight_string(col1) from (select col1, col2 from `user`) as x limit :__upper_limit", + expSandboxQ: "select x.col1, x.col2, weight_string(x.col1) from (select col1, col2 from `user`) as x limit :__upper_limit", expField: `[name:"col1" type:VARCHAR name:"sum(col2)" type:DECIMAL]`, expRow: `[[VARCHAR("a") DECIMAL(12)]]`, }, { sql: `select col1, sum(col2) from (select col1, col2 from user limit 4) x group by col1`, sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|weight_string(col1)", "varchar|varchar|varbinary"), "a|2|a"), - expSandboxQ: "select col1, col2, weight_string(col1) from (select col1, col2 from `user`) as x limit :__upper_limit", + expSandboxQ: "select x.col1, x.col2, weight_string(x.col1) from (select col1, col2 from `user`) as x limit :__upper_limit", expField: `[name:"col1" type:VARCHAR name:"sum(col2)" type:FLOAT64]`, expRow: `[[VARCHAR("a") FLOAT64(8)]]`, }, { sql: `select col1, sum(col2) from (select col1, col2 from user limit 4) x group by col1`, sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|weight_string(col1)", "varchar|varchar|varbinary"), "a|x|a"), - expSandboxQ: "select col1, col2, weight_string(col1) from (select col1, col2 from `user`) as x limit :__upper_limit", + expSandboxQ: "select x.col1, x.col2, weight_string(x.col1) from (select col1, col2 from `user`) as x limit :__upper_limit", expField: `[name:"col1" type:VARCHAR name:"sum(col2)" type:FLOAT64]`, expRow: `[[VARCHAR("a") FLOAT64(0)]]`, }, { sql: `select col1, sum(col2) from (select col1, col2 from user limit 4) x group by col1`, sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col1|col2|weight_string(col1)", "varchar|varchar|varbinary"), "a|null|a"), - expSandboxQ: "select col1, col2, weight_string(col1) from (select col1, col2 from `user`) as x limit :__upper_limit", + expSandboxQ: "select x.col1, x.col2, weight_string(x.col1) from (select col1, col2 from `user`) as x limit :__upper_limit", expField: `[name:"col1" type:VARCHAR name:"sum(col2)" type:FLOAT64]`, expRow: `[[VARCHAR("a") NULL]]`, }, diff --git a/go/vt/vtgate/planbuilder/operators/SQL_builder.go b/go/vt/vtgate/planbuilder/operators/SQL_builder.go index 58b0cfe6545..998f849ba3c 100644 --- a/go/vt/vtgate/planbuilder/operators/SQL_builder.go +++ b/go/vt/vtgate/planbuilder/operators/SQL_builder.go @@ -21,6 +21,7 @@ import ( "slices" "sort" + "vitess.io/vitess/go/slice" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" @@ -84,7 +85,7 @@ func (qb *queryBuilder) addTableExpr( } func (qb *queryBuilder) addPredicate(expr sqlparser.Expr) { - if _, toBeSkipped := qb.ctx.SkipPredicates[expr]; toBeSkipped { + if qb.ctx.ShouldSkip(expr) { // This is a predicate that was added to the RHS of an ApplyJoin. // The original predicate will be added, so we don't have to add this here return @@ -523,20 +524,24 @@ func buildProjection(op *Projection, qb *queryBuilder) { } func buildApplyJoin(op *ApplyJoin, qb *queryBuilder) { + predicates := slice.Map(op.JoinPredicates.columns, func(jc applyJoinColumn) sqlparser.Expr { + // since we are adding these join predicates, we need to mark to broken up version (RHSExpr) of it as done + err := qb.ctx.SkipJoinPredicates(jc.Original) + if err != nil { + panic(err) + } + return jc.Original + }) + pred := sqlparser.AndExpressions(predicates...) + buildQuery(op.LHS, qb) - // If we are going to add the predicate used in join here - // We should not add the predicate's copy of when it was split into - // two parts. To avoid this, we use the SkipPredicates map. - for _, expr := range qb.ctx.JoinPredicates[op.Predicate] { - qb.ctx.SkipPredicates[expr] = nil - } + qbR := &queryBuilder{ctx: qb.ctx} buildQuery(op.RHS, qbR) - if op.LeftJoin { - qb.joinOuterWith(qbR, op.Predicate) + qb.joinOuterWith(qbR, pred) } else { - qb.joinInnerWith(qbR, op.Predicate) + qb.joinInnerWith(qbR, pred) } } diff --git a/go/vt/vtgate/planbuilder/operators/aggregator.go b/go/vt/vtgate/planbuilder/operators/aggregator.go index f8148fb3f0e..6c2171cf689 100644 --- a/go/vt/vtgate/planbuilder/operators/aggregator.go +++ b/go/vt/vtgate/planbuilder/operators/aggregator.go @@ -120,6 +120,14 @@ func (a *Aggregator) isDerived() bool { return a.DT != nil } +func (a *Aggregator) derivedName() string { + if a.DT == nil { + return "" + } + + return a.DT.Alias +} + func (a *Aggregator) FindCol(ctx *plancontext.PlanningContext, in sqlparser.Expr, underRoute bool) int { if underRoute && a.isDerived() { // We don't want to use columns on this operator if it's a derived table under a route. diff --git a/go/vt/vtgate/planbuilder/operators/apply_join.go b/go/vt/vtgate/planbuilder/operators/apply_join.go index 79b92687a49..9294311c00f 100644 --- a/go/vt/vtgate/planbuilder/operators/apply_join.go +++ b/go/vt/vtgate/planbuilder/operators/apply_join.go @@ -37,9 +37,6 @@ type ( // LeftJoin will be true in the case of an outer join LeftJoin bool - // Before offset planning - Predicate sqlparser.Expr - // JoinColumns keeps track of what AST expression is represented in the Columns array JoinColumns *applyJoinColumns @@ -85,16 +82,17 @@ type ( } ) -func NewApplyJoin(lhs, rhs Operator, predicate sqlparser.Expr, leftOuterJoin bool) *ApplyJoin { - return &ApplyJoin{ +func NewApplyJoin(ctx *plancontext.PlanningContext, lhs, rhs Operator, predicate sqlparser.Expr, leftOuterJoin bool) *ApplyJoin { + aj := &ApplyJoin{ LHS: lhs, RHS: rhs, Vars: map[string]int{}, - Predicate: predicate, LeftJoin: leftOuterJoin, JoinColumns: &applyJoinColumns{}, JoinPredicates: &applyJoinColumns{}, } + aj.AddJoinPredicate(ctx, predicate) + return aj } // Clone implements the Operator interface @@ -106,7 +104,6 @@ func (aj *ApplyJoin) Clone(inputs []Operator) Operator { kopy.JoinColumns = aj.JoinColumns.clone() kopy.JoinPredicates = aj.JoinPredicates.clone() kopy.Vars = maps.Clone(aj.Vars) - kopy.Predicate = sqlparser.CloneExpr(aj.Predicate) kopy.ExtraLHSVars = slices.Clone(aj.ExtraLHSVars) return &kopy } @@ -150,8 +147,9 @@ func (aj *ApplyJoin) IsInner() bool { } func (aj *ApplyJoin) AddJoinPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) { - aj.Predicate = ctx.SemTable.AndExpressions(expr, aj.Predicate) - + if expr == nil { + return + } col := breakExpressionInLHSandRHSForApplyJoin(ctx, expr, TableID(aj.LHS)) aj.JoinPredicates.add(col) rhs := aj.RHS.AddPredicate(ctx, col.RHSExpr) @@ -266,11 +264,14 @@ func (aj *ApplyJoin) addOffset(offset int) { } func (aj *ApplyJoin) ShortDescription() string { - pred := sqlparser.String(aj.Predicate) - columns := slice.Map(aj.JoinColumns.columns, func(from applyJoinColumn) string { - return sqlparser.String(from.Original) - }) - firstPart := fmt.Sprintf("on %s columns: %s", pred, strings.Join(columns, ", ")) + fn := func(cols *applyJoinColumns) string { + out := slice.Map(cols.columns, func(jc applyJoinColumn) string { + return jc.String() + }) + return strings.Join(out, ", ") + } + + firstPart := fmt.Sprintf("on %s columns: %s", fn(aj.JoinPredicates), fn(aj.JoinColumns)) if len(aj.ExtraLHSVars) == 0 { return firstPart } @@ -361,6 +362,14 @@ func (a *ApplyJoin) LHSColumnsNeeded(ctx *plancontext.PlanningContext) (needed s return ctx.SemTable.Uniquify(needed) } +func (jc applyJoinColumn) String() string { + rhs := sqlparser.String(jc.RHSExpr) + lhs := slice.Map(jc.LHSExprs, func(e BindVarExpr) string { + return sqlparser.String(e.Expr) + }) + return fmt.Sprintf("[%s | %s | %s]", strings.Join(lhs, ", "), rhs, sqlparser.String(jc.Original)) +} + func (jc applyJoinColumn) IsPureLeft() bool { return jc.RHSExpr == nil } diff --git a/go/vt/vtgate/planbuilder/operators/ast_to_op.go b/go/vt/vtgate/planbuilder/operators/ast_to_op.go index 8f6f06132b5..8a46109e959 100644 --- a/go/vt/vtgate/planbuilder/operators/ast_to_op.go +++ b/go/vt/vtgate/planbuilder/operators/ast_to_op.go @@ -196,6 +196,10 @@ func createOpFromStmt(inCtx *plancontext.PlanningContext, stmt sqlparser.Stateme if err != nil { panic(err) } + + // need to remember which predicates have been broken up during join planning + inCtx.KeepPredicateInfo(ctx) + return op } diff --git a/go/vt/vtgate/planbuilder/operators/expressions.go b/go/vt/vtgate/planbuilder/operators/expressions.go index 35008a3c4ab..1b9194c35fa 100644 --- a/go/vt/vtgate/planbuilder/operators/expressions.go +++ b/go/vt/vtgate/planbuilder/operators/expressions.go @@ -51,7 +51,8 @@ func breakExpressionInLHSandRHSForApplyJoin( cursor.Replace(arg) }, nil).(sqlparser.Expr) - ctx.JoinPredicates[expr] = append(ctx.JoinPredicates[expr], rewrittenExpr) + ctx.AddJoinPredicates(expr, rewrittenExpr) col.RHSExpr = rewrittenExpr + col.Original = expr return } diff --git a/go/vt/vtgate/planbuilder/operators/join_merging.go b/go/vt/vtgate/planbuilder/operators/join_merging.go index dfd89013e94..0cc5da9121f 100644 --- a/go/vt/vtgate/planbuilder/operators/join_merging.go +++ b/go/vt/vtgate/planbuilder/operators/join_merging.go @@ -203,7 +203,7 @@ func mergeShardedRouting(r1 *ShardedRouting, r2 *ShardedRouting) *ShardedRouting } func (jm *joinMerger) getApplyJoin(ctx *plancontext.PlanningContext, op1, op2 *Route) *ApplyJoin { - return NewApplyJoin(op1.Source, op2.Source, ctx.SemTable.AndExpressions(jm.predicates...), !jm.innerJoin) + return NewApplyJoin(ctx, op1.Source, op2.Source, ctx.SemTable.AndExpressions(jm.predicates...), !jm.innerJoin) } func (jm *joinMerger) merge(ctx *plancontext.PlanningContext, op1, op2 *Route, r Routing) *Route { diff --git a/go/vt/vtgate/planbuilder/operators/offset_planning.go b/go/vt/vtgate/planbuilder/operators/offset_planning.go index 3d74059e812..638d3d80907 100644 --- a/go/vt/vtgate/planbuilder/operators/offset_planning.go +++ b/go/vt/vtgate/planbuilder/operators/offset_planning.go @@ -28,6 +28,7 @@ import ( // planOffsets will walk the tree top down, adding offset information to columns in the tree for use in further optimization, func planOffsets(ctx *plancontext.PlanningContext, root Operator) Operator { type offsettable interface { + Operator planOffsets(ctx *plancontext.PlanningContext) Operator } @@ -37,9 +38,16 @@ func planOffsets(ctx *plancontext.PlanningContext, root Operator) Operator { panic(vterrors.VT13001(fmt.Sprintf("should not see %T here", in))) case offsettable: newOp := op.planOffsets(ctx) - if newOp != nil { - return newOp, Rewrote("new operator after offset planning") + + if newOp == nil { + newOp = op + } + + if DebugOperatorTree { + fmt.Println("Planned offsets for:") + fmt.Println(ToTree(newOp)) } + return newOp, nil } return in, NoRewrite } diff --git a/go/vt/vtgate/planbuilder/operators/projection.go b/go/vt/vtgate/planbuilder/operators/projection.go index c06038f893c..1eae4e0e06e 100644 --- a/go/vt/vtgate/planbuilder/operators/projection.go +++ b/go/vt/vtgate/planbuilder/operators/projection.go @@ -142,7 +142,7 @@ func (sp StarProjections) GetSelectExprs() sqlparser.SelectExprs { func (ap AliasedProjections) GetColumns() []*sqlparser.AliasedExpr { return slice.Map(ap, func(from *ProjExpr) *sqlparser.AliasedExpr { - return aeWrap(from.ColExpr) + return from.Original }) } @@ -229,6 +229,14 @@ func (p *Projection) isDerived() bool { return p.DT != nil } +func (p *Projection) derivedName() string { + if p.DT == nil { + return "" + } + + return p.DT.Alias +} + func (p *Projection) FindCol(ctx *plancontext.PlanningContext, expr sqlparser.Expr, underRoute bool) int { ap, err := p.GetAliasedProjections() if err != nil { diff --git a/go/vt/vtgate/planbuilder/operators/projection_pushing.go b/go/vt/vtgate/planbuilder/operators/projection_pushing.go new file mode 100644 index 00000000000..59f6e6d484d --- /dev/null +++ b/go/vt/vtgate/planbuilder/operators/projection_pushing.go @@ -0,0 +1,451 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package operators + +import ( + "slices" + "strconv" + + "vitess.io/vitess/go/slice" + "vitess.io/vitess/go/test/dbg" + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" + "vitess.io/vitess/go/vt/vtgate/semantics" +) + +type ( + projector struct { + columns []*ProjExpr + columnAliases []string + explicitColumnAliases bool + tableName sqlparser.TableName + } +) + +// add introduces a new projection with the specified alias to the projector. +func (p *projector) add(pe *ProjExpr, alias string) { + p.columns = append(p.columns, pe) + if alias != "" && slices.Index(p.columnAliases, alias) > -1 { + panic("alias already used") + } + p.columnAliases = append(p.columnAliases, alias) +} + +// get finds or adds an expression in the projector, returning its SQL representation with the appropriate alias +func (p *projector) get(ctx *plancontext.PlanningContext, expr sqlparser.Expr) sqlparser.Expr { + for _, column := range p.columns { + if ctx.SemTable.EqualsExprWithDeps(expr, column.ColExpr) { + alias := p.claimUnusedAlias(column.Original) + out := sqlparser.NewColName(alias) + out.Qualifier = p.tableName + + ctx.SemTable.CopySemanticInfo(expr, out) + return out + } + } + + // we could not find the expression, so we add it + alias := sqlparser.UnescapedString(expr) + pe := newProjExpr(sqlparser.NewAliasedExpr(expr, alias)) + p.columns = append(p.columns, pe) + p.columnAliases = append(p.columnAliases, alias) + + out := sqlparser.NewColName(alias) + out.Qualifier = p.tableName + + ctx.SemTable.CopySemanticInfo(expr, out) + + return out +} + +// claimUnusedAlias generates a unique alias based on the provided expression, ensuring no duplication in the projector +func (p *projector) claimUnusedAlias(ae *sqlparser.AliasedExpr) string { + bare := ae.ColumnName() + alias := bare + for i := int64(0); slices.Index(p.columnAliases, alias) > -1; i++ { + alias = bare + strconv.FormatInt(i, 10) + } + return alias +} + +// tryPushProjection attempts to optimize a projection by pushing it down in the query plan +func tryPushProjection( + ctx *plancontext.PlanningContext, + p *Projection, +) (Operator, *ApplyResult) { + switch src := p.Source.(type) { + case *Route: + return Swap(p, src, "push projection under route") + case *Limit: + return Swap(p, src, "push projection under limit") + case *ApplyJoin: + if p.FromAggr || !p.canPush(ctx) { + return p, NoRewrite + } + return pushProjectionInApplyJoin(ctx, p, src) + case *HashJoin: + if !p.canPush(ctx) { + return p, NoRewrite + } + return pushProjectionThroughHashJoin(ctx, p, src) + case *Vindex: + if !p.canPush(ctx) { + return p, NoRewrite + } + return pushProjectionInVindex(ctx, p, src) + case *SubQueryContainer: + if !p.canPush(ctx) { + return p, NoRewrite + } + return pushProjectionToOuterContainer(ctx, p, src) + case *SubQuery: + return pushProjectionToOuter(ctx, p, src) + default: + return p, NoRewrite + } +} + +// pushProjectionThroughHashJoin optimizes projection operations within a hash join +func pushProjectionThroughHashJoin(ctx *plancontext.PlanningContext, p *Projection, hj *HashJoin) (Operator, *ApplyResult) { + cols := p.Columns.(AliasedProjections) + for _, col := range cols { + if !col.isSameInAndOut(ctx) { + return p, NoRewrite + } + hj.columns.add(col.ColExpr) + } + return hj, Rewrote("merged projection into hash join") +} + +func pushProjectionToOuter(ctx *plancontext.PlanningContext, p *Projection, sq *SubQuery) (Operator, *ApplyResult) { + ap, err := p.GetAliasedProjections() + if err != nil { + return p, NoRewrite + } + + if !reachedPhase(ctx, subquerySettling) { + return p, NoRewrite + } + + outer := TableID(sq.Outer) + for _, pe := range ap { + _, isOffset := pe.Info.(*Offset) + if isOffset { + continue + } + + if !ctx.SemTable.RecursiveDeps(pe.EvalExpr).IsSolvedBy(outer) { + return p, NoRewrite + } + + se, ok := pe.Info.(SubQueryExpression) + if ok { + pe.EvalExpr = rewriteColNameToArgument(ctx, pe.EvalExpr, se, sq) + } + } + // all projections can be pushed to the outer + sq.Outer, p.Source = p, sq.Outer + return sq, Rewrote("push projection into outer side of subquery") +} + +func pushProjectionInVindex( + ctx *plancontext.PlanningContext, + p *Projection, + src *Vindex, +) (Operator, *ApplyResult) { + ap, err := p.GetAliasedProjections() + if err != nil { + panic(err) + } + for _, pe := range ap { + src.AddColumn(ctx, true, false, aeWrap(pe.EvalExpr)) + } + return src, Rewrote("push projection into vindex") +} + +func pushProjectionToOuterContainer(ctx *plancontext.PlanningContext, p *Projection, src *SubQueryContainer) (Operator, *ApplyResult) { + ap, err := p.GetAliasedProjections() + if err != nil { + return p, NoRewrite + } + + outer := TableID(src.Outer) + for _, pe := range ap { + _, isOffset := pe.Info.(*Offset) + if isOffset { + continue + } + + if !ctx.SemTable.RecursiveDeps(pe.EvalExpr).IsSolvedBy(outer) { + return p, NoRewrite + } + + if se, ok := pe.Info.(SubQueryExpression); ok { + pe.EvalExpr = rewriteColNameToArgument(ctx, pe.EvalExpr, se, src.Inner...) + } + } + // all projections can be pushed to the outer + src.Outer, p.Source = p, src.Outer + return src, Rewrote("push projection into outer side of subquery container") +} + +// pushProjectionInApplyJoin pushes down a projection operation into an ApplyJoin operation. +// It processes each input column and creates new JoinPredicates for the ApplyJoin operation based on +// the input column's expression. It also creates new Projection operators for the left and right +// children of the ApplyJoin operation, if needed. +func pushProjectionInApplyJoin( + ctx *plancontext.PlanningContext, + p *Projection, + src *ApplyJoin, +) (Operator, *ApplyResult) { + ap, err := p.GetAliasedProjections() + if src.LeftJoin || err != nil { + // we can't push down expression evaluation to the rhs if we are not sure if it will even be executed + return p, NoRewrite + } + lhs, rhs := &projector{}, &projector{} + if p.DT != nil && len(p.DT.Columns) > 0 { + lhs.explicitColumnAliases = true + rhs.explicitColumnAliases = true + } + + src.JoinColumns = &applyJoinColumns{} + for idx, pe := range ap { + var alias string + if p.DT != nil && len(p.DT.Columns) > 0 { + if len(p.DT.Columns) <= idx { + panic(vterrors.VT13001("no such alias found for derived table")) + } + alias = p.DT.Columns[idx].String() + } + splitProjectionAcrossJoin(ctx, src, lhs, rhs, pe, alias) + } + + if p.isDerived() { + exposeColumnsThroughDerivedTable(ctx, p, src, lhs, rhs) + } + + // Create and update the Projection operators for the left and right children, if needed. + src.LHS = createProjectionWithTheseColumns(ctx, src.LHS, lhs, p.DT) + src.RHS = createProjectionWithTheseColumns(ctx, src.RHS, rhs, p.DT) + + return src, Rewrote("split projection to either side of join") +} + +// splitProjectionAcrossJoin creates JoinPredicates for all projections, +// and pushes down columns as needed between the LHS and RHS of a join +func splitProjectionAcrossJoin( + ctx *plancontext.PlanningContext, + join *ApplyJoin, + lhs, rhs *projector, + pe *ProjExpr, + colAlias string, +) { + + // Check if the current expression can reuse an existing column in the ApplyJoin. + if _, found := canReuseColumn(ctx, join.JoinColumns.columns, pe.EvalExpr, joinColumnToExpr); found { + return + } + + switch pe.Info.(type) { + case nil: + join.JoinColumns.add(splitUnexploredExpression(ctx, join, lhs, rhs, pe, colAlias)) + case Offset: + // for offsets, we'll just treat the expression as unexplored, and later stages will handle the new offset + join.JoinColumns.add(splitUnexploredExpression(ctx, join, lhs, rhs, pe, colAlias)) + case SubQueryExpression: + join.JoinColumns.add(splitSubqueryExpression(ctx, join, lhs, rhs, pe, colAlias)) + default: + panic(dbg.S(pe.Info)) + } +} + +func splitSubqueryExpression( + ctx *plancontext.PlanningContext, + join *ApplyJoin, + lhs, rhs *projector, + pe *ProjExpr, + alias string, +) applyJoinColumn { + col := join.getJoinColumnFor(ctx, pe.Original, pe.ColExpr, false) + return pushDownSplitJoinCol(col, lhs, pe, alias, rhs) +} + +func splitUnexploredExpression( + ctx *plancontext.PlanningContext, + join *ApplyJoin, + lhs, rhs *projector, + pe *ProjExpr, + alias string, +) applyJoinColumn { + // Get a applyJoinColumn for the current expression. + col := join.getJoinColumnFor(ctx, pe.Original, pe.ColExpr, false) + + return pushDownSplitJoinCol(col, lhs, pe, alias, rhs) +} + +func pushDownSplitJoinCol(col applyJoinColumn, lhs *projector, pe *ProjExpr, alias string, rhs *projector) applyJoinColumn { + // Update the left and right child columns and names based on the applyJoinColumn type. + switch { + case col.IsPureLeft(): + lhs.add(pe, alias) + case col.IsPureRight(): + rhs.add(pe, alias) + case col.IsMixedLeftAndRight(): + for _, lhsExpr := range col.LHSExprs { + var lhsAlias string + if alias != "" { + // we need to add an explicit column alias here. let's try just the ColName as is first + lhsAlias = sqlparser.String(lhsExpr.Expr) + } + lhs.add(newProjExpr(aeWrap(lhsExpr.Expr)), lhsAlias) + } + innerPE := newProjExprWithInner(pe.Original, col.RHSExpr) + innerPE.ColExpr = col.RHSExpr + innerPE.Info = pe.Info + rhs.add(innerPE, alias) + } + return col +} + +// exposeColumnsThroughDerivedTable rewrites expressions within a join that is inside a derived table +// in order to make them accessible outside the derived table. This is necessary when swapping the +// positions of the derived table and join operation. +// +// For example, consider the input query: +// select ... from (select T1.foo from T1 join T2 on T1.id = T2.id) as t +// If we push the derived table under the join, with T1 on the LHS of the join, we need to expose +// the values of T1.id through the derived table, or they will not be accessible on the RHS. +// +// The function iterates through each join predicate, rewriting the expressions in the predicate's +// LHS expressions to include the derived table. This allows the expressions to be accessed outside +// the derived table. +func exposeColumnsThroughDerivedTable(ctx *plancontext.PlanningContext, p *Projection, src *ApplyJoin, lhs, rhs *projector) { + derivedTbl, err := ctx.SemTable.TableInfoFor(p.DT.TableID) + if err != nil { + panic(err) + } + derivedTblName, err := derivedTbl.Name() + if err != nil { + panic(err) + } + lhs.tableName = derivedTblName + rhs.tableName = derivedTblName + + lhsIDs := TableID(src.LHS) + rhsIDs := TableID(src.RHS) + rewriteColumnsForJoin(ctx, src.JoinPredicates.columns, lhsIDs, rhsIDs, lhs, rhs, false) + rewriteColumnsForJoin(ctx, src.JoinColumns.columns, lhsIDs, rhsIDs, lhs, rhs, true) +} + +func rewriteColumnsForJoin( + ctx *plancontext.PlanningContext, + columns []applyJoinColumn, + lhsIDs, rhsIDs semantics.TableSet, + lhs, rhs *projector, + exposeRHS bool, // we only want to expose the returned columns from the RHS. + // For predicates, we don't need to expose the RHS columns +) { + for colIdx, column := range columns { + for lhsIdx, bve := range column.LHSExprs { + // since this is on the LHSExprs, we know that dependencies are from that side of the join + column.LHSExprs[lhsIdx].Expr = lhs.get(ctx, bve.Expr) + } + if column.IsPureLeft() { + continue + } + + // now we need to go over the predicate and find + var rewriteTo sqlparser.Expr + + pre := func(node, _ sqlparser.SQLNode) bool { + _, isSQ := node.(*sqlparser.Subquery) + if isSQ { + return false + } + expr, ok := node.(sqlparser.Expr) + if !ok { + return true + } + deps := ctx.SemTable.RecursiveDeps(expr) + + switch { + case deps.IsEmpty(): + return true + case deps.IsSolvedBy(lhsIDs): + rewriteTo = lhs.get(ctx, expr) + return false + case deps.IsSolvedBy(rhsIDs): + if exposeRHS { + rewriteTo = rhs.get(ctx, expr) + } + return false + default: + return true + } + } + + post := func(cursor *sqlparser.CopyOnWriteCursor) { + if rewriteTo != nil { + cursor.Replace(rewriteTo) + rewriteTo = nil + return + } + } + newOriginal := sqlparser.CopyOnRewrite(column.Original, pre, post, ctx.SemTable.CopySemanticInfo).(sqlparser.Expr) + column.Original = newOriginal + + columns[colIdx] = column + } +} + +// prefixColNames adds qualifier prefixes to all ColName:s. +// We want to be more explicit than the user was to make sure we never produce invalid SQL +func prefixColNames(ctx *plancontext.PlanningContext, tblName sqlparser.TableName, e sqlparser.Expr) sqlparser.Expr { + return sqlparser.CopyOnRewrite(e, nil, func(cursor *sqlparser.CopyOnWriteCursor) { + col, ok := cursor.Node().(*sqlparser.ColName) + if !ok { + return + } + cursor.Replace(sqlparser.NewColNameWithQualifier(col.Name.String(), tblName)) + }, ctx.SemTable.CopySemanticInfo).(sqlparser.Expr) +} + +func createProjectionWithTheseColumns( + ctx *plancontext.PlanningContext, + src Operator, + p *projector, + dt *DerivedTable, +) Operator { + if len(p.columns) == 0 { + return src + } + proj := createProjection(ctx, src, "") + proj.Columns = AliasedProjections(p.columns) + if dt != nil { + kopy := *dt + if p.explicitColumnAliases { + kopy.Columns = slice.Map(p.columnAliases, func(s string) sqlparser.IdentifierCI { + return sqlparser.NewIdentifierCI(s) + }) + } + proj.DT = &kopy + } + + return proj +} diff --git a/go/vt/vtgate/planbuilder/operators/query_planning.go b/go/vt/vtgate/planbuilder/operators/query_planning.go index 135126a15bd..f412e783f42 100644 --- a/go/vt/vtgate/planbuilder/operators/query_planning.go +++ b/go/vt/vtgate/planbuilder/operators/query_planning.go @@ -27,14 +27,6 @@ import ( "vitess.io/vitess/go/vt/vtgate/semantics" ) -type ( - projector struct { - columns []*ProjExpr - columnAliases sqlparser.Columns - explicitColumnAliases bool - } -) - func planQuery(ctx *plancontext.PlanningContext, root Operator) Operator { output := runPhases(ctx, root) output = planOffsets(ctx, output) @@ -251,280 +243,6 @@ func pushOrExpandHorizon(ctx *plancontext.PlanningContext, in *Horizon) (Operato return expandHorizon(ctx, in) } -func tryPushProjection( - ctx *plancontext.PlanningContext, - p *Projection, -) (Operator, *ApplyResult) { - switch src := p.Source.(type) { - case *Route: - return Swap(p, src, "push projection under route") - case *ApplyJoin: - if p.FromAggr || !p.canPush(ctx) { - return p, NoRewrite - } - return pushProjectionInApplyJoin(ctx, p, src) - case *HashJoin: - if !p.canPush(ctx) { - return p, NoRewrite - } - return pushProjectionThroughHashJoin(ctx, p, src) - case *Vindex: - if !p.canPush(ctx) { - return p, NoRewrite - } - return pushProjectionInVindex(ctx, p, src) - case *SubQueryContainer: - if !p.canPush(ctx) { - return p, NoRewrite - } - return pushProjectionToOuterContainer(ctx, p, src) - case *SubQuery: - return pushProjectionToOuter(ctx, p, src) - case *Limit: - return Swap(p, src, "push projection under limit") - default: - return p, NoRewrite - } -} - -func pushProjectionThroughHashJoin(ctx *plancontext.PlanningContext, p *Projection, hj *HashJoin) (Operator, *ApplyResult) { - cols := p.Columns.(AliasedProjections) - for _, col := range cols { - if !col.isSameInAndOut(ctx) { - return p, NoRewrite - } - hj.columns.add(col.ColExpr) - } - return hj, Rewrote("merged projection into hash join") -} - -func pushProjectionToOuter(ctx *plancontext.PlanningContext, p *Projection, sq *SubQuery) (Operator, *ApplyResult) { - ap, err := p.GetAliasedProjections() - if err != nil { - return p, NoRewrite - } - - if !reachedPhase(ctx, subquerySettling) { - return p, NoRewrite - } - - outer := TableID(sq.Outer) - for _, pe := range ap { - _, isOffset := pe.Info.(*Offset) - if isOffset { - continue - } - - if !ctx.SemTable.RecursiveDeps(pe.EvalExpr).IsSolvedBy(outer) { - return p, NoRewrite - } - - se, ok := pe.Info.(SubQueryExpression) - if ok { - pe.EvalExpr = rewriteColNameToArgument(ctx, pe.EvalExpr, se, sq) - } - } - // all projections can be pushed to the outer - sq.Outer, p.Source = p, sq.Outer - return sq, Rewrote("push projection into outer side of subquery") -} - -func pushProjectionInVindex( - ctx *plancontext.PlanningContext, - p *Projection, - src *Vindex, -) (Operator, *ApplyResult) { - ap, err := p.GetAliasedProjections() - if err != nil { - panic(err) - } - for _, pe := range ap { - src.AddColumn(ctx, true, false, aeWrap(pe.EvalExpr)) - } - return src, Rewrote("push projection into vindex") -} - -func (p *projector) add(pe *ProjExpr, col *sqlparser.IdentifierCI) { - p.columns = append(p.columns, pe) - if col != nil { - p.columnAliases = append(p.columnAliases, *col) - } -} - -// pushProjectionInApplyJoin pushes down a projection operation into an ApplyJoin operation. -// It processes each input column and creates new JoinPredicates for the ApplyJoin operation based on -// the input column's expression. It also creates new Projection operators for the left and right -// children of the ApplyJoin operation, if needed. -func pushProjectionInApplyJoin( - ctx *plancontext.PlanningContext, - p *Projection, - src *ApplyJoin, -) (Operator, *ApplyResult) { - ap, err := p.GetAliasedProjections() - if src.LeftJoin || err != nil { - // we can't push down expression evaluation to the rhs if we are not sure if it will even be executed - return p, NoRewrite - } - lhs, rhs := &projector{}, &projector{} - if p.DT != nil && len(p.DT.Columns) > 0 { - lhs.explicitColumnAliases = true - rhs.explicitColumnAliases = true - } - - src.JoinColumns = &applyJoinColumns{} - for idx, pe := range ap { - var col *sqlparser.IdentifierCI - if p.DT != nil && idx < len(p.DT.Columns) { - col = &p.DT.Columns[idx] - } - splitProjectionAcrossJoin(ctx, src, lhs, rhs, pe, col) - } - - if p.isDerived() { - exposeColumnsThroughDerivedTable(ctx, p, src, lhs) - } - - // Create and update the Projection operators for the left and right children, if needed. - src.LHS = createProjectionWithTheseColumns(ctx, src.LHS, lhs, p.DT) - src.RHS = createProjectionWithTheseColumns(ctx, src.RHS, rhs, p.DT) - - return src, Rewrote("split projection to either side of join") -} - -// splitProjectionAcrossJoin creates JoinPredicates for all projections, -// and pushes down columns as needed between the LHS and RHS of a join -func splitProjectionAcrossJoin( - ctx *plancontext.PlanningContext, - join *ApplyJoin, - lhs, rhs *projector, - pe *ProjExpr, - colAlias *sqlparser.IdentifierCI, -) { - - // Check if the current expression can reuse an existing column in the ApplyJoin. - if _, found := canReuseColumn(ctx, join.JoinColumns.columns, pe.EvalExpr, joinColumnToExpr); found { - return - } - - // Add the new applyJoinColumn to the ApplyJoin's JoinPredicates. - join.JoinColumns.add(splitUnexploredExpression(ctx, join, lhs, rhs, pe, colAlias)) -} - -func splitUnexploredExpression( - ctx *plancontext.PlanningContext, - join *ApplyJoin, - lhs, rhs *projector, - pe *ProjExpr, - colAlias *sqlparser.IdentifierCI, -) applyJoinColumn { - // Get a applyJoinColumn for the current expression. - col := join.getJoinColumnFor(ctx, pe.Original, pe.ColExpr, false) - - // Update the left and right child columns and names based on the applyJoinColumn type. - switch { - case col.IsPureLeft(): - lhs.add(pe, colAlias) - case col.IsPureRight(): - rhs.add(pe, colAlias) - case col.IsMixedLeftAndRight(): - for _, lhsExpr := range col.LHSExprs { - var lhsAlias *sqlparser.IdentifierCI - if colAlias != nil { - // we need to add an explicit column alias here. let's try just the ColName as is first - ci := sqlparser.NewIdentifierCI(sqlparser.String(lhsExpr.Expr)) - lhsAlias = &ci - } - lhs.add(newProjExpr(aeWrap(lhsExpr.Expr)), lhsAlias) - } - innerPE := newProjExprWithInner(pe.Original, col.RHSExpr) - innerPE.ColExpr = col.RHSExpr - innerPE.Info = pe.Info - rhs.add(innerPE, colAlias) - } - return col -} - -// exposeColumnsThroughDerivedTable rewrites expressions within a join that is inside a derived table -// in order to make them accessible outside the derived table. This is necessary when swapping the -// positions of the derived table and join operation. -// -// For example, consider the input query: -// select ... from (select T1.foo from T1 join T2 on T1.id = T2.id) as t -// If we push the derived table under the join, with T1 on the LHS of the join, we need to expose -// the values of T1.id through the derived table, or they will not be accessible on the RHS. -// -// The function iterates through each join predicate, rewriting the expressions in the predicate's -// LHS expressions to include the derived table. This allows the expressions to be accessed outside -// the derived table. -func exposeColumnsThroughDerivedTable(ctx *plancontext.PlanningContext, p *Projection, src *ApplyJoin, lhs *projector) { - derivedTbl, err := ctx.SemTable.TableInfoFor(p.DT.TableID) - if err != nil { - panic(err) - } - derivedTblName, err := derivedTbl.Name() - if err != nil { - panic(err) - } - for _, predicate := range src.JoinPredicates.columns { - for idx, bve := range predicate.LHSExprs { - expr := bve.Expr - tbl, err := ctx.SemTable.TableInfoForExpr(expr) - if err != nil { - panic(err) - } - tblName, err := tbl.Name() - if err != nil { - panic(err) - } - - expr = semantics.RewriteDerivedTableExpression(expr, derivedTbl) - out := prefixColNames(ctx, tblName, expr) - - alias := sqlparser.UnescapedString(out) - predicate.LHSExprs[idx].Expr = sqlparser.NewColNameWithQualifier(alias, derivedTblName) - identifierCI := sqlparser.NewIdentifierCI(alias) - projExpr := newProjExprWithInner(&sqlparser.AliasedExpr{Expr: out, As: identifierCI}, out) - var colAlias *sqlparser.IdentifierCI - if lhs.explicitColumnAliases { - colAlias = &identifierCI - } - lhs.add(projExpr, colAlias) - } - } -} - -// prefixColNames adds qualifier prefixes to all ColName:s. -// We want to be more explicit than the user was to make sure we never produce invalid SQL -func prefixColNames(ctx *plancontext.PlanningContext, tblName sqlparser.TableName, e sqlparser.Expr) sqlparser.Expr { - return sqlparser.CopyOnRewrite(e, nil, func(cursor *sqlparser.CopyOnWriteCursor) { - col, ok := cursor.Node().(*sqlparser.ColName) - if !ok { - return - } - cursor.Replace(sqlparser.NewColNameWithQualifier(col.Name.String(), tblName)) - }, ctx.SemTable.CopySemanticInfo).(sqlparser.Expr) -} - -func createProjectionWithTheseColumns( - ctx *plancontext.PlanningContext, - src Operator, - p *projector, - dt *DerivedTable, -) Operator { - if len(p.columns) == 0 { - return src - } - proj := createProjection(ctx, src) - proj.Columns = AliasedProjections(p.columns) - if dt != nil { - kopy := *dt - kopy.Columns = p.columnAliases - proj.DT = &kopy - } - - return proj -} - func tryPushLimit(in *Limit) (Operator, *ApplyResult) { switch src := in.Source.(type) { case *Route: diff --git a/go/vt/vtgate/planbuilder/operators/route.go b/go/vt/vtgate/planbuilder/operators/route.go index 5303cac401b..ea50f605105 100644 --- a/go/vt/vtgate/planbuilder/operators/route.go +++ b/go/vt/vtgate/planbuilder/operators/route.go @@ -19,13 +19,12 @@ package operators import ( "fmt" - topodatapb "vitess.io/vitess/go/vt/proto/topodata" - "vitess.io/vitess/go/vt/vtenv" - "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/slice" "vitess.io/vitess/go/vt/key" + topodatapb "vitess.io/vitess/go/vt/proto/topodata" "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vtenv" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/engine" "vitess.io/vitess/go/vt/vtgate/evalengine" @@ -564,11 +563,21 @@ func (r *Route) AddPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Ex return r } -func createProjection(ctx *plancontext.PlanningContext, src Operator) *Projection { +func createProjection(ctx *plancontext.PlanningContext, src Operator, derivedName string) *Projection { proj := newAliasedProjection(src) cols := src.GetColumns(ctx) for _, col := range cols { - proj.addUnexploredExpr(col, col.Expr) + if derivedName == "" { + proj.addUnexploredExpr(col, col.Expr) + continue + } + + // for derived tables, we want to use the exposed colname + tableName := sqlparser.NewTableName(derivedName) + columnName := col.ColumnName() + colName := sqlparser.NewColNameWithQualifier(columnName, tableName) + ctx.SemTable.CopySemanticInfo(col.Expr, colName) + proj.addUnexploredExpr(aeWrap(colName), colName) } return proj } @@ -585,14 +594,14 @@ func (r *Route) AddColumn(ctx *plancontext.PlanningContext, reuse bool, gb bool, // if at least one column is not already present, we check if we can easily find a projection // or aggregation in our source that we can add to - op, ok, offsets := addMultipleColumnsToInput(ctx, r.Source, reuse, []bool{gb}, []*sqlparser.AliasedExpr{expr}) + derived, op, ok, offsets := addMultipleColumnsToInput(ctx, r.Source, reuse, []bool{gb}, []*sqlparser.AliasedExpr{expr}) r.Source = op if ok { return offsets[0] } // If no-one could be found, we probably don't have one yet, so we add one here - src := createProjection(ctx, r.Source) + src := createProjection(ctx, r.Source, derived) r.Source = src offsets = src.addColumnsWithoutPushing(ctx, reuse, []bool{gb}, []*sqlparser.AliasedExpr{expr}) @@ -603,57 +612,66 @@ type selectExpressions interface { Operator addColumnWithoutPushing(ctx *plancontext.PlanningContext, expr *sqlparser.AliasedExpr, addToGroupBy bool) int addColumnsWithoutPushing(ctx *plancontext.PlanningContext, reuse bool, addToGroupBy []bool, exprs []*sqlparser.AliasedExpr) []int - isDerived() bool + derivedName() string } -// addColumnToInput adds a column to an operator without pushing it down. -// It will return a bool indicating whether the addition was successful or not, -// and an offset to where the column can be found -func addMultipleColumnsToInput(ctx *plancontext.PlanningContext, operator Operator, reuse bool, addToGroupBy []bool, exprs []*sqlparser.AliasedExpr) (Operator, bool, []int) { +// addColumnToInput adds columns to an operator without pushing them down +func addMultipleColumnsToInput( + ctx *plancontext.PlanningContext, + operator Operator, + reuse bool, + addToGroupBy []bool, + exprs []*sqlparser.AliasedExpr, +) (derivedName string, // if we found a derived table, this will contain its name + projection Operator, // if an operator needed to be built, it will be returned here + found bool, // whether a matching op was found or not + offsets []int, // the offsets the expressions received +) { switch op := operator.(type) { case *SubQuery: - src, added, offset := addMultipleColumnsToInput(ctx, op.Outer, reuse, addToGroupBy, exprs) + derivedName, src, added, offset := addMultipleColumnsToInput(ctx, op.Outer, reuse, addToGroupBy, exprs) if added { op.Outer = src } - return op, added, offset + return derivedName, op, added, offset case *Distinct: - src, added, offset := addMultipleColumnsToInput(ctx, op.Source, reuse, addToGroupBy, exprs) + derivedName, src, added, offset := addMultipleColumnsToInput(ctx, op.Source, reuse, addToGroupBy, exprs) if added { op.Source = src } - return op, added, offset + return derivedName, op, added, offset case *Limit: - src, added, offset := addMultipleColumnsToInput(ctx, op.Source, reuse, addToGroupBy, exprs) + derivedName, src, added, offset := addMultipleColumnsToInput(ctx, op.Source, reuse, addToGroupBy, exprs) if added { op.Source = src } - return op, added, offset + return derivedName, op, added, offset case *Ordering: - src, added, offset := addMultipleColumnsToInput(ctx, op.Source, reuse, addToGroupBy, exprs) + derivedName, src, added, offset := addMultipleColumnsToInput(ctx, op.Source, reuse, addToGroupBy, exprs) if added { op.Source = src } - return op, added, offset + return derivedName, op, added, offset case *LockAndComment: - src, added, offset := addMultipleColumnsToInput(ctx, op.Source, reuse, addToGroupBy, exprs) + derivedName, src, added, offset := addMultipleColumnsToInput(ctx, op.Source, reuse, addToGroupBy, exprs) if added { op.Source = src } - return op, added, offset + return derivedName, op, added, offset case selectExpressions: - if op.isDerived() { + name := op.derivedName() + if name != "" { // if the only thing we can push to is a derived table, // we have to add a new projection and can't build on this one - return op, false, nil + return name, op, false, nil } offset := op.addColumnsWithoutPushing(ctx, reuse, addToGroupBy, exprs) - return op, true, offset + return "", op, true, offset case *Union: tableID := semantics.SingleTableSet(len(ctx.SemTable.Tables)) @@ -669,7 +687,7 @@ func addMultipleColumnsToInput(ctx *plancontext.PlanningContext, operator Operat } return addMultipleColumnsToInput(ctx, proj, reuse, addToGroupBy, exprs) default: - return op, false, nil + return "", op, false, nil } } diff --git a/go/vt/vtgate/planbuilder/operators/route_planning.go b/go/vt/vtgate/planbuilder/operators/route_planning.go index 81fece1e596..6290ec9038c 100644 --- a/go/vt/vtgate/planbuilder/operators/route_planning.go +++ b/go/vt/vtgate/planbuilder/operators/route_planning.go @@ -347,12 +347,12 @@ func mergeOrJoin(ctx *plancontext.PlanningContext, lhs, rhs Operator, joinPredic return join, Rewrote("use a hash join because we have LIMIT on the LHS") } - join := NewApplyJoin(Clone(rhs), Clone(lhs), nil, !inner) + join := NewApplyJoin(ctx, Clone(rhs), Clone(lhs), nil, !inner) newOp := pushJoinPredicates(ctx, joinPredicates, join) return newOp, Rewrote("logical join to applyJoin, switching side because LIMIT") } - join := NewApplyJoin(Clone(lhs), Clone(rhs), nil, !inner) + join := NewApplyJoin(ctx, Clone(lhs), Clone(rhs), nil, !inner) newOp := pushJoinPredicates(ctx, joinPredicates, join) return newOp, Rewrote("logical join to applyJoin ") } diff --git a/go/vt/vtgate/planbuilder/operators/subquery.go b/go/vt/vtgate/planbuilder/operators/subquery.go index 0765a878a3e..24417cfab21 100644 --- a/go/vt/vtgate/planbuilder/operators/subquery.go +++ b/go/vt/vtgate/planbuilder/operators/subquery.go @@ -167,7 +167,7 @@ func (sq *SubQuery) ShortDescription() string { preds := append(sq.Predicates, sq.OuterPredicate) pred = " MERGE ON " + sqlparser.String(sqlparser.AndExpressions(preds...)) } - return fmt.Sprintf("%s %v%s", typ, sq.FilterType.String(), pred) + return fmt.Sprintf(":%s %s %v%s", sq.ArgName, typ, sq.FilterType.String(), pred) } func (sq *SubQuery) AddPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) Operator { diff --git a/go/vt/vtgate/planbuilder/operators/subquery_planning.go b/go/vt/vtgate/planbuilder/operators/subquery_planning.go index 0980cca9cc8..1727f7bedcb 100644 --- a/go/vt/vtgate/planbuilder/operators/subquery_planning.go +++ b/go/vt/vtgate/planbuilder/operators/subquery_planning.go @@ -343,32 +343,6 @@ func rewriteOriginalPushedToRHS(ctx *plancontext.PlanningContext, expression sql return result.(sqlparser.Expr) } -func pushProjectionToOuterContainer(ctx *plancontext.PlanningContext, p *Projection, src *SubQueryContainer) (Operator, *ApplyResult) { - ap, err := p.GetAliasedProjections() - if err != nil { - return p, NoRewrite - } - - outer := TableID(src.Outer) - for _, pe := range ap { - _, isOffset := pe.Info.(*Offset) - if isOffset { - continue - } - - if !ctx.SemTable.RecursiveDeps(pe.EvalExpr).IsSolvedBy(outer) { - return p, NoRewrite - } - - if se, ok := pe.Info.(SubQueryExpression); ok { - pe.EvalExpr = rewriteColNameToArgument(ctx, pe.EvalExpr, se, src.Inner...) - } - } - // all projections can be pushed to the outer - src.Outer, p.Source = p, src.Outer - return src, Rewrote("push projection into outer side of subquery container") -} - func rewriteColNameToArgument(ctx *plancontext.PlanningContext, in sqlparser.Expr, se SubQueryExpression, subqueries ...*SubQuery) sqlparser.Expr { rewriteIt := func(s string) sqlparser.SQLNode { for _, sq1 := range se { diff --git a/go/vt/vtgate/planbuilder/plan_test.go b/go/vt/vtgate/planbuilder/plan_test.go index 1a09aa555ec..caef2ae21f4 100644 --- a/go/vt/vtgate/planbuilder/plan_test.go +++ b/go/vt/vtgate/planbuilder/plan_test.go @@ -289,6 +289,9 @@ func TestOne(t *testing.T) { } func TestOneTPCC(t *testing.T) { + reset := operators.EnableDebugPrinting() + defer reset() + vschema := &vschemawrapper.VSchemaWrapper{ V: loadSchema(t, "vschemas/tpcc_schema.json", true), Env: vtenv.NewTestEnv(), @@ -298,6 +301,8 @@ func TestOneTPCC(t *testing.T) { } func TestOneWithMainAsDefault(t *testing.T) { + reset := operators.EnableDebugPrinting() + defer reset() vschema := &vschemawrapper.VSchemaWrapper{ V: loadSchema(t, "vschemas/schema.json", true), Keyspace: &vindexes.Keyspace{ @@ -311,6 +316,8 @@ func TestOneWithMainAsDefault(t *testing.T) { } func TestOneWithSecondUserAsDefault(t *testing.T) { + reset := operators.EnableDebugPrinting() + defer reset() vschema := &vschemawrapper.VSchemaWrapper{ V: loadSchema(t, "vschemas/schema.json", true), Keyspace: &vindexes.Keyspace{ @@ -324,6 +331,8 @@ func TestOneWithSecondUserAsDefault(t *testing.T) { } func TestOneWithUserAsDefault(t *testing.T) { + reset := operators.EnableDebugPrinting() + defer reset() vschema := &vschemawrapper.VSchemaWrapper{ V: loadSchema(t, "vschemas/schema.json", true), Keyspace: &vindexes.Keyspace{ @@ -349,6 +358,8 @@ func TestOneWithTPCHVSchema(t *testing.T) { } func TestOneWith57Version(t *testing.T) { + reset := operators.EnableDebugPrinting() + defer reset() // first we move everything to use 5.7 logic env, err := vtenv.New(vtenv.Options{ MySQLServerVersion: "5.7.9", diff --git a/go/vt/vtgate/planbuilder/plancontext/planning_context.go b/go/vt/vtgate/planbuilder/plancontext/planning_context.go index 3871c8fdbc4..49039ddd347 100644 --- a/go/vt/vtgate/planbuilder/plancontext/planning_context.go +++ b/go/vt/vtgate/planbuilder/plancontext/planning_context.go @@ -19,6 +19,7 @@ package plancontext import ( querypb "vitess.io/vitess/go/vt/proto/query" "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/semantics" ) @@ -27,12 +28,16 @@ type PlanningContext struct { SemTable *semantics.SemTable VSchema VSchema - // here we add all predicates that were created because of a join condition - // e.g. [FROM tblA JOIN tblB ON a.colA = b.colB] will be rewritten to [FROM tblB WHERE :a_colA = b.colB], - // if we assume that tblB is on the RHS of the join. This last predicate in the WHERE clause is added to the - // map below - JoinPredicates map[sqlparser.Expr][]sqlparser.Expr - SkipPredicates map[sqlparser.Expr]any + // joinPredicates maps each original join predicate (key) to a slice of + // variations of the RHS predicates (value). This map is used to handle + // different scenarios in join planning, where the RHS predicates are + // modified to accommodate dependencies from the LHS, represented as Arguments. + joinPredicates map[sqlparser.Expr][]sqlparser.Expr + + // skipPredicates tracks predicates that should be skipped, typically when + // a join predicate is reverted to its original form during planning. + skipPredicates map[sqlparser.Expr]any + PlannerVersion querypb.ExecuteOptions_PlannerVersion // If we during planning have turned this expression into an argument name, @@ -54,6 +59,10 @@ type PlanningContext struct { Statement sqlparser.Statement } +// CreatePlanningContext initializes a new PlanningContext with the given parameters. +// It analyzes the SQL statement within the given virtual schema context, +// handling default keyspace settings and semantic analysis. +// Returns an error if semantic analysis fails. func CreatePlanningContext(stmt sqlparser.Statement, reservedVars *sqlparser.ReservedVars, vschema VSchema, @@ -76,14 +85,17 @@ func CreatePlanningContext(stmt sqlparser.Statement, ReservedVars: reservedVars, SemTable: semTable, VSchema: vschema, - JoinPredicates: map[sqlparser.Expr][]sqlparser.Expr{}, - SkipPredicates: map[sqlparser.Expr]any{}, + joinPredicates: map[sqlparser.Expr][]sqlparser.Expr{}, + skipPredicates: map[sqlparser.Expr]any{}, PlannerVersion: version, ReservedArguments: map[sqlparser.Expr]string{}, Statement: stmt, }, nil } +// GetReservedArgumentFor retrieves a reserved argument name for a given expression. +// If the expression already has a reserved argument, it returns that name; +// otherwise, it reserves a new name based on the expression type. func (ctx *PlanningContext) GetReservedArgumentFor(expr sqlparser.Expr) string { for key, name := range ctx.ReservedArguments { if ctx.SemTable.EqualsExpr(key, expr) { @@ -104,13 +116,75 @@ func (ctx *PlanningContext) GetReservedArgumentFor(expr sqlparser.Expr) string { return bvName } -func (ctx *PlanningContext) GetArgumentFor(expr sqlparser.Expr, f func() string) string { - for key, name := range ctx.ReservedArguments { - if ctx.SemTable.EqualsExpr(key, expr) { - return name +// ShouldSkip determines if a given expression should be ignored in the SQL output building. +// It checks against expressions that have been marked to be excluded from further processing. +func (ctx *PlanningContext) ShouldSkip(expr sqlparser.Expr) bool { + for k := range ctx.skipPredicates { + if ctx.SemTable.EqualsExpr(expr, k) { + return true } } - bvName := f() - ctx.ReservedArguments[expr] = bvName - return bvName + return false +} + +// AddJoinPredicates associates additional RHS predicates with an existing join predicate. +// This is used to dynamically adjust the RHS predicates based on evolving join conditions. +func (ctx *PlanningContext) AddJoinPredicates(joinPred sqlparser.Expr, predicates ...sqlparser.Expr) { + fn := func(original sqlparser.Expr, rhsExprs []sqlparser.Expr) { + ctx.joinPredicates[original] = append(rhsExprs, predicates...) + } + if ctx.execOnJoinPredicateEqual(joinPred, fn) { + return + } + + // we didn't find an existing entry + ctx.joinPredicates[joinPred] = predicates +} + +// SkipJoinPredicates marks the predicates related to a specific join predicate as irrelevant +// for the current planning stage. This is used when a join has been pushed under a route and +// the original predicate will be used. +func (ctx *PlanningContext) SkipJoinPredicates(joinPred sqlparser.Expr) error { + fn := func(_ sqlparser.Expr, rhsExprs []sqlparser.Expr) { + ctx.skipThesePredicates(rhsExprs...) + } + if ctx.execOnJoinPredicateEqual(joinPred, fn) { + return nil + } + return vterrors.VT13001("predicate does not exist: " + sqlparser.String(joinPred)) +} + +// KeepPredicateInfo transfers join predicate information from another context. +// This is useful when nesting queries, ensuring consistent predicate handling across contexts. +func (ctx *PlanningContext) KeepPredicateInfo(other *PlanningContext) { + for k, v := range other.joinPredicates { + ctx.AddJoinPredicates(k, v...) + } + for expr := range other.skipPredicates { + ctx.skipThesePredicates(expr) + } +} + +// skipThesePredicates is a utility function to exclude certain predicates from SQL building +func (ctx *PlanningContext) skipThesePredicates(preds ...sqlparser.Expr) { +outer: + for _, expr := range preds { + for k := range ctx.skipPredicates { + if ctx.SemTable.EqualsExpr(expr, k) { + // already skipped + continue outer + } + } + ctx.skipPredicates[expr] = nil + } +} + +func (ctx *PlanningContext) execOnJoinPredicateEqual(joinPred sqlparser.Expr, fn func(original sqlparser.Expr, rhsExprs []sqlparser.Expr)) bool { + for key, values := range ctx.joinPredicates { + if ctx.SemTable.EqualsExpr(joinPred, key) { + fn(key, values) + return true + } + } + return false } diff --git a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json index 4a1c8fa1559..92722113693 100644 --- a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json @@ -3405,8 +3405,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select phone, id, city from (select phone, id, city from `user` where 1 != 1) as x where 1 != 1", - "Query": "select phone, id, city from (select phone, id, city from `user` where id > 12) as x limit :__upper_limit", + "FieldQuery": "select x.phone, x.id, x.city from (select phone, id, city from `user` where 1 != 1) as x where 1 != 1", + "Query": "select x.phone, x.id, x.city from (select phone, id, city from `user` where id > 12) as x limit :__upper_limit", "Table": "`user`" } ] @@ -3448,8 +3448,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select phone, id, city, 1 from (select phone, id, city from `user` where 1 != 1) as x where 1 != 1", - "Query": "select phone, id, city, 1 from (select phone, id, city from `user` where id > 12) as x limit :__upper_limit", + "FieldQuery": "select x.phone, x.id, x.city, 1 from (select phone, id, city from `user` where 1 != 1) as x where 1 != 1", + "Query": "select x.phone, x.id, x.city, 1 from (select phone, id, city from `user` where id > 12) as x limit :__upper_limit", "Table": "`user`" } ] @@ -3553,9 +3553,9 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select id, val1, 1, weight_string(val1) from (select id, val1 from `user` where 1 != 1) as x where 1 != 1", + "FieldQuery": "select x.id, x.val1, 1, weight_string(x.val1) from (select id, val1 from `user` where 1 != 1) as x where 1 != 1", "OrderBy": "(1|3) ASC", - "Query": "select id, val1, 1, weight_string(val1) from (select id, val1 from `user` where val2 < 4) as x order by val1 asc limit :__upper_limit", + "Query": "select x.id, x.val1, 1, weight_string(x.val1) from (select id, val1 from `user` where val2 < 4) as x order by val1 asc limit :__upper_limit", "Table": "`user`" } ] @@ -5856,8 +5856,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select id, val2 from (select id, val2 from `user` where 1 != 1) as x where 1 != 1", - "Query": "select id, val2 from (select id, val2 from `user` where val2 is null) as x limit :__upper_limit", + "FieldQuery": "select x.id, x.val2 from (select id, val2 from `user` where 1 != 1) as x where 1 != 1", + "Query": "select x.id, x.val2 from (select id, val2 from `user` where val2 is null) as x limit :__upper_limit", "Table": "`user`" } ] @@ -6467,74 +6467,59 @@ "OrderBy": "(4|6) ASC, (5|7) ASC", "Inputs": [ { - "OperatorType": "Projection", - "Expressions": [ - "count(*) as count(*)", - "count(*) as count(*)", - "`user`.col as col", - "ue.col as col", - "`user`.foo as foo", - "ue.bar as bar", - "weight_string(`user`.foo) as weight_string(`user`.foo)", - "weight_string(ue.bar) as weight_string(ue.bar)" - ], + "OperatorType": "Join", + "Variant": "HashLeftJoin", + "Collation": "binary", + "ComparisonType": "INT16", + "JoinColumnIndexes": "-1,1,-2,2,-3,3", + "Predicate": "`user`.col = ue.col", + "TableName": "`user`_user_extra", "Inputs": [ { - "OperatorType": "Join", - "Variant": "HashLeftJoin", - "Collation": "binary", - "ComparisonType": "INT16", - "JoinColumnIndexes": "-1,1,-2,2,-3,3", - "Predicate": "`user`.col = ue.col", - "TableName": "`user`_user_extra", + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select count(*), `user`.col, `user`.foo from `user` where 1 != 1 group by `user`.col, `user`.foo", + "Query": "select count(*), `user`.col, `user`.foo from `user` group by `user`.col, `user`.foo", + "Table": "`user`" + }, + { + "OperatorType": "Aggregate", + "Variant": "Ordered", + "Aggregates": "count_star(0)", + "GroupBy": "1, (2|3)", "Inputs": [ { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select count(*), `user`.col, `user`.foo from `user` where 1 != 1 group by `user`.col, `user`.foo", - "Query": "select count(*), `user`.col, `user`.foo from `user` group by `user`.col, `user`.foo", - "Table": "`user`" - }, - { - "OperatorType": "Aggregate", - "Variant": "Ordered", - "Aggregates": "count_star(0)", - "GroupBy": "1, (2|3)", + "OperatorType": "SimpleProjection", + "Columns": [ + 2, + 0, + 1, + 3 + ], "Inputs": [ { - "OperatorType": "SimpleProjection", - "Columns": [ - 2, - 0, - 1, - 3 - ], + "OperatorType": "Sort", + "Variant": "Memory", + "OrderBy": "0 ASC, (1|3) ASC", "Inputs": [ { - "OperatorType": "Sort", - "Variant": "Memory", - "OrderBy": "0 ASC, (1|3) ASC", + "OperatorType": "Limit", + "Count": "10", "Inputs": [ { - "OperatorType": "Limit", - "Count": "10", - "Inputs": [ - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select col, bar, 1, weight_string(bar) from (select col, bar from user_extra where 1 != 1) as ue where 1 != 1", - "Query": "select col, bar, 1, weight_string(bar) from (select col, bar from user_extra) as ue limit :__upper_limit", - "Table": "user_extra" - } - ] + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select ue.col, ue.bar, 1, weight_string(ue.bar) from (select col, bar from user_extra where 1 != 1) as ue where 1 != 1", + "Query": "select ue.col, ue.bar, 1, weight_string(ue.bar) from (select col, bar from user_extra) as ue limit :__upper_limit", + "Table": "user_extra" } ] } diff --git a/go/vt/vtgate/planbuilder/testdata/cte_cases.json b/go/vt/vtgate/planbuilder/testdata/cte_cases.json index e41aa27ce1b..c51a6f9144d 100644 --- a/go/vt/vtgate/planbuilder/testdata/cte_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/cte_cases.json @@ -198,8 +198,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select phone, id, city from (select phone, id, city from `user` where 1 != 1) as x where 1 != 1", - "Query": "select phone, id, city from (select phone, id, city from `user` where id > 12) as x limit :__upper_limit", + "FieldQuery": "select x.phone, x.id, x.city from (select phone, id, city from `user` where 1 != 1) as x where 1 != 1", + "Query": "select x.phone, x.id, x.city from (select phone, id, city from `user` where id > 12) as x limit :__upper_limit", "Table": "`user`" } ] @@ -241,8 +241,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select phone, id, city, 1 from (select phone, id, city from `user` where 1 != 1) as x where 1 != 1", - "Query": "select phone, id, city, 1 from (select phone, id, city from `user` where id > 12) as x limit :__upper_limit", + "FieldQuery": "select x.phone, x.id, x.city, 1 from (select phone, id, city from `user` where 1 != 1) as x where 1 != 1", + "Query": "select x.phone, x.id, x.city, 1 from (select phone, id, city from `user` where id > 12) as x limit :__upper_limit", "Table": "`user`" } ] @@ -346,9 +346,9 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select id, val1, 1, weight_string(val1) from (select id, val1 from `user` where 1 != 1) as x where 1 != 1", + "FieldQuery": "select x.id, x.val1, 1, weight_string(x.val1) from (select id, val1 from `user` where 1 != 1) as x where 1 != 1", "OrderBy": "(1|3) ASC", - "Query": "select id, val1, 1, weight_string(val1) from (select id, val1 from `user` where val2 < 4) as x order by val1 asc limit :__upper_limit", + "Query": "select x.id, x.val1, 1, weight_string(x.val1) from (select id, val1 from `user` where val2 < 4) as x order by val1 asc limit :__upper_limit", "Table": "`user`" } ] @@ -691,8 +691,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select id, val2 from (select id, val2 from `user` where 1 != 1) as x where 1 != 1", - "Query": "select id, val2 from (select id, val2 from `user` where val2 is null) as x limit :__upper_limit", + "FieldQuery": "select x.id, x.val2 from (select id, val2 from `user` where 1 != 1) as x where 1 != 1", + "Query": "select x.id, x.val2 from (select id, val2 from `user` where val2 is null) as x limit :__upper_limit", "Table": "`user`" } ] @@ -1830,8 +1830,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select id, foo, weight_string(id), weight_string(foo) from (select id, foo from (select id, foo from `user` where 1 != 1) as x where 1 != 1 union select id, foo from (select id, foo from `user` where 1 != 1) as x where 1 != 1) as dt where 1 != 1", - "Query": "select id, foo, weight_string(id), weight_string(foo) from (select id, foo from (select id, foo from `user`) as x union select id, foo from (select id, foo from `user`) as x) as dt", + "FieldQuery": "select dt.id, dt.foo, weight_string(dt.id), weight_string(dt.foo) from (select id, foo from (select id, foo from `user` where 1 != 1) as x where 1 != 1 union select id, foo from (select id, foo from `user` where 1 != 1) as x where 1 != 1) as dt where 1 != 1", + "Query": "select dt.id, dt.foo, weight_string(dt.id), weight_string(dt.foo) from (select id, foo from (select id, foo from `user`) as x union select id, foo from (select id, foo from `user`) as x) as dt", "Table": "`user`" } ] diff --git a/go/vt/vtgate/planbuilder/testdata/from_cases.json b/go/vt/vtgate/planbuilder/testdata/from_cases.json index 30b20f59087..0d9b1447a4c 100644 --- a/go/vt/vtgate/planbuilder/testdata/from_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/from_cases.json @@ -4088,8 +4088,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select col from (select col from user_extra where 1 != 1) as ue where 1 != 1", - "Query": "select col from (select col from user_extra) as ue limit :__upper_limit", + "FieldQuery": "select ue.col from (select col from user_extra where 1 != 1) as ue where 1 != 1", + "Query": "select ue.col from (select col from user_extra) as ue limit :__upper_limit", "Table": "user_extra" } ] @@ -4128,8 +4128,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select id, col from (select id, col from `user` where 1 != 1) as u where 1 != 1", - "Query": "select id, col from (select id, col from `user`) as u limit :__upper_limit", + "FieldQuery": "select u.id, u.col from (select id, col from `user` where 1 != 1) as u where 1 != 1", + "Query": "select u.id, u.col from (select id, col from `user`) as u limit :__upper_limit", "Table": "`user`" } ] @@ -4145,8 +4145,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select col, user_id from (select col, user_id from user_extra where 1 != 1) as ue where 1 != 1", - "Query": "select col, user_id from (select col, user_id from user_extra) as ue limit :__upper_limit", + "FieldQuery": "select ue.col, ue.user_id from (select col, user_id from user_extra where 1 != 1) as ue where 1 != 1", + "Query": "select ue.col, ue.user_id from (select col, user_id from user_extra) as ue limit :__upper_limit", "Table": "user_extra" } ] @@ -4201,22 +4201,28 @@ "ResultColumns": 2, "Inputs": [ { - "OperatorType": "Projection", - "Expressions": [ - "id as id", - "user_id as user_id", - "weight_string(id) as weight_string(id)", - "weight_string(user_id) as weight_string(user_id)" - ], + "OperatorType": "Join", + "Variant": "HashLeftJoin", + "Collation": "binary", + "ComparisonType": "INT16", + "JoinColumnIndexes": "-1,2", + "Predicate": "u.col = ue.col", + "TableName": "`user`_user_extra", "Inputs": [ { - "OperatorType": "Join", - "Variant": "HashLeftJoin", - "Collation": "binary", - "ComparisonType": "INT16", - "JoinColumnIndexes": "-1,2", - "Predicate": "u.col = ue.col", - "TableName": "`user`_user_extra", + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select id, col from (select id, col from `user` where 1 != 1) as u where 1 != 1", + "Query": "select distinct id, col from (select id, col from `user`) as u", + "Table": "`user`" + }, + { + "OperatorType": "Limit", + "Count": "10", "Inputs": [ { "OperatorType": "Route", @@ -4225,26 +4231,9 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select id, col from (select id, col from `user` where 1 != 1) as u where 1 != 1", - "Query": "select distinct id, col from (select id, col from `user`) as u", - "Table": "`user`" - }, - { - "OperatorType": "Limit", - "Count": "10", - "Inputs": [ - { - "OperatorType": "Route", - "Variant": "Scatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select col, user_id from (select col, user_id from user_extra where 1 != 1) as ue where 1 != 1", - "Query": "select col, user_id from (select col, user_id from user_extra) as ue limit :__upper_limit", - "Table": "user_extra" - } - ] + "FieldQuery": "select ue.col, ue.user_id from (select col, user_id from user_extra where 1 != 1) as ue where 1 != 1", + "Query": "select ue.col, ue.user_id from (select col, user_id from user_extra) as ue limit :__upper_limit", + "Table": "user_extra" } ] } @@ -4279,5 +4268,79 @@ "main.unsharded" ] } + }, + { + "comment": "pushing derived projection under the join should not cause problems", + "query": "SELECT count(*) FROM (SELECT DISTINCT u.user_id FROM user u JOIN user_extra ue ON u.id = ue.user_id JOIN music m ON m.id = u.id) subquery_for_count", + "plan": { + "QueryType": "SELECT", + "Original": "SELECT count(*) FROM (SELECT DISTINCT u.user_id FROM user u JOIN user_extra ue ON u.id = ue.user_id JOIN music m ON m.id = u.id) subquery_for_count", + "Instructions": { + "OperatorType": "Aggregate", + "Variant": "Scalar", + "Aggregates": "count_star(0) AS count(*)", + "Inputs": [ + { + "OperatorType": "SimpleProjection", + "Columns": [ + 1 + ], + "Inputs": [ + { + "OperatorType": "Distinct", + "Collations": [ + "(0:2)", + "1" + ], + "Inputs": [ + { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "R:0,L:1,R:1", + "JoinVars": { + "m_id": 0 + }, + "TableName": "music_`user`, user_extra", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select subquery_for_count.`m.id`, 1 from (select m.id as `m.id` from music as m where 1 != 1) as subquery_for_count where 1 != 1", + "Query": "select distinct subquery_for_count.`m.id`, 1 from (select m.id as `m.id` from music as m) as subquery_for_count", + "Table": "music" + }, + { + "OperatorType": "Route", + "Variant": "EqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select subquery_for_count.user_id, weight_string(subquery_for_count.user_id) from (select u.user_id from `user` as u, user_extra as ue where 1 != 1) as subquery_for_count where 1 != 1", + "Query": "select distinct subquery_for_count.user_id, weight_string(subquery_for_count.user_id) from (select u.user_id from `user` as u, user_extra as ue where u.id = :m_id and u.id = ue.user_id) as subquery_for_count", + "Table": "`user`, user_extra", + "Values": [ + ":m_id" + ], + "Vindex": "user_index" + } + ] + } + ] + } + ] + } + ] + }, + "TablesUsed": [ + "user.music", + "user.user", + "user.user_extra" + ] + } } ] diff --git a/go/vt/vtgate/planbuilder/testdata/large_union_cases.json b/go/vt/vtgate/planbuilder/testdata/large_union_cases.json index 2d66bc62d42..89adb07335a 100644 --- a/go/vt/vtgate/planbuilder/testdata/large_union_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/large_union_cases.json @@ -23,8 +23,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select content, user_id, weight_string(content), weight_string(user_id) from ((select content, user_id from music where 1 != 1) union (select content, user_id from music where 1 != 1)) as dt where 1 != 1", - "Query": "select content, user_id, weight_string(content), weight_string(user_id) from ((select content, user_id from music where user_id = 1270698330 order by created_at asc, id asc limit 11) union (select content, user_id from music where user_id = 1270698330 order by created_at asc, id asc limit 11)) as dt", + "FieldQuery": "select dt.content, dt.user_id, weight_string(content), weight_string(user_id) from ((select content, user_id from music where 1 != 1) union (select content, user_id from music where 1 != 1)) as dt where 1 != 1", + "Query": "select dt.content, dt.user_id, weight_string(content), weight_string(user_id) from ((select content, user_id from music where user_id = 1270698330 order by created_at asc, id asc limit 11) union (select content, user_id from music where user_id = 1270698330 order by created_at asc, id asc limit 11)) as dt", "Table": "music", "Values": [ "1270698330" @@ -38,8 +38,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select content, user_id, weight_string(content), weight_string(user_id) from ((select content, user_id from music where 1 != 1) union (select content, user_id from music where 1 != 1)) as dt where 1 != 1", - "Query": "select content, user_id, weight_string(content), weight_string(user_id) from ((select content, user_id from music where user_id = 1270699497 order by created_at asc, id asc limit 11) union (select content, user_id from music where user_id = 1270699497 order by created_at asc, id asc limit 11)) as dt", + "FieldQuery": "select dt.content, dt.user_id, weight_string(content), weight_string(user_id) from ((select content, user_id from music where 1 != 1) union (select content, user_id from music where 1 != 1)) as dt where 1 != 1", + "Query": "select dt.content, dt.user_id, weight_string(content), weight_string(user_id) from ((select content, user_id from music where user_id = 1270699497 order by created_at asc, id asc limit 11) union (select content, user_id from music where user_id = 1270699497 order by created_at asc, id asc limit 11)) as dt", "Table": "music", "Values": [ "1270699497" diff --git a/go/vt/vtgate/planbuilder/testdata/select_cases.json b/go/vt/vtgate/planbuilder/testdata/select_cases.json index 039093cd0c7..59a6e4686a6 100644 --- a/go/vt/vtgate/planbuilder/testdata/select_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/select_cases.json @@ -2021,11 +2021,11 @@ } }, { - "comment": "select (select col from user limit 1) as a from user join user_extra order by a", - "query": "select (select col from user limit 1) as a from user join user_extra order by a", + "comment": "subquery in select expression of derived table", + "query": "select t.a from (select (select col from user limit 1) as a from user join user_extra) t", "plan": { "QueryType": "SELECT", - "Original": "select (select col from user limit 1) as a from user join user_extra order by a", + "Original": "select t.a from (select (select col from user limit 1) as a from user join user_extra) t", "Instructions": { "OperatorType": "Join", "Variant": "Join", @@ -2065,9 +2065,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select :__sq1 as __sq1, weight_string(:__sq1) from `user` where 1 != 1", - "OrderBy": "(0|1) ASC", - "Query": "select :__sq1 as __sq1, weight_string(:__sq1) from `user` order by __sq1 asc", + "FieldQuery": "select t.a from (select :__sq1 as a from `user` where 1 != 1) as t where 1 != 1", + "Query": "select t.a from (select :__sq1 as a from `user`) as t", "Table": "`user`" } ] @@ -2092,11 +2091,11 @@ } }, { - "comment": "select t.a from (select (select col from user limit 1) as a from user join user_extra) t", - "query": "select t.a from (select (select col from user limit 1) as a from user join user_extra) t", + "comment": "select (select col from user limit 1) as a from user join user_extra order by a", + "query": "select (select col from user limit 1) as a from user join user_extra order by a", "plan": { "QueryType": "SELECT", - "Original": "select t.a from (select (select col from user limit 1) as a from user join user_extra) t", + "Original": "select (select col from user limit 1) as a from user join user_extra order by a", "Instructions": { "OperatorType": "Join", "Variant": "Join", @@ -2136,8 +2135,9 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select t.a from (select :__sq1 as a from `user` where 1 != 1) as t where 1 != 1", - "Query": "select t.a from (select :__sq1 as a from `user`) as t", + "FieldQuery": "select :__sq1 as __sq1, weight_string(:__sq1) from `user` where 1 != 1", + "OrderBy": "(0|1) ASC", + "Query": "select :__sq1 as __sq1, weight_string(:__sq1) from `user` order by __sq1 asc", "Table": "`user`" } ] @@ -4223,8 +4223,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select user_id from (select user_id from user_extra where 1 != 1) as ue where 1 != 1", - "Query": "select user_id from (select user_id from user_extra) as ue limit :__upper_limit", + "FieldQuery": "select ue.user_id from (select user_id from user_extra where 1 != 1) as ue where 1 != 1", + "Query": "select ue.user_id from (select user_id from user_extra) as ue limit :__upper_limit", "Table": "user_extra" } ] diff --git a/go/vt/vtgate/planbuilder/testdata/tpch_cases.json b/go/vt/vtgate/planbuilder/testdata/tpch_cases.json index a5c144df355..2d225808992 100644 --- a/go/vt/vtgate/planbuilder/testdata/tpch_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/tpch_cases.json @@ -578,9 +578,9 @@ { "OperatorType": "Join", "Variant": "Join", - "JoinColumnIndexes": "L:0,R:0,L:1,R:1,L:2,L:5,R:2,L:6", + "JoinColumnIndexes": "L:0,R:0,L:1,R:1,L:2,L:4,R:2,L:5", "JoinVars": { - "n1_n_name": 4, + "n1_n_name": 1, "o_custkey": 3 }, "TableName": "lineitem_orders_supplier_nation_customer_nation", @@ -591,18 +591,17 @@ "sum(volume) * count(*) as revenue", ":2 as supp_nation", ":3 as l_year", - ":4 as orders.o_custkey", - ":5 as n1.n_name", - ":6 as weight_string(supp_nation)", - ":7 as weight_string(l_year)" + ":4 as o_custkey", + ":5 as weight_string(supp_nation)", + ":6 as weight_string(l_year)" ], "Inputs": [ { "OperatorType": "Join", "Variant": "Join", - "JoinColumnIndexes": "L:0,R:0,R:1,L:1,L:2,L:3,R:2,L:5", + "JoinColumnIndexes": "L:0,R:0,R:1,L:1,L:2,R:2,L:4", "JoinVars": { - "l_suppkey": 4 + "l_suppkey": 3 }, "TableName": "lineitem_orders_supplier_nation", "Inputs": [ @@ -611,18 +610,17 @@ "Expressions": [ "sum(volume) * count(*) as revenue", ":2 as l_year", - ":3 as orders.o_custkey", - ":4 as n1.n_name", - ":5 as lineitem.l_suppkey", - ":6 as weight_string(l_year)" + ":3 as o_custkey", + ":4 as l_suppkey", + ":5 as weight_string(l_year)" ], "Inputs": [ { "OperatorType": "Join", "Variant": "Join", - "JoinColumnIndexes": "L:0,R:0,L:1,L:2,L:3,L:4,L:6", + "JoinColumnIndexes": "L:0,R:0,L:1,R:1,L:2,L:4", "JoinVars": { - "l_orderkey": 5 + "l_orderkey": 3 }, "TableName": "lineitem_orders", "Inputs": [ @@ -633,9 +631,9 @@ "Name": "main", "Sharded": true }, - "FieldQuery": "select sum(volume) as revenue, l_year, shipping.`orders.o_custkey`, shipping.`n1.n_name`, shipping.`lineitem.l_suppkey`, shipping.`lineitem.l_orderkey`, weight_string(l_year), supp_nation, weight_string(supp_nation), cust_nation, weight_string(cust_nation) from (select extract(year from l_shipdate) as l_year, l_extendedprice * (1 - l_discount) as volume, orders.o_custkey as `orders.o_custkey`, lineitem.l_suppkey as `lineitem.l_suppkey`, lineitem.l_orderkey as `lineitem.l_orderkey` from lineitem where 1 != 1) as shipping where 1 != 1 group by l_year, shipping.`orders.o_custkey`, shipping.`n1.n_name`, shipping.`lineitem.l_suppkey`, shipping.`lineitem.l_orderkey`, weight_string(l_year)", - "OrderBy": "(7|8) ASC, (9|10) ASC, (1|6) ASC", - "Query": "select sum(volume) as revenue, l_year, shipping.`orders.o_custkey`, shipping.`n1.n_name`, shipping.`lineitem.l_suppkey`, shipping.`lineitem.l_orderkey`, weight_string(l_year), supp_nation, weight_string(supp_nation), cust_nation, weight_string(cust_nation) from (select extract(year from l_shipdate) as l_year, l_extendedprice * (1 - l_discount) as volume, orders.o_custkey as `orders.o_custkey`, lineitem.l_suppkey as `lineitem.l_suppkey`, lineitem.l_orderkey as `lineitem.l_orderkey` from lineitem where l_shipdate between date('1995-01-01') and date('1996-12-31')) as shipping group by l_year, shipping.`orders.o_custkey`, shipping.`n1.n_name`, shipping.`lineitem.l_suppkey`, shipping.`lineitem.l_orderkey`, weight_string(l_year) order by supp_nation asc, cust_nation asc, l_year asc", + "FieldQuery": "select sum(volume) as revenue, l_year, shipping.l_suppkey, shipping.l_orderkey, weight_string(l_year), supp_nation, weight_string(supp_nation), cust_nation, weight_string(cust_nation) from (select extract(year from l_shipdate) as l_year, l_extendedprice * (1 - l_discount) as volume, l_suppkey as l_suppkey, l_orderkey as l_orderkey from lineitem where 1 != 1) as shipping where 1 != 1 group by l_year, shipping.l_suppkey, shipping.l_orderkey, weight_string(l_year)", + "OrderBy": "(5|6) ASC, (7|8) ASC, (1|4) ASC", + "Query": "select sum(volume) as revenue, l_year, shipping.l_suppkey, shipping.l_orderkey, weight_string(l_year), supp_nation, weight_string(supp_nation), cust_nation, weight_string(cust_nation) from (select extract(year from l_shipdate) as l_year, l_extendedprice * (1 - l_discount) as volume, l_suppkey as l_suppkey, l_orderkey as l_orderkey from lineitem where l_shipdate between date('1995-01-01') and date('1996-12-31')) as shipping group by l_year, shipping.l_suppkey, shipping.l_orderkey, weight_string(l_year) order by supp_nation asc, cust_nation asc, l_year asc", "Table": "lineitem" }, { @@ -645,8 +643,8 @@ "Name": "main", "Sharded": true }, - "FieldQuery": "select count(*) from orders where 1 != 1 group by .0", - "Query": "select count(*) from orders where o_orderkey = :l_orderkey group by .0", + "FieldQuery": "select count(*), shipping.o_custkey from (select o_custkey as o_custkey from orders where 1 != 1) as shipping where 1 != 1 group by shipping.o_custkey", + "Query": "select count(*), shipping.o_custkey from (select o_custkey as o_custkey from orders where o_orderkey = :l_orderkey) as shipping group by shipping.o_custkey", "Table": "orders", "Values": [ ":l_orderkey" @@ -681,8 +679,8 @@ "Name": "main", "Sharded": true }, - "FieldQuery": "select count(*), shipping.`supplier.s_nationkey` from (select supplier.s_nationkey as `supplier.s_nationkey` from supplier where 1 != 1) as shipping where 1 != 1 group by shipping.`supplier.s_nationkey`", - "Query": "select count(*), shipping.`supplier.s_nationkey` from (select supplier.s_nationkey as `supplier.s_nationkey` from supplier where s_suppkey = :l_suppkey) as shipping group by shipping.`supplier.s_nationkey`", + "FieldQuery": "select count(*), shipping.s_nationkey from (select s_nationkey as s_nationkey from supplier where 1 != 1) as shipping where 1 != 1 group by shipping.s_nationkey", + "Query": "select count(*), shipping.s_nationkey from (select s_nationkey as s_nationkey from supplier where s_suppkey = :l_suppkey) as shipping group by shipping.s_nationkey", "Table": "supplier", "Values": [ ":l_suppkey" @@ -696,8 +694,8 @@ "Name": "main", "Sharded": true }, - "FieldQuery": "select count(*), supp_nation, weight_string(supp_nation) from (select n1.n_name as supp_nation from nation as n1 where 1 != 1) as shipping where 1 != 1 group by supp_nation, weight_string(supp_nation)", - "Query": "select count(*), supp_nation, weight_string(supp_nation) from (select n1.n_name as supp_nation from nation as n1 where n1.n_nationkey = :s_nationkey) as shipping group by supp_nation, weight_string(supp_nation)", + "FieldQuery": "select count(*), supp_nation, weight_string(supp_nation) from (select n1.n_name as supp_nation, n1.n_name = 'FRANCE' as `n1.n_name = 'FRANCE'`, n1.n_name = 'GERMANY' as `n1.n_name = 'GERMANY'` from nation as n1 where 1 != 1) as shipping where 1 != 1 group by supp_nation, weight_string(supp_nation)", + "Query": "select count(*), supp_nation, weight_string(supp_nation) from (select n1.n_name as supp_nation, n1.n_name = 'FRANCE' as `n1.n_name = 'FRANCE'`, n1.n_name = 'GERMANY' as `n1.n_name = 'GERMANY'` from nation as n1 where n1.n_nationkey = :s_nationkey) as shipping group by supp_nation, weight_string(supp_nation)", "Table": "nation", "Values": [ ":s_nationkey" @@ -736,8 +734,8 @@ "Name": "main", "Sharded": true }, - "FieldQuery": "select count(*), shipping.`customer.c_nationkey` from (select customer.c_nationkey as `customer.c_nationkey` from customer where 1 != 1) as shipping where 1 != 1 group by shipping.`customer.c_nationkey`", - "Query": "select count(*), shipping.`customer.c_nationkey` from (select customer.c_nationkey as `customer.c_nationkey` from customer where c_custkey = :o_custkey) as shipping group by shipping.`customer.c_nationkey`", + "FieldQuery": "select count(*), shipping.c_nationkey from (select c_nationkey as c_nationkey from customer where 1 != 1) as shipping where 1 != 1 group by shipping.c_nationkey", + "Query": "select count(*), shipping.c_nationkey from (select c_nationkey as c_nationkey from customer where c_custkey = :o_custkey) as shipping group by shipping.c_nationkey", "Table": "customer", "Values": [ ":o_custkey" diff --git a/go/vt/vtgate/planbuilder/testdata/union_cases.json b/go/vt/vtgate/planbuilder/testdata/union_cases.json index 3f25b60556b..12a709d023f 100644 --- a/go/vt/vtgate/planbuilder/testdata/union_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/union_cases.json @@ -42,8 +42,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select id, weight_string(id) from (select id from `user` where 1 != 1 union select id from music where 1 != 1) as dt where 1 != 1", - "Query": "select id, weight_string(id) from (select id from `user` union select id from music) as dt", + "FieldQuery": "select dt.id, weight_string(dt.id) from (select id from `user` where 1 != 1 union select id from music where 1 != 1) as dt where 1 != 1", + "Query": "select dt.id, weight_string(dt.id) from (select id from `user` union select id from music) as dt", "Table": "`user`, music" } ] @@ -384,8 +384,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select id, weight_string(id) from (select id from `user` where 1 != 1 union select id from music where 1 != 1 union select 1 from dual where 1 != 1) as dt where 1 != 1", - "Query": "select id, weight_string(id) from (select id from `user` union select id from music union select 1 from dual) as dt", + "FieldQuery": "select dt.id, weight_string(dt.id) from (select id from `user` where 1 != 1 union select id from music where 1 != 1 union select 1 from dual where 1 != 1) as dt where 1 != 1", + "Query": "select dt.id, weight_string(dt.id) from (select id from `user` union select id from music union select 1 from dual) as dt", "Table": "`user`, dual, music" } ] @@ -503,8 +503,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select id, weight_string(id) from ((select id from `user` where 1 != 1) union (select id from `user` where 1 != 1)) as dt where 1 != 1", - "Query": "select id, weight_string(id) from ((select id from `user` order by id desc) union (select id from `user` order by id asc)) as dt", + "FieldQuery": "select dt.id, weight_string(dt.id) from ((select id from `user` where 1 != 1) union (select id from `user` where 1 != 1)) as dt where 1 != 1", + "Query": "select dt.id, weight_string(dt.id) from ((select id from `user` order by id desc) union (select id from `user` order by id asc)) as dt", "Table": "`user`" } ] @@ -761,8 +761,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select id, weight_string(id) from (select id from `user` where 1 != 1 union select id + 1 from `user` where 1 != 1 union select user_id from user_extra where 1 != 1) as dt where 1 != 1", - "Query": "select id, weight_string(id) from (select id from `user` union select id + 1 from `user` union select user_id from user_extra) as dt", + "FieldQuery": "select dt.id, weight_string(dt.id) from (select id from `user` where 1 != 1 union select id + 1 from `user` where 1 != 1 union select user_id from user_extra where 1 != 1) as dt where 1 != 1", + "Query": "select dt.id, weight_string(dt.id) from (select id from `user` union select id + 1 from `user` union select user_id from user_extra) as dt", "Table": "`user`, user_extra" } ] @@ -796,8 +796,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select id, weight_string(id) from (select id from `user` where 1 != 1 union select id from music where 1 != 1) as dt where 1 != 1", - "Query": "select id, weight_string(id) from (select id from `user` union select id from music) as dt", + "FieldQuery": "select id, weight_string(dt.id) from (select id from `user` where 1 != 1 union select id from music where 1 != 1) as dt where 1 != 1", + "Query": "select id, weight_string(dt.id) from (select id from `user` union select id from music) as dt", "Table": "`user`, music" }, { @@ -846,8 +846,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select id, weight_string(id) from (select id from `user` where 1 != 1 union select 3 from dual where 1 != 1) as dt where 1 != 1", - "Query": "select id, weight_string(id) from (select id from `user` union select 3 from dual limit :__upper_limit) as dt", + "FieldQuery": "select id, weight_string(dt.id) from (select id from `user` where 1 != 1 union select 3 from dual where 1 != 1) as dt where 1 != 1", + "Query": "select id, weight_string(dt.id) from (select id from `user` union select 3 from dual limit :__upper_limit) as dt", "Table": "`user`, dual" } ] @@ -905,8 +905,8 @@ "Name": "main", "Sharded": false }, - "FieldQuery": "select col, weight_string(col) from (select col from unsharded where 1 != 1 union select col2 from unsharded where 1 != 1) as dt where 1 != 1", - "Query": "select col, weight_string(col) from (select col from unsharded union select col2 from unsharded) as dt", + "FieldQuery": "select dt.col, weight_string(col) from (select col from unsharded where 1 != 1 union select col2 from unsharded where 1 != 1) as dt where 1 != 1", + "Query": "select dt.col, weight_string(col) from (select col from unsharded union select col2 from unsharded) as dt", "Table": "unsharded" }, { @@ -1055,8 +1055,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select id, weight_string(id) from (select id from `user` where 1 != 1 union select 3 from dual where 1 != 1) as dt where 1 != 1", - "Query": "select id, weight_string(id) from (select id from `user` union select 3 from dual) as dt", + "FieldQuery": "select dt.id, weight_string(dt.id) from (select id from `user` where 1 != 1 union select 3 from dual where 1 != 1) as dt where 1 != 1", + "Query": "select dt.id, weight_string(dt.id) from (select id from `user` union select 3 from dual) as dt", "Table": "`user`, dual" } ] @@ -1428,8 +1428,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select bar, baz, toto, weight_string(bar), weight_string(baz), weight_string(toto) from (select bar, baz, toto from music where 1 != 1 union select foo, foo, foo from `user` where 1 != 1) as dt where 1 != 1", - "Query": "select bar, baz, toto, weight_string(bar), weight_string(baz), weight_string(toto) from (select bar, baz, toto from music union select foo, foo, foo from `user`) as dt", + "FieldQuery": "select dt.bar, dt.baz, dt.toto, weight_string(dt.bar), weight_string(dt.baz), weight_string(dt.toto) from (select bar, baz, toto from music where 1 != 1 union select foo, foo, foo from `user` where 1 != 1) as dt where 1 != 1", + "Query": "select dt.bar, dt.baz, dt.toto, weight_string(dt.bar), weight_string(dt.baz), weight_string(dt.toto) from (select bar, baz, toto from music union select foo, foo, foo from `user`) as dt", "Table": "`user`, music" } ] @@ -1462,8 +1462,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select foo, foo, foo, weight_string(foo) from (select foo, foo, foo from `user` where 1 != 1 union select bar, baz, toto from music where 1 != 1) as dt where 1 != 1", - "Query": "select foo, foo, foo, weight_string(foo) from (select foo, foo, foo from `user` union select bar, baz, toto from music) as dt", + "FieldQuery": "select dt.foo, dt.foo, dt.foo, weight_string(dt.foo) from (select foo, foo, foo from `user` where 1 != 1 union select bar, baz, toto from music where 1 != 1) as dt where 1 != 1", + "Query": "select dt.foo, dt.foo, dt.foo, weight_string(dt.foo) from (select foo, foo, foo from `user` union select bar, baz, toto from music) as dt", "Table": "`user`, music" } ] @@ -1502,8 +1502,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select foo, weight_string(foo) from (select foo from `user` where 1 != 1 union select foo from `user` where 1 != 1) as dt where 1 != 1", - "Query": "select foo, weight_string(foo) from (select foo from `user` where bar = 12 union select foo from `user` where bar = 134) as dt", + "FieldQuery": "select dt.foo, weight_string(dt.foo) from (select foo from `user` where 1 != 1 union select foo from `user` where 1 != 1) as dt where 1 != 1", + "Query": "select dt.foo, weight_string(dt.foo) from (select foo from `user` where bar = 12 union select foo from `user` where bar = 134) as dt", "Table": "`user`" } ] @@ -1521,8 +1521,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select bar, weight_string(bar) from (select bar from music where 1 != 1 union select bar from music where 1 != 1) as dt where 1 != 1", - "Query": "select bar, weight_string(bar) from (select bar from music where foo = 12 and bar = :t1_foo union select bar from music where foo = 1234 and bar = :t1_foo) as dt", + "FieldQuery": "select dt.bar, weight_string(dt.bar) from (select bar from music where 1 != 1 union select bar from music where 1 != 1) as dt where 1 != 1", + "Query": "select dt.bar, weight_string(dt.bar) from (select bar from music where foo = 12 and bar = :t1_foo union select bar from music where foo = 1234 and bar = :t1_foo) as dt", "Table": "music" } ] @@ -1585,8 +1585,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select col1, weight_string(col1) from (select col1 from `user` where 1 != 1 union select 3 from `user` where 1 != 1) as dt where 1 != 1", - "Query": "select col1, weight_string(col1) from (select col1 from `user` union select 3 from `user`) as dt", + "FieldQuery": "select dt.col1, weight_string(dt.col1) from (select col1 from `user` where 1 != 1 union select 3 from `user` where 1 != 1) as dt where 1 != 1", + "Query": "select dt.col1, weight_string(dt.col1) from (select col1 from `user` union select 3 from `user`) as dt", "Table": "`user`" } ] @@ -1616,8 +1616,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select `3`, weight_string(`3`) from (select 3 from `user` where 1 != 1 union select col1 from `user` where 1 != 1) as dt where 1 != 1", - "Query": "select `3`, weight_string(`3`) from (select 3 from `user` union select col1 from `user`) as dt", + "FieldQuery": "select dt.`3`, weight_string(dt.`3`) from (select 3 from `user` where 1 != 1 union select col1 from `user` where 1 != 1) as dt where 1 != 1", + "Query": "select dt.`3`, weight_string(dt.`3`) from (select 3 from `user` union select col1 from `user`) as dt", "Table": "`user`" } ] @@ -1647,8 +1647,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select `3`, weight_string(`3`) from (select 3 from `user` where 1 != 1 union select now() from `user` where 1 != 1) as dt where 1 != 1", - "Query": "select `3`, weight_string(`3`) from (select 3 from `user` union select now() from `user`) as dt", + "FieldQuery": "select dt.`3`, weight_string(dt.`3`) from (select 3 from `user` where 1 != 1 union select now() from `user` where 1 != 1) as dt where 1 != 1", + "Query": "select dt.`3`, weight_string(dt.`3`) from (select 3 from `user` union select now() from `user`) as dt", "Table": "`user`" } ] @@ -1678,8 +1678,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select `now()`, weight_string(`now()`) from (select now() from `user` where 1 != 1 union select 3 from `user` where 1 != 1) as dt where 1 != 1", - "Query": "select `now()`, weight_string(`now()`) from (select now() from `user` union select 3 from `user`) as dt", + "FieldQuery": "select dt.`now()`, weight_string(dt.`now()`) from (select now() from `user` where 1 != 1 union select 3 from `user` where 1 != 1) as dt where 1 != 1", + "Query": "select dt.`now()`, weight_string(dt.`now()`) from (select now() from `user` union select 3 from `user`) as dt", "Table": "`user`" } ] @@ -1709,8 +1709,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select `now()`, weight_string(`now()`) from (select now() from `user` where 1 != 1 union select id from `user` where 1 != 1) as dt where 1 != 1", - "Query": "select `now()`, weight_string(`now()`) from (select now() from `user` union select id from `user`) as dt", + "FieldQuery": "select dt.`now()`, weight_string(dt.`now()`) from (select now() from `user` where 1 != 1 union select id from `user` where 1 != 1) as dt where 1 != 1", + "Query": "select dt.`now()`, weight_string(dt.`now()`) from (select now() from `user` union select id from `user`) as dt", "Table": "`user`" } ] diff --git a/go/vt/vtgate/semantics/semantic_state.go b/go/vt/vtgate/semantics/semantic_state.go index 9067e77dd88..5f2529a6e83 100644 --- a/go/vt/vtgate/semantics/semantic_state.go +++ b/go/vt/vtgate/semantics/semantic_state.go @@ -806,7 +806,7 @@ func (st *SemTable) SingleKeyspace() (ks *vindexes.Keyspace) { func (st *SemTable) EqualsExpr(a, b sqlparser.Expr) bool { // If there is no SemTable, then we cannot compare the expressions. if st == nil { - return false + return sqlparser.Equals.Expr(a, b) } return st.ASTEquals().Expr(a, b) }