Skip to content

Commit

Permalink
handle derived projection pushing better
Browse files Browse the repository at this point in the history
Signed-off-by: Andres Taylor <andres@planetscale.com>
  • Loading branch information
systay committed Aug 14, 2024
1 parent 83f560c commit 703dfa4
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 38 deletions.
23 changes: 6 additions & 17 deletions go/vt/vtgate/planbuilder/operators/apply_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,10 @@ type (
// so they can be used for the result of this expression that is using data from both sides.
// All fields will be used for these
applyJoinColumn struct {
Original sqlparser.Expr // this is the original expression being passed through
LHSExprs []BindVarExpr // These are the expressions we are pushing to the left hand side which we'll receive as bind variables
RHSExpr sqlparser.Expr // This the expression that we'll evaluate on the right hand side. This is nil, if the right hand side has nothing.
DTColName *sqlparser.ColName // This is the output column name that the parent of JOIN will be seeing. If this is unset, then the colname is the String(Original). We set this when we push Projections with derived tables underneath a Join.
GroupBy bool // if this is true, we need to push this down to our inputs with addToGroupBy set to true
Original sqlparser.Expr // this is the original expression being passed through
LHSExprs []BindVarExpr // These are the expressions we are pushing to the left hand side which we'll receive as bind variables
RHSExpr sqlparser.Expr // This the expression that we'll evaluate on the right hand side. This is nil, if the right hand side has nothing.
GroupBy bool // if this is true, we need to push this down to our inputs with addToGroupBy set to true
}

// BindVarExpr is an expression needed from one side of a join/subquery, and the argument name for it.
Expand Down Expand Up @@ -225,8 +224,7 @@ func (aj *ApplyJoin) getJoinColumnFor(ctx *plancontext.PlanningContext, orig *sq

func applyJoinCompare(ctx *plancontext.PlanningContext, expr sqlparser.Expr) func(e applyJoinColumn) bool {
return func(e applyJoinColumn) bool {
// e.DTColName is how the outside world will be using this expression. So we should check for an equality with that too.
return ctx.SemTable.EqualsExprWithDeps(e.Original, expr) || ctx.SemTable.EqualsExprWithDeps(e.DTColName, expr)
return ctx.SemTable.EqualsExprWithDeps(e.Original, expr)
}
}

Expand Down Expand Up @@ -447,11 +445,8 @@ func (jc applyJoinColumn) String() string {
lhs := slice.Map(jc.LHSExprs, func(e BindVarExpr) string {
return sqlparser.String(e.Expr)
})
if jc.DTColName == nil {
return fmt.Sprintf("[%s | %s | %s]", strings.Join(lhs, ", "), rhs, sqlparser.String(jc.Original))
}

return fmt.Sprintf("[%s | %s | %s | %s]", strings.Join(lhs, ", "), rhs, sqlparser.String(jc.Original), sqlparser.String(jc.DTColName))
return fmt.Sprintf("[%s | %s | %s]", strings.Join(lhs, ", "), rhs, sqlparser.String(jc.Original))
}

func (jc applyJoinColumn) IsPureLeft() bool {
Expand All @@ -467,16 +462,10 @@ func (jc applyJoinColumn) IsMixedLeftAndRight() bool {
}

func (jc applyJoinColumn) GetPureLeftExpr() sqlparser.Expr {
if jc.DTColName != nil {
return jc.DTColName
}
return jc.LHSExprs[0].Expr
}

func (jc applyJoinColumn) GetRHSExpr() sqlparser.Expr {
if jc.DTColName != nil {
return jc.DTColName
}
return jc.RHSExpr
}

Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/projection.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ type (

ProjExpr struct {
Original *sqlparser.AliasedExpr // this is the expression the user asked for. should only be used to decide on the column alias
EvalExpr sqlparser.Expr // EvalExpr is the expression that will be evaluated at runtime
EvalExpr sqlparser.Expr // EvalExpr represents the expression evaluated at runtime or used when the ProjExpr is pushed under a route
ColExpr sqlparser.Expr // ColExpr is used during planning to figure out which column this ProjExpr is representing
Info ExprInfo // Here we store information about evalengine, offsets or subqueries
}
Expand Down
53 changes: 40 additions & 13 deletions go/vt/vtgate/planbuilder/operators/projection_pushing.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ func splitSubqueryExpression(
alias string,
) applyJoinColumn {
col := join.getJoinColumnFor(ctx, pe.Original, pe.ColExpr, false)
return pushDownSplitJoinCol(col, lhs, pe, alias, rhs)
return pushDownSplitJoinCol(col, lhs, rhs, pe, alias)
}

func splitUnexploredExpression(
Expand All @@ -334,23 +334,50 @@ func splitUnexploredExpression(
original := sqlparser.Clone(pe.Original)
expr := pe.ColExpr

var colName *sqlparser.ColName
if dt != nil {
if !pe.isSameInAndOut(ctx) {
panic(vterrors.VT13001("derived table columns must be the same in and out"))
}
colName = sqlparser.NewColNameWithQualifier(pe.Original.ColumnName(), sqlparser.NewTableName(dt.Alias))
ctx.SemTable.CopySemanticInfo(expr, colName)
}

// Get a applyJoinColumn for the current expression.
col := join.getJoinColumnFor(ctx, original, expr, false)
col.DTColName = colName

return pushDownSplitJoinCol(col, lhs, pe, alias, rhs)
if dt == nil {
return pushDownSplitJoinCol(col, lhs, rhs, pe, alias)
}

if !pe.isSameInAndOut(ctx) {
panic(vterrors.VT13001("derived table columns must be the same in and out"))
}
// we are pushing a derived projection through a join. that means that after this rewrite, we are on top of the
// derived table divider, and can only see the projected columns, not the underlying expressions
colName := sqlparser.NewColNameWithQualifier(pe.Original.ColumnName(), sqlparser.NewTableName(dt.Alias))
ctx.SemTable.CopySemanticInfo(expr, colName)
col.Original = colName
if alias == "" {
alias = pe.Original.ColumnName()
}

// Update the left and right child columns and names based on the applyJoinColumn type.
switch {
case col.IsPureLeft():
lhs.add(pe, alias)
col.LHSExprs[0].Expr = colName
case col.IsPureRight():
rhs.add(pe, alias)
col.RHSExpr = colName
case col.IsMixedLeftAndRight():
for _, lhsExpr := range col.LHSExprs {
ae := aeWrap(lhsExpr.Expr)
columnName := ae.ColumnName()
ae.As = sqlparser.NewIdentifierCI(columnName)
lhs.add(newProjExpr(ae), columnName)
}
innerPE := newProjExprWithInner(pe.Original, col.RHSExpr)
innerPE.ColExpr = col.RHSExpr
col.RHSExpr = colName
innerPE.Info = pe.Info
rhs.add(innerPE, alias)
}
return col
}

func pushDownSplitJoinCol(col applyJoinColumn, lhs *projector, pe *ProjExpr, alias string, rhs *projector) applyJoinColumn {
func pushDownSplitJoinCol(col applyJoinColumn, lhs, rhs *projector, pe *ProjExpr, alias string) applyJoinColumn {
// Update the left and right child columns and names based on the applyJoinColumn type.
switch {
case col.IsPureLeft():
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/planbuilder/testdata/from_cases.json
Original file line number Diff line number Diff line change
Expand Up @@ -4803,8 +4803,8 @@
"Name": "user",
"Sharded": true
},
"FieldQuery": "select dt.foo from (select u.foo from `user` as u where 1 != 1) as dt where 1 != 1",
"Query": "select dt.foo from (select u.foo from `user` as u) as dt",
"FieldQuery": "select dt.foo from (select u.foo as foo from `user` as u where 1 != 1) as dt where 1 != 1",
"Query": "select dt.foo from (select u.foo as foo from `user` as u) as dt",
"Table": "`user`"
},
{
Expand Down
10 changes: 5 additions & 5 deletions go/vt/vtgate/planbuilder/testdata/tpch_cases.json
Original file line number Diff line number Diff line change
Expand Up @@ -1132,7 +1132,7 @@
{
"OperatorType": "Join",
"Variant": "Join",
"JoinColumnIndexes": "L:0,R:0,R:1,R:2,R:3,R:4,L:3,R:5",
"JoinColumnIndexes": "L:0,R:0,R:1,R:2,R:3,R:4,L:2,R:5",
"JoinVars": {
"o_orderkey": 1
},
Expand All @@ -1145,8 +1145,8 @@
"Name": "main",
"Sharded": true
},
"FieldQuery": "select profit.o_year, profit.o_orderkey, weight_string(profit.o_year), weight_string(extract(year from o_orderdate)) from (select extract(year from o_orderdate) as o_year, o_orderkey as o_orderkey from orders where 1 != 1) as profit where 1 != 1",
"Query": "select profit.o_year, profit.o_orderkey, weight_string(profit.o_year), weight_string(extract(year from o_orderdate)) from (select extract(year from o_orderdate) as o_year, o_orderkey as o_orderkey from orders) as profit",
"FieldQuery": "select profit.o_year, profit.o_orderkey, weight_string(profit.o_year) from (select extract(year from o_orderdate) as o_year, o_orderkey as o_orderkey from orders where 1 != 1) as profit where 1 != 1",
"Query": "select profit.o_year, profit.o_orderkey, weight_string(profit.o_year) from (select extract(year from o_orderdate) as o_year, o_orderkey as o_orderkey from orders) as profit",
"Table": "orders"
},
{
Expand Down Expand Up @@ -1192,8 +1192,8 @@
"Name": "main",
"Sharded": true
},
"FieldQuery": "select profit.l_extendedprice, profit.l_discount, profit.l_quantity, profit.l_suppkey, profit.l_partkey, weight_string(profit.l_suppkey) from (select l_extendedprice, l_discount, l_quantity, l_suppkey as l_suppkey, l_partkey as l_partkey from lineitem where 1 != 1) as profit where 1 != 1",
"Query": "select profit.l_extendedprice, profit.l_discount, profit.l_quantity, profit.l_suppkey, profit.l_partkey, weight_string(profit.l_suppkey) from (select l_extendedprice, l_discount, l_quantity, l_suppkey as l_suppkey, l_partkey as l_partkey from lineitem where l_orderkey = :o_orderkey) as profit",
"FieldQuery": "select profit.l_extendedprice, profit.l_discount, profit.l_quantity, profit.l_suppkey, profit.l_partkey, weight_string(profit.l_suppkey) from (select l_extendedprice as l_extendedprice, l_discount as l_discount, l_quantity as l_quantity, l_suppkey as l_suppkey, l_partkey as l_partkey from lineitem where 1 != 1) as profit where 1 != 1",
"Query": "select profit.l_extendedprice, profit.l_discount, profit.l_quantity, profit.l_suppkey, profit.l_partkey, weight_string(profit.l_suppkey) from (select l_extendedprice as l_extendedprice, l_discount as l_discount, l_quantity as l_quantity, l_suppkey as l_suppkey, l_partkey as l_partkey from lineitem where l_orderkey = :o_orderkey) as profit",
"Table": "lineitem"
}
]
Expand Down

0 comments on commit 703dfa4

Please sign in to comment.