diff --git a/go/vt/vtgate/planbuilder/operators/apply_join.go b/go/vt/vtgate/planbuilder/operators/apply_join.go index 0900384dc5e..f7bd5b131b8 100644 --- a/go/vt/vtgate/planbuilder/operators/apply_join.go +++ b/go/vt/vtgate/planbuilder/operators/apply_join.go @@ -69,11 +69,10 @@ type ( // so they can be used for the result of this expression that is using data from both sides. // All fields will be used for these applyJoinColumn struct { - Original sqlparser.Expr // this is the original expression being passed through - LHSExprs []BindVarExpr // These are the expressions we are pushing to the left hand side which we'll receive as bind variables - RHSExpr sqlparser.Expr // This the expression that we'll evaluate on the right hand side. This is nil, if the right hand side has nothing. - DTColName *sqlparser.ColName // This is the output column name that the parent of JOIN will be seeing. If this is unset, then the colname is the String(Original). We set this when we push Projections with derived tables underneath a Join. - GroupBy bool // if this is true, we need to push this down to our inputs with addToGroupBy set to true + Original sqlparser.Expr // this is the original expression being passed through + LHSExprs []BindVarExpr // These are the expressions we are pushing to the left hand side which we'll receive as bind variables + RHSExpr sqlparser.Expr // This the expression that we'll evaluate on the right hand side. This is nil, if the right hand side has nothing. + GroupBy bool // if this is true, we need to push this down to our inputs with addToGroupBy set to true } // BindVarExpr is an expression needed from one side of a join/subquery, and the argument name for it. @@ -225,8 +224,7 @@ func (aj *ApplyJoin) getJoinColumnFor(ctx *plancontext.PlanningContext, orig *sq func applyJoinCompare(ctx *plancontext.PlanningContext, expr sqlparser.Expr) func(e applyJoinColumn) bool { return func(e applyJoinColumn) bool { - // e.DTColName is how the outside world will be using this expression. So we should check for an equality with that too. - return ctx.SemTable.EqualsExprWithDeps(e.Original, expr) || ctx.SemTable.EqualsExprWithDeps(e.DTColName, expr) + return ctx.SemTable.EqualsExprWithDeps(e.Original, expr) } } @@ -447,11 +445,8 @@ func (jc applyJoinColumn) String() string { lhs := slice.Map(jc.LHSExprs, func(e BindVarExpr) string { return sqlparser.String(e.Expr) }) - if jc.DTColName == nil { - return fmt.Sprintf("[%s | %s | %s]", strings.Join(lhs, ", "), rhs, sqlparser.String(jc.Original)) - } - return fmt.Sprintf("[%s | %s | %s | %s]", strings.Join(lhs, ", "), rhs, sqlparser.String(jc.Original), sqlparser.String(jc.DTColName)) + return fmt.Sprintf("[%s | %s | %s]", strings.Join(lhs, ", "), rhs, sqlparser.String(jc.Original)) } func (jc applyJoinColumn) IsPureLeft() bool { @@ -467,16 +462,10 @@ func (jc applyJoinColumn) IsMixedLeftAndRight() bool { } func (jc applyJoinColumn) GetPureLeftExpr() sqlparser.Expr { - if jc.DTColName != nil { - return jc.DTColName - } return jc.LHSExprs[0].Expr } func (jc applyJoinColumn) GetRHSExpr() sqlparser.Expr { - if jc.DTColName != nil { - return jc.DTColName - } return jc.RHSExpr } diff --git a/go/vt/vtgate/planbuilder/operators/projection.go b/go/vt/vtgate/planbuilder/operators/projection.go index 616731bcfb7..95ebeadaeb7 100644 --- a/go/vt/vtgate/planbuilder/operators/projection.go +++ b/go/vt/vtgate/planbuilder/operators/projection.go @@ -87,7 +87,7 @@ type ( ProjExpr struct { Original *sqlparser.AliasedExpr // this is the expression the user asked for. should only be used to decide on the column alias - EvalExpr sqlparser.Expr // EvalExpr is the expression that will be evaluated at runtime + EvalExpr sqlparser.Expr // EvalExpr represents the expression evaluated at runtime or used when the ProjExpr is pushed under a route ColExpr sqlparser.Expr // ColExpr is used during planning to figure out which column this ProjExpr is representing Info ExprInfo // Here we store information about evalengine, offsets or subqueries } diff --git a/go/vt/vtgate/planbuilder/operators/projection_pushing.go b/go/vt/vtgate/planbuilder/operators/projection_pushing.go index b5c6a10bb78..21d2a058a86 100644 --- a/go/vt/vtgate/planbuilder/operators/projection_pushing.go +++ b/go/vt/vtgate/planbuilder/operators/projection_pushing.go @@ -320,7 +320,7 @@ func splitSubqueryExpression( alias string, ) applyJoinColumn { col := join.getJoinColumnFor(ctx, pe.Original, pe.ColExpr, false) - return pushDownSplitJoinCol(col, lhs, pe, alias, rhs) + return pushDownSplitJoinCol(col, lhs, rhs, pe, alias) } func splitUnexploredExpression( @@ -334,23 +334,50 @@ func splitUnexploredExpression( original := sqlparser.Clone(pe.Original) expr := pe.ColExpr - var colName *sqlparser.ColName - if dt != nil { - if !pe.isSameInAndOut(ctx) { - panic(vterrors.VT13001("derived table columns must be the same in and out")) - } - colName = sqlparser.NewColNameWithQualifier(pe.Original.ColumnName(), sqlparser.NewTableName(dt.Alias)) - ctx.SemTable.CopySemanticInfo(expr, colName) - } - // Get a applyJoinColumn for the current expression. col := join.getJoinColumnFor(ctx, original, expr, false) - col.DTColName = colName - return pushDownSplitJoinCol(col, lhs, pe, alias, rhs) + if dt == nil { + return pushDownSplitJoinCol(col, lhs, rhs, pe, alias) + } + + if !pe.isSameInAndOut(ctx) { + panic(vterrors.VT13001("derived table columns must be the same in and out")) + } + // we are pushing a derived projection through a join. that means that after this rewrite, we are on top of the + // derived table divider, and can only see the projected columns, not the underlying expressions + colName := sqlparser.NewColNameWithQualifier(pe.Original.ColumnName(), sqlparser.NewTableName(dt.Alias)) + ctx.SemTable.CopySemanticInfo(expr, colName) + col.Original = colName + if alias == "" { + alias = pe.Original.ColumnName() + } + + // Update the left and right child columns and names based on the applyJoinColumn type. + switch { + case col.IsPureLeft(): + lhs.add(pe, alias) + col.LHSExprs[0].Expr = colName + case col.IsPureRight(): + rhs.add(pe, alias) + col.RHSExpr = colName + case col.IsMixedLeftAndRight(): + for _, lhsExpr := range col.LHSExprs { + ae := aeWrap(lhsExpr.Expr) + columnName := ae.ColumnName() + ae.As = sqlparser.NewIdentifierCI(columnName) + lhs.add(newProjExpr(ae), columnName) + } + innerPE := newProjExprWithInner(pe.Original, col.RHSExpr) + innerPE.ColExpr = col.RHSExpr + col.RHSExpr = colName + innerPE.Info = pe.Info + rhs.add(innerPE, alias) + } + return col } -func pushDownSplitJoinCol(col applyJoinColumn, lhs *projector, pe *ProjExpr, alias string, rhs *projector) applyJoinColumn { +func pushDownSplitJoinCol(col applyJoinColumn, lhs, rhs *projector, pe *ProjExpr, alias string) applyJoinColumn { // Update the left and right child columns and names based on the applyJoinColumn type. switch { case col.IsPureLeft(): diff --git a/go/vt/vtgate/planbuilder/testdata/from_cases.json b/go/vt/vtgate/planbuilder/testdata/from_cases.json index da869bb5e69..9a0ffa37f0c 100644 --- a/go/vt/vtgate/planbuilder/testdata/from_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/from_cases.json @@ -4803,8 +4803,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select dt.foo from (select u.foo from `user` as u where 1 != 1) as dt where 1 != 1", - "Query": "select dt.foo from (select u.foo from `user` as u) as dt", + "FieldQuery": "select dt.foo from (select u.foo as foo from `user` as u where 1 != 1) as dt where 1 != 1", + "Query": "select dt.foo from (select u.foo as foo from `user` as u) as dt", "Table": "`user`" }, { diff --git a/go/vt/vtgate/planbuilder/testdata/tpch_cases.json b/go/vt/vtgate/planbuilder/testdata/tpch_cases.json index 2f8a927d0c8..6e67d4f9e8d 100644 --- a/go/vt/vtgate/planbuilder/testdata/tpch_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/tpch_cases.json @@ -1132,7 +1132,7 @@ { "OperatorType": "Join", "Variant": "Join", - "JoinColumnIndexes": "L:0,R:0,R:1,R:2,R:3,R:4,L:3,R:5", + "JoinColumnIndexes": "L:0,R:0,R:1,R:2,R:3,R:4,L:2,R:5", "JoinVars": { "o_orderkey": 1 }, @@ -1145,8 +1145,8 @@ "Name": "main", "Sharded": true }, - "FieldQuery": "select profit.o_year, profit.o_orderkey, weight_string(profit.o_year), weight_string(extract(year from o_orderdate)) from (select extract(year from o_orderdate) as o_year, o_orderkey as o_orderkey from orders where 1 != 1) as profit where 1 != 1", - "Query": "select profit.o_year, profit.o_orderkey, weight_string(profit.o_year), weight_string(extract(year from o_orderdate)) from (select extract(year from o_orderdate) as o_year, o_orderkey as o_orderkey from orders) as profit", + "FieldQuery": "select profit.o_year, profit.o_orderkey, weight_string(profit.o_year) from (select extract(year from o_orderdate) as o_year, o_orderkey as o_orderkey from orders where 1 != 1) as profit where 1 != 1", + "Query": "select profit.o_year, profit.o_orderkey, weight_string(profit.o_year) from (select extract(year from o_orderdate) as o_year, o_orderkey as o_orderkey from orders) as profit", "Table": "orders" }, { @@ -1192,8 +1192,8 @@ "Name": "main", "Sharded": true }, - "FieldQuery": "select profit.l_extendedprice, profit.l_discount, profit.l_quantity, profit.l_suppkey, profit.l_partkey, weight_string(profit.l_suppkey) from (select l_extendedprice, l_discount, l_quantity, l_suppkey as l_suppkey, l_partkey as l_partkey from lineitem where 1 != 1) as profit where 1 != 1", - "Query": "select profit.l_extendedprice, profit.l_discount, profit.l_quantity, profit.l_suppkey, profit.l_partkey, weight_string(profit.l_suppkey) from (select l_extendedprice, l_discount, l_quantity, l_suppkey as l_suppkey, l_partkey as l_partkey from lineitem where l_orderkey = :o_orderkey) as profit", + "FieldQuery": "select profit.l_extendedprice, profit.l_discount, profit.l_quantity, profit.l_suppkey, profit.l_partkey, weight_string(profit.l_suppkey) from (select l_extendedprice as l_extendedprice, l_discount as l_discount, l_quantity as l_quantity, l_suppkey as l_suppkey, l_partkey as l_partkey from lineitem where 1 != 1) as profit where 1 != 1", + "Query": "select profit.l_extendedprice, profit.l_discount, profit.l_quantity, profit.l_suppkey, profit.l_partkey, weight_string(profit.l_suppkey) from (select l_extendedprice as l_extendedprice, l_discount as l_discount, l_quantity as l_quantity, l_suppkey as l_suppkey, l_partkey as l_partkey from lineitem where l_orderkey = :o_orderkey) as profit", "Table": "lineitem" } ]