Skip to content

Commit

Permalink
feat: handle last_insert_id with arguments even when deeply nested in…
Browse files Browse the repository at this point in the history
… SELECT expressions

Signed-off-by: Andres Taylor <andres@planetscale.com>
  • Loading branch information
systay committed Dec 10, 2024
1 parent f3ee39d commit 6460bfc
Show file tree
Hide file tree
Showing 5 changed files with 309 additions and 238 deletions.
26 changes: 26 additions & 0 deletions go/vt/sqlparser/cow.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,32 @@ func CopyOnRewrite(
return out
}

func CopyAndReplaceExpr(node SQLNode, replaceFn func(node Expr) (Expr, bool)) SQLNode {
var replace Expr
pre := func(node, _ SQLNode) bool {
expr, ok := node.(Expr)
if !ok {
return true
}
newExpr, ok := replaceFn(expr)
if !ok {
return true
}
replace = newExpr
return false
}

post := func(cursor *CopyOnWriteCursor) {
if replace == nil {
return
}
cursor.Replace(replace)
replace = nil
}

return CopyOnRewrite(node, pre, post, nil)
}

// StopTreeWalk aborts the current tree walking. No more nodes will be visited, and the rewriter will exit out early
func (c *CopyOnWriteCursor) StopTreeWalk() {
c.stop = true
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/executorcontext/vcursor_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -1130,8 +1130,8 @@ func (vc *VCursorImpl) SetFoundRows(foundRows uint64) {
vc.SafeSession.SetFoundRows(foundRows)
}

func (vc *vcursorImpl) SetLastInsertID(id uint64) {
vc.safeSession.LastInsertId = id
func (vc *VCursorImpl) SetLastInsertID(id uint64) {
vc.SafeSession.LastInsertId = id
}

// SetDDLStrategy implements the SessionActions interface
Expand Down
93 changes: 68 additions & 25 deletions go/vt/vtgate/planbuilder/operators/query_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ import (
func planQuery(ctx *plancontext.PlanningContext, root Operator) Operator {
var selExpr sqlparser.SelectExprs
if horizon, isHorizon := root.(*Horizon); isHorizon {
sel := sqlparser.GetFirstSelect(horizon.Query)
selExpr = sqlparser.Clone(sel.SelectExprs)
selExpr = extractSelectExpressions(horizon)
}

output := runPhases(ctx, root)
Expand Down Expand Up @@ -821,36 +820,46 @@ func tryPushUnion(ctx *plancontext.PlanningContext, op *Union) (Operator, *Apply
return newUnion(sources, selects, op.unionColumns, op.distinct), Rewrote("merge union inputs")
}

// addTruncationOrProjectionToReturnOutput uses the original Horizon to make sure that the output columns line up with what the user asked for
func addTruncationOrProjectionToReturnOutput(ctx *plancontext.PlanningContext, selExprs sqlparser.SelectExprs, output Operator) Operator {
if len(selExprs) == 0 {
return output
}

cols := output.GetSelectExprs(ctx)
sizeCorrect := len(selExprs) == len(cols) || tryTruncateColumnsAt(output, len(selExprs))
if !sizeCorrect || !colNamesAlign(selExprs, cols) {
output = createSimpleProjection(ctx, selExprs, output)
}

if !ctx.SemTable.QuerySignature.LastInsertIDArg {
return output
}

var offset int
for i, expr := range selExprs {
func handleLastInsertIDColumns(ctx *plancontext.PlanningContext, output Operator) Operator {
offset := -1
topLevel := false
var arg sqlparser.Expr
for i, expr := range output.GetSelectExprs(ctx) {
ae, ok := expr.(*sqlparser.AliasedExpr)
if !ok {
panic(vterrors.VT09015())
}
fnc, ok := ae.Expr.(*sqlparser.FuncExpr)
if !ok || !fnc.Name.EqualString("last_insert_id") {
continue

replaceFn := func(node sqlparser.Expr) (sqlparser.Expr, bool) {
fnc, ok := node.(*sqlparser.FuncExpr)
if !ok || !fnc.Name.EqualString("last_insert_id") {
return node, false
}
if offset != -1 {
panic(vterrors.VT12001("last_insert_id() found multiple times in select list"))
}
arg = fnc.Exprs[0]
if node == ae.Expr {
topLevel = true
}
offset = i
return arg, true
}

newExpr := sqlparser.CopyAndReplaceExpr(ae.Expr, replaceFn)
ae.Expr = newExpr.(sqlparser.Expr)
}

if topLevel {
return &SaveToSession{
unaryOperator: unaryOperator{
Source: output,
},
Offset: offset,
}
offset = i
break
}

offset = output.AddColumn(ctx, false, false, aeWrap(arg))
return &SaveToSession{
unaryOperator: unaryOperator{
Source: output,
Expand All @@ -859,6 +868,40 @@ func addTruncationOrProjectionToReturnOutput(ctx *plancontext.PlanningContext, s
}
}

// addTruncationOrProjectionToReturnOutput uses the original Horizon to make sure that the output columns line up with what the user asked for
func addTruncationOrProjectionToReturnOutput(ctx *plancontext.PlanningContext, selExprs sqlparser.SelectExprs, output Operator) Operator {
if len(selExprs) == 0 {
return output
}

if ctx.SemTable.QuerySignature.LastInsertIDArg {
output = handleLastInsertIDColumns(ctx, output)
}

cols := output.GetSelectExprs(ctx)
sizeCorrect := len(selExprs) == len(cols) || tryTruncateColumnsAt(output, len(selExprs))
if sizeCorrect && colNamesAlign(selExprs, cols) {
return output
}

return createSimpleProjection(ctx, selExprs, output)
}

func extractSelectExpressions(horizon *Horizon) sqlparser.SelectExprs {
sel := sqlparser.GetFirstSelect(horizon.Query)
// we handle last_insert_id with arguments separately - no need to send this down to mysql
selExprs := sqlparser.CopyAndReplaceExpr(sel.SelectExprs, func(node sqlparser.Expr) (sqlparser.Expr, bool) {
fnc, ok := node.(*sqlparser.FuncExpr)
if !ok || !fnc.Name.EqualString("last_insert_id") || len(fnc.Exprs) != 1 {
return nil, false
}

return fnc.Exprs[0], true
})

return selExprs.(sqlparser.SelectExprs)
}

func colNamesAlign(expected, actual sqlparser.SelectExprs) bool {
if len(expected) > len(actual) {
// if we expect more columns than we have, we can't align
Expand Down
3 changes: 2 additions & 1 deletion go/vt/vtgate/planbuilder/plan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -664,8 +664,9 @@ func (s *planTestSuite) testFile(filename string, vschema *vschemawrapper.VSchem
continue
}
current := PlanTest{
Comment: testName,
Comment: tcase.Comment,
Query: tcase.Query,
SkipE2E: tcase.SkipE2E,
}
vschema.Version = Gen4
out := getPlanOutput(tcase, vschema, render)
Expand Down
Loading

0 comments on commit 6460bfc

Please sign in to comment.